mirror of
https://github.com/soconnor0919/hristudio.git
synced 2026-03-23 19:27:51 -04:00
feat: implement WebSocket for real-time trial updates
- Create standalone WebSocket server (ws-server.ts) on port 3001 using Bun - Add ws_connections table to track active connections in database - Create global WebSocket manager that persists across component unmounts - Fix useWebSocket hook to prevent infinite re-renders and use refs - Fix TrialForm Select components with proper default values - Add trialId to WebSocket URL for server-side tracking - Update package.json with dev:ws script for separate WS server
This commit is contained in:
@@ -11,7 +11,8 @@
|
|||||||
"db:push": "drizzle-kit push",
|
"db:push": "drizzle-kit push",
|
||||||
"db:studio": "drizzle-kit studio",
|
"db:studio": "drizzle-kit studio",
|
||||||
"db:seed": "bun db:push && bun scripts/seed-dev.ts",
|
"db:seed": "bun db:push && bun scripts/seed-dev.ts",
|
||||||
"dev": "next dev --turbo",
|
"dev": "bun run ws-server.ts & next dev --turbo",
|
||||||
|
"dev:ws": "bun run ws-server.ts",
|
||||||
"docker:up": "if [ \"$(uname)\" = \"Darwin\" ]; then colima start; fi && docker compose up -d",
|
"docker:up": "if [ \"$(uname)\" = \"Darwin\" ]; then colima start; fi && docker compose up -d",
|
||||||
"docker:down": "docker compose down && if [ \"$(uname)\" = \"Darwin\" ]; then colima stop; fi",
|
"docker:down": "docker compose down && if [ \"$(uname)\" = \"Darwin\" ]; then colima stop; fi",
|
||||||
"format:check": "prettier --check \"**/*.{ts,tsx,js,jsx,mdx}\" --cache",
|
"format:check": "prettier --check \"**/*.{ts,tsx,js,jsx,mdx}\" --cache",
|
||||||
|
|||||||
135
src/app/api/websocket/route.ts
Normal file
135
src/app/api/websocket/route.ts
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import { NextRequest } from "next/server";
|
||||||
|
import { headers } from "next/headers";
|
||||||
|
import { wsManager } from "~/server/services/websocket-manager";
|
||||||
|
import { auth } from "~/lib/auth";
|
||||||
|
|
||||||
|
const clientConnections = new Map<
|
||||||
|
string,
|
||||||
|
{ socket: WebSocket; clientId: string }
|
||||||
|
>();
|
||||||
|
|
||||||
|
function generateClientId(): string {
|
||||||
|
return `ws_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const runtime = "edge";
|
||||||
|
export const dynamic = "force-dynamic";
|
||||||
|
|
||||||
|
export async function GET(request: NextRequest) {
|
||||||
|
const url = new URL(request.url);
|
||||||
|
const trialId = url.searchParams.get("trialId");
|
||||||
|
const token = url.searchParams.get("token");
|
||||||
|
|
||||||
|
if (!trialId) {
|
||||||
|
return new Response("Missing trialId parameter", { status: 400 });
|
||||||
|
}
|
||||||
|
|
||||||
|
let userId: string | null = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const session = await auth.api.getSession({
|
||||||
|
headers: await headers(),
|
||||||
|
});
|
||||||
|
if (session?.user?.id) {
|
||||||
|
userId = session.user.id;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
if (!token) {
|
||||||
|
return new Response("Authentication required", { status: 401 });
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const tokenData = JSON.parse(atob(token));
|
||||||
|
userId = tokenData.userId;
|
||||||
|
} catch {
|
||||||
|
return new Response("Invalid token", { status: 401 });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const pair = new WebSocketPair();
|
||||||
|
const clientId = generateClientId();
|
||||||
|
const serverWebSocket = Object.values(pair)[0] as WebSocket;
|
||||||
|
|
||||||
|
clientConnections.set(clientId, { socket: serverWebSocket, clientId });
|
||||||
|
|
||||||
|
await wsManager.subscribe(clientId, serverWebSocket, trialId, userId);
|
||||||
|
|
||||||
|
serverWebSocket.accept();
|
||||||
|
|
||||||
|
serverWebSocket.addEventListener("message", async (event) => {
|
||||||
|
try {
|
||||||
|
const message = JSON.parse(event.data as string);
|
||||||
|
|
||||||
|
switch (message.type) {
|
||||||
|
case "heartbeat":
|
||||||
|
wsManager.sendToClient(clientId, {
|
||||||
|
type: "heartbeat_response",
|
||||||
|
data: { timestamp: Date.now() },
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "request_trial_status": {
|
||||||
|
const status = await wsManager.getTrialStatus(trialId);
|
||||||
|
wsManager.sendToClient(clientId, {
|
||||||
|
type: "trial_status",
|
||||||
|
data: {
|
||||||
|
trial: status?.trial ?? null,
|
||||||
|
current_step_index: status?.currentStepIndex ?? 0,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "request_trial_events": {
|
||||||
|
const events = await wsManager.getTrialEvents(
|
||||||
|
trialId,
|
||||||
|
message.data?.limit ?? 100,
|
||||||
|
);
|
||||||
|
wsManager.sendToClient(clientId, {
|
||||||
|
type: "trial_events_snapshot",
|
||||||
|
data: { events, timestamp: Date.now() },
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "ping":
|
||||||
|
wsManager.sendToClient(clientId, {
|
||||||
|
type: "pong",
|
||||||
|
data: { timestamp: Date.now() },
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
console.log(
|
||||||
|
`[WS] Unknown message type from client ${clientId}:`,
|
||||||
|
message.type,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`[WS] Error processing message from ${clientId}:`, error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
serverWebSocket.addEventListener("close", () => {
|
||||||
|
wsManager.unsubscribe(clientId);
|
||||||
|
clientConnections.delete(clientId);
|
||||||
|
});
|
||||||
|
|
||||||
|
serverWebSocket.addEventListener("error", (error) => {
|
||||||
|
console.error(`[WS] Error for client ${clientId}:`, error);
|
||||||
|
wsManager.unsubscribe(clientId);
|
||||||
|
clientConnections.delete(clientId);
|
||||||
|
});
|
||||||
|
|
||||||
|
return new Response(null, {
|
||||||
|
status: 101,
|
||||||
|
webSocket: serverWebSocket,
|
||||||
|
} as ResponseInit);
|
||||||
|
}
|
||||||
|
|
||||||
|
declare global {
|
||||||
|
interface WebSocket {
|
||||||
|
accept(): void;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -163,6 +163,11 @@ export function TrialForm({ mode, trialId, studyId }: TrialFormProps) {
|
|||||||
const form = useForm<TrialFormData>({
|
const form = useForm<TrialFormData>({
|
||||||
resolver: zodResolver(trialSchema),
|
resolver: zodResolver(trialSchema),
|
||||||
defaultValues: {
|
defaultValues: {
|
||||||
|
experimentId: "" as any,
|
||||||
|
participantId: "" as any,
|
||||||
|
scheduledAt: new Date(),
|
||||||
|
wizardId: undefined,
|
||||||
|
notes: "",
|
||||||
sessionNumber: 1,
|
sessionNumber: 1,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@@ -347,7 +352,7 @@ export function TrialForm({ mode, trialId, studyId }: TrialFormProps) {
|
|||||||
<FormField>
|
<FormField>
|
||||||
<Label htmlFor="experimentId">Experiment *</Label>
|
<Label htmlFor="experimentId">Experiment *</Label>
|
||||||
<Select
|
<Select
|
||||||
value={form.watch("experimentId")}
|
value={form.watch("experimentId") ?? ""}
|
||||||
onValueChange={(value) => form.setValue("experimentId", value)}
|
onValueChange={(value) => form.setValue("experimentId", value)}
|
||||||
disabled={experimentsLoading || mode === "edit"}
|
disabled={experimentsLoading || mode === "edit"}
|
||||||
>
|
>
|
||||||
@@ -387,7 +392,7 @@ export function TrialForm({ mode, trialId, studyId }: TrialFormProps) {
|
|||||||
<FormField>
|
<FormField>
|
||||||
<Label htmlFor="participantId">Participant *</Label>
|
<Label htmlFor="participantId">Participant *</Label>
|
||||||
<Select
|
<Select
|
||||||
value={form.watch("participantId")}
|
value={form.watch("participantId") ?? ""}
|
||||||
onValueChange={(value) => form.setValue("participantId", value)}
|
onValueChange={(value) => form.setValue("participantId", value)}
|
||||||
disabled={participantsLoading || mode === "edit"}
|
disabled={participantsLoading || mode === "edit"}
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ import { WebcamPanel } from "./panels/WebcamPanel";
|
|||||||
import { TrialStatusBar } from "./panels/TrialStatusBar";
|
import { TrialStatusBar } from "./panels/TrialStatusBar";
|
||||||
import { api } from "~/trpc/react";
|
import { api } from "~/trpc/react";
|
||||||
import { useWizardRos } from "~/hooks/useWizardRos";
|
import { useWizardRos } from "~/hooks/useWizardRos";
|
||||||
|
import { useTrialWebSocket, type TrialEvent } from "~/hooks/useWebSocket";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { useTour } from "~/components/onboarding/TourProvider";
|
import { useTour } from "~/components/onboarding/TourProvider";
|
||||||
|
|
||||||
@@ -252,59 +253,65 @@ export const WizardInterface = React.memo(function WizardInterface({
|
|||||||
[setAutonomousLifeRaw],
|
[setAutonomousLifeRaw],
|
||||||
);
|
);
|
||||||
|
|
||||||
// Use polling for trial status updates (no trial WebSocket server exists)
|
// Trial WebSocket for real-time updates
|
||||||
const { data: pollingData } = api.trials.get.useQuery(
|
const {
|
||||||
{ id: trial.id },
|
isConnected: wsConnected,
|
||||||
{
|
connectionError: wsError,
|
||||||
refetchInterval: trial.status === "in_progress" ? 5000 : 15000,
|
trialEvents: wsTrialEvents,
|
||||||
staleTime: 2000,
|
currentTrialStatus,
|
||||||
refetchOnWindowFocus: false,
|
addLocalEvent,
|
||||||
|
} = useTrialWebSocket(trial.id, {
|
||||||
|
onStatusChange: (status) => {
|
||||||
|
// Update local trial state when WebSocket reports status changes
|
||||||
|
setTrial((prev) => ({
|
||||||
|
...prev,
|
||||||
|
status: status.status,
|
||||||
|
startedAt: status.startedAt
|
||||||
|
? new Date(status.startedAt)
|
||||||
|
: prev.startedAt,
|
||||||
|
completedAt: status.completedAt
|
||||||
|
? new Date(status.completedAt)
|
||||||
|
: prev.completedAt,
|
||||||
|
}));
|
||||||
},
|
},
|
||||||
);
|
onTrialEvent: (event) => {
|
||||||
|
// Optionally show toast for new events
|
||||||
// Poll for trial events
|
if (event.eventType === "trial_started") {
|
||||||
const { data: fetchedEvents } = api.trials.getEvents.useQuery(
|
toast.info("Trial started");
|
||||||
{ trialId: trial.id, limit: 100 },
|
} else if (event.eventType === "trial_completed") {
|
||||||
{
|
toast.info("Trial completed");
|
||||||
refetchInterval: 3000,
|
} else if (event.eventType === "trial_aborted") {
|
||||||
staleTime: 1000,
|
toast.warning("Trial aborted");
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// Update local trial state from polling only if changed
|
|
||||||
useEffect(() => {
|
|
||||||
if (pollingData && JSON.stringify(pollingData) !== JSON.stringify(trial)) {
|
|
||||||
// Only update if specific fields we care about have changed to avoid
|
|
||||||
// unnecessary re-renders that might cause UI flashing
|
|
||||||
if (
|
|
||||||
pollingData.status !== trial.status ||
|
|
||||||
pollingData.startedAt?.getTime() !== trial.startedAt?.getTime() ||
|
|
||||||
pollingData.completedAt?.getTime() !== trial.completedAt?.getTime()
|
|
||||||
) {
|
|
||||||
setTrial((prev) => {
|
|
||||||
// Double check inside setter to be safe
|
|
||||||
if (
|
|
||||||
prev.status === pollingData.status &&
|
|
||||||
prev.startedAt?.getTime() === pollingData.startedAt?.getTime() &&
|
|
||||||
prev.completedAt?.getTime() === pollingData.completedAt?.getTime()
|
|
||||||
) {
|
|
||||||
return prev;
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
status: pollingData.status,
|
|
||||||
startedAt: pollingData.startedAt
|
|
||||||
? new Date(pollingData.startedAt)
|
|
||||||
: prev.startedAt,
|
|
||||||
completedAt: pollingData.completedAt
|
|
||||||
? new Date(pollingData.completedAt)
|
|
||||||
: prev.completedAt,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update trial state from WebSocket status
|
||||||
|
useEffect(() => {
|
||||||
|
if (currentTrialStatus) {
|
||||||
|
setTrial((prev) => {
|
||||||
|
if (
|
||||||
|
prev.status === currentTrialStatus.status &&
|
||||||
|
prev.startedAt?.getTime() ===
|
||||||
|
new Date(currentTrialStatus.startedAt ?? "").getTime() &&
|
||||||
|
prev.completedAt?.getTime() ===
|
||||||
|
new Date(currentTrialStatus.completedAt ?? "").getTime()
|
||||||
|
) {
|
||||||
|
return prev;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
status: currentTrialStatus.status,
|
||||||
|
startedAt: currentTrialStatus.startedAt
|
||||||
|
? new Date(currentTrialStatus.startedAt)
|
||||||
|
: prev.startedAt,
|
||||||
|
completedAt: currentTrialStatus.completedAt
|
||||||
|
? new Date(currentTrialStatus.completedAt)
|
||||||
|
: prev.completedAt,
|
||||||
|
};
|
||||||
|
});
|
||||||
}
|
}
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
}, [currentTrialStatus]);
|
||||||
}, [pollingData]);
|
|
||||||
|
|
||||||
// Auto-start trial on mount if scheduled
|
// Auto-start trial on mount if scheduled
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -313,7 +320,7 @@ export const WizardInterface = React.memo(function WizardInterface({
|
|||||||
}
|
}
|
||||||
}, []); // Run once on mount
|
}, []); // Run once on mount
|
||||||
|
|
||||||
// Trial events from robot actions
|
// Trial events from WebSocket (and initial load)
|
||||||
const trialEvents = useMemo<
|
const trialEvents = useMemo<
|
||||||
Array<{
|
Array<{
|
||||||
type: string;
|
type: string;
|
||||||
@@ -322,8 +329,8 @@ export const WizardInterface = React.memo(function WizardInterface({
|
|||||||
message?: string;
|
message?: string;
|
||||||
}>
|
}>
|
||||||
>(() => {
|
>(() => {
|
||||||
return (fetchedEvents ?? [])
|
return (wsTrialEvents ?? [])
|
||||||
.map((event) => {
|
.map((event: TrialEvent) => {
|
||||||
let message: string | undefined;
|
let message: string | undefined;
|
||||||
const eventData = event.data as any;
|
const eventData = event.data as any;
|
||||||
|
|
||||||
@@ -364,7 +371,7 @@ export const WizardInterface = React.memo(function WizardInterface({
|
|||||||
};
|
};
|
||||||
})
|
})
|
||||||
.sort((a, b) => b.timestamp.getTime() - a.timestamp.getTime()); // Newest first
|
.sort((a, b) => b.timestamp.getTime() - a.timestamp.getTime()); // Newest first
|
||||||
}, [fetchedEvents]);
|
}, [wsTrialEvents]);
|
||||||
|
|
||||||
// Transform experiment steps to component format
|
// Transform experiment steps to component format
|
||||||
const steps: StepData[] = useMemo(
|
const steps: StepData[] = useMemo(
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
/* eslint-disable react-hooks/exhaustive-deps */
|
|
||||||
|
|
||||||
import { useSession } from "~/lib/auth-client";
|
import { useSession } from "~/lib/auth-client";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
@@ -56,10 +54,42 @@ interface TrialActionExecutedMessage {
|
|||||||
interface InterventionLoggedMessage {
|
interface InterventionLoggedMessage {
|
||||||
type: "intervention_logged";
|
type: "intervention_logged";
|
||||||
data: {
|
data: {
|
||||||
|
intervention?: unknown;
|
||||||
timestamp: number;
|
timestamp: number;
|
||||||
} & Record<string, unknown>;
|
} & Record<string, unknown>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface TrialEventMessage {
|
||||||
|
type: "trial_event";
|
||||||
|
data: {
|
||||||
|
event: unknown;
|
||||||
|
timestamp: number;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
interface TrialEventsSnapshotMessage {
|
||||||
|
type: "trial_events_snapshot";
|
||||||
|
data: {
|
||||||
|
events: unknown[];
|
||||||
|
timestamp: number;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
interface AnnotationAddedMessage {
|
||||||
|
type: "annotation_added";
|
||||||
|
data: {
|
||||||
|
annotation: unknown;
|
||||||
|
timestamp: number;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
interface PongMessage {
|
||||||
|
type: "pong";
|
||||||
|
data: {
|
||||||
|
timestamp: number;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
interface StepChangedMessage {
|
interface StepChangedMessage {
|
||||||
type: "step_changed";
|
type: "step_changed";
|
||||||
data: {
|
data: {
|
||||||
@@ -83,6 +113,10 @@ type KnownInboundMessage =
|
|||||||
| TrialStatusMessage
|
| TrialStatusMessage
|
||||||
| TrialActionExecutedMessage
|
| TrialActionExecutedMessage
|
||||||
| InterventionLoggedMessage
|
| InterventionLoggedMessage
|
||||||
|
| TrialEventMessage
|
||||||
|
| TrialEventsSnapshotMessage
|
||||||
|
| AnnotationAddedMessage
|
||||||
|
| PongMessage
|
||||||
| StepChangedMessage
|
| StepChangedMessage
|
||||||
| ErrorMessage;
|
| ErrorMessage;
|
||||||
|
|
||||||
@@ -98,18 +132,247 @@ export interface OutgoingMessage {
|
|||||||
data: Record<string, unknown>;
|
data: Record<string, unknown>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface UseWebSocketOptions {
|
interface Subscription {
|
||||||
trialId: string;
|
trialId: string;
|
||||||
onMessage?: (message: WebSocketMessage) => void;
|
onMessage?: (message: WebSocketMessage) => void;
|
||||||
onConnect?: () => void;
|
onConnect?: () => void;
|
||||||
onDisconnect?: () => void;
|
onDisconnect?: () => void;
|
||||||
onError?: (error: Event) => void;
|
onError?: (error: Event) => void;
|
||||||
reconnectAttempts?: number;
|
|
||||||
reconnectInterval?: number;
|
|
||||||
heartbeatInterval?: number;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface UseWebSocketReturn {
|
interface GlobalWSState {
|
||||||
|
isConnected: boolean;
|
||||||
|
isConnecting: boolean;
|
||||||
|
connectionError: string | null;
|
||||||
|
lastMessage: WebSocketMessage | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
type StateListener = (state: GlobalWSState) => void;
|
||||||
|
|
||||||
|
class GlobalWebSocketManager {
|
||||||
|
private ws: WebSocket | null = null;
|
||||||
|
private subscriptions: Map<string, Subscription> = new Map();
|
||||||
|
private stateListeners: Set<StateListener> = new Set();
|
||||||
|
private sessionRef: { user: { id: string } } | null = null;
|
||||||
|
private heartbeatInterval: ReturnType<typeof setInterval> | null = null;
|
||||||
|
private reconnectTimeout: ReturnType<typeof setTimeout> | null = null;
|
||||||
|
private attemptCount = 0;
|
||||||
|
private maxAttempts = 5;
|
||||||
|
|
||||||
|
private state: GlobalWSState = {
|
||||||
|
isConnected: false,
|
||||||
|
isConnecting: false,
|
||||||
|
connectionError: null,
|
||||||
|
lastMessage: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
private setState(partial: Partial<GlobalWSState>) {
|
||||||
|
this.state = { ...this.state, ...partial };
|
||||||
|
this.notifyListeners();
|
||||||
|
}
|
||||||
|
|
||||||
|
private notifyListeners() {
|
||||||
|
this.stateListeners.forEach((listener) => listener(this.state));
|
||||||
|
}
|
||||||
|
|
||||||
|
subscribe(
|
||||||
|
session: { user: { id: string } } | null,
|
||||||
|
subscription: Subscription,
|
||||||
|
) {
|
||||||
|
this.sessionRef = session;
|
||||||
|
this.subscriptions.set(subscription.trialId, subscription);
|
||||||
|
|
||||||
|
if (this.subscriptions.size === 1 && !this.ws) {
|
||||||
|
this.connect();
|
||||||
|
}
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
this.subscriptions.delete(subscription.trialId);
|
||||||
|
// Don't auto-disconnect - keep global connection alive
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
sendMessage(message: OutgoingMessage) {
|
||||||
|
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||||
|
this.ws.send(JSON.stringify(message));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
connect() {
|
||||||
|
if (
|
||||||
|
this.ws?.readyState === WebSocket.CONNECTING ||
|
||||||
|
this.ws?.readyState === WebSocket.OPEN
|
||||||
|
) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!this.sessionRef?.user) {
|
||||||
|
this.setState({ connectionError: "No session", isConnecting: false });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.setState({ isConnecting: true, connectionError: null });
|
||||||
|
|
||||||
|
const token = btoa(JSON.stringify({ userId: this.sessionRef.user.id }));
|
||||||
|
const wsPort = process.env.NEXT_PUBLIC_WS_PORT || "3001";
|
||||||
|
|
||||||
|
// Collect all trial IDs from subscriptions
|
||||||
|
const trialIds = Array.from(this.subscriptions.keys());
|
||||||
|
const trialIdParam = trialIds.length > 0 ? `&trialId=${trialIds[0]}` : "";
|
||||||
|
const url = `ws://${typeof window !== "undefined" ? window.location.hostname : "localhost"}:${wsPort}/api/websocket?token=${token}${trialIdParam}`;
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.ws = new WebSocket(url);
|
||||||
|
|
||||||
|
this.ws.onopen = () => {
|
||||||
|
console.log("[GlobalWS] Connected");
|
||||||
|
this.setState({ isConnected: true, isConnecting: false });
|
||||||
|
this.attemptCount = 0;
|
||||||
|
this.startHeartbeat();
|
||||||
|
|
||||||
|
// Subscribe to all subscribed trials
|
||||||
|
this.subscriptions.forEach((sub) => {
|
||||||
|
this.ws?.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: "subscribe",
|
||||||
|
data: { trialId: sub.trialId },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
this.subscriptions.forEach((sub) => sub.onConnect?.());
|
||||||
|
};
|
||||||
|
|
||||||
|
this.ws.onmessage = (event) => {
|
||||||
|
try {
|
||||||
|
const message = JSON.parse(event.data) as WebSocketMessage;
|
||||||
|
this.setState({ lastMessage: message });
|
||||||
|
|
||||||
|
if (message.type === "connection_established") {
|
||||||
|
const data = (message as ConnectionEstablishedMessage).data;
|
||||||
|
const sub = this.subscriptions.get(data.trialId);
|
||||||
|
if (sub) {
|
||||||
|
sub.onMessage?.(message);
|
||||||
|
}
|
||||||
|
} else if (
|
||||||
|
message.type === "trial_event" ||
|
||||||
|
message.type === "trial_status"
|
||||||
|
) {
|
||||||
|
const data = (message as TrialEventMessage).data;
|
||||||
|
const event = data.event as { trialId?: string };
|
||||||
|
if (event?.trialId) {
|
||||||
|
const sub = this.subscriptions.get(event.trialId);
|
||||||
|
sub?.onMessage?.(message);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Broadcast to all subscriptions
|
||||||
|
this.subscriptions.forEach((sub) => sub.onMessage?.(message));
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("[GlobalWS] Failed to parse message:", error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
this.ws.onclose = (event) => {
|
||||||
|
console.log("[GlobalWS] Disconnected:", event.code);
|
||||||
|
this.setState({ isConnected: false, isConnecting: false });
|
||||||
|
this.stopHeartbeat();
|
||||||
|
this.subscriptions.forEach((sub) => sub.onDisconnect?.());
|
||||||
|
|
||||||
|
// Auto-reconnect if not intentionally closed
|
||||||
|
if (event.code !== 1000 && this.subscriptions.size > 0) {
|
||||||
|
this.scheduleReconnect();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
this.ws.onerror = (error) => {
|
||||||
|
console.error("[GlobalWS] Error:", error);
|
||||||
|
this.setState({
|
||||||
|
connectionError: "Connection error",
|
||||||
|
isConnecting: false,
|
||||||
|
});
|
||||||
|
this.subscriptions.forEach((sub) =>
|
||||||
|
sub.onError?.(new Event("ws_error")),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
console.error("[GlobalWS] Failed to create:", error);
|
||||||
|
this.setState({
|
||||||
|
connectionError: "Failed to create connection",
|
||||||
|
isConnecting: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
disconnect() {
|
||||||
|
if (this.reconnectTimeout) {
|
||||||
|
clearTimeout(this.reconnectTimeout);
|
||||||
|
this.reconnectTimeout = null;
|
||||||
|
}
|
||||||
|
this.stopHeartbeat();
|
||||||
|
if (this.ws) {
|
||||||
|
this.ws.close(1000, "Manual disconnect");
|
||||||
|
this.ws = null;
|
||||||
|
}
|
||||||
|
this.setState({ isConnected: false, isConnecting: false });
|
||||||
|
}
|
||||||
|
|
||||||
|
private startHeartbeat() {
|
||||||
|
this.heartbeatInterval = setInterval(() => {
|
||||||
|
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||||
|
this.ws.send(JSON.stringify({ type: "heartbeat", data: {} }));
|
||||||
|
}
|
||||||
|
}, 30000);
|
||||||
|
}
|
||||||
|
|
||||||
|
private stopHeartbeat() {
|
||||||
|
if (this.heartbeatInterval) {
|
||||||
|
clearInterval(this.heartbeatInterval);
|
||||||
|
this.heartbeatInterval = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private scheduleReconnect() {
|
||||||
|
if (this.attemptCount >= this.maxAttempts) {
|
||||||
|
this.setState({ connectionError: "Max reconnection attempts reached" });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const delay = Math.min(30000, 1000 * Math.pow(1.5, this.attemptCount));
|
||||||
|
this.attemptCount++;
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`[GlobalWS] Reconnecting in ${delay}ms (attempt ${this.attemptCount})`,
|
||||||
|
);
|
||||||
|
|
||||||
|
this.reconnectTimeout = setTimeout(() => {
|
||||||
|
if (this.subscriptions.size > 0) {
|
||||||
|
this.connect();
|
||||||
|
}
|
||||||
|
}, delay);
|
||||||
|
}
|
||||||
|
|
||||||
|
getState(): GlobalWSState {
|
||||||
|
return this.state;
|
||||||
|
}
|
||||||
|
|
||||||
|
addListener(listener: StateListener) {
|
||||||
|
this.stateListeners.add(listener);
|
||||||
|
return () => this.stateListeners.delete(listener);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const globalWS = new GlobalWebSocketManager();
|
||||||
|
|
||||||
|
export interface UseGlobalWebSocketOptions {
|
||||||
|
trialId: string;
|
||||||
|
onMessage?: (message: WebSocketMessage) => void;
|
||||||
|
onConnect?: () => void;
|
||||||
|
onDisconnect?: () => void;
|
||||||
|
onError?: (error: Event) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface UseGlobalWebSocketReturn {
|
||||||
isConnected: boolean;
|
isConnected: boolean;
|
||||||
isConnecting: boolean;
|
isConnecting: boolean;
|
||||||
connectionError: string | null;
|
connectionError: string | null;
|
||||||
@@ -119,333 +382,66 @@ export interface UseWebSocketReturn {
|
|||||||
lastMessage: WebSocketMessage | null;
|
lastMessage: WebSocketMessage | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useWebSocket({
|
export function useGlobalWebSocket({
|
||||||
trialId,
|
trialId,
|
||||||
onMessage,
|
onMessage,
|
||||||
onConnect,
|
onConnect,
|
||||||
onDisconnect,
|
onDisconnect,
|
||||||
onError,
|
onError,
|
||||||
reconnectAttempts = 5,
|
}: UseGlobalWebSocketOptions): UseGlobalWebSocketReturn {
|
||||||
reconnectInterval = 3000,
|
|
||||||
heartbeatInterval = 30000,
|
|
||||||
}: UseWebSocketOptions): UseWebSocketReturn {
|
|
||||||
const { data: session } = useSession();
|
const { data: session } = useSession();
|
||||||
const [isConnected, setIsConnected] = useState<boolean>(false);
|
const [isConnected, setIsConnected] = useState(false);
|
||||||
const [isConnecting, setIsConnecting] = useState<boolean>(false);
|
const [isConnecting, setIsConnecting] = useState(false);
|
||||||
const [connectionError, setConnectionError] = useState<string | null>(null);
|
const [connectionError, setConnectionError] = useState<string | null>(null);
|
||||||
const [hasAttemptedConnection, setHasAttemptedConnection] =
|
|
||||||
useState<boolean>(false);
|
|
||||||
const [lastMessage, setLastMessage] = useState<WebSocketMessage | null>(null);
|
const [lastMessage, setLastMessage] = useState<WebSocketMessage | null>(null);
|
||||||
|
|
||||||
const wsRef = useRef<WebSocket | null>(null);
|
const onMessageRef = useRef(onMessage);
|
||||||
const reconnectTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
const onConnectRef = useRef(onConnect);
|
||||||
const heartbeatTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
const onDisconnectRef = useRef(onDisconnect);
|
||||||
const attemptCountRef = useRef<number>(0);
|
const onErrorRef = useRef(onError);
|
||||||
const mountedRef = useRef<boolean>(true);
|
|
||||||
const connectionStableTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
|
||||||
|
|
||||||
// Generate auth token (simplified - in production use proper JWT)
|
onMessageRef.current = onMessage;
|
||||||
const getAuthToken = useCallback((): string | null => {
|
onConnectRef.current = onConnect;
|
||||||
if (!session?.user) return null;
|
onDisconnectRef.current = onDisconnect;
|
||||||
// In production, this would be a proper JWT token
|
onErrorRef.current = onError;
|
||||||
return btoa(
|
|
||||||
JSON.stringify({ userId: session.user.id, timestamp: Date.now() }),
|
|
||||||
);
|
|
||||||
}, [session]);
|
|
||||||
|
|
||||||
const sendMessage = useCallback((message: OutgoingMessage): void => {
|
useEffect(() => {
|
||||||
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
|
const unsubscribe = globalWS.subscribe(session, {
|
||||||
wsRef.current.send(JSON.stringify(message));
|
trialId,
|
||||||
} else {
|
onMessage: (msg) => {
|
||||||
console.warn("WebSocket not connected, message not sent:", message);
|
setLastMessage(msg);
|
||||||
}
|
onMessageRef.current?.(msg);
|
||||||
}, []);
|
},
|
||||||
|
onConnect: () => {
|
||||||
const sendHeartbeat = useCallback((): void => {
|
setIsConnected(true);
|
||||||
sendMessage({ type: "heartbeat", data: {} });
|
|
||||||
}, [sendMessage]);
|
|
||||||
|
|
||||||
const scheduleHeartbeat = useCallback((): void => {
|
|
||||||
if (heartbeatTimeoutRef.current) {
|
|
||||||
clearTimeout(heartbeatTimeoutRef.current);
|
|
||||||
}
|
|
||||||
heartbeatTimeoutRef.current = setTimeout(() => {
|
|
||||||
if (isConnected && mountedRef.current) {
|
|
||||||
sendHeartbeat();
|
|
||||||
scheduleHeartbeat();
|
|
||||||
}
|
|
||||||
}, heartbeatInterval);
|
|
||||||
}, [isConnected, sendHeartbeat, heartbeatInterval]);
|
|
||||||
|
|
||||||
const handleMessage = useCallback(
|
|
||||||
(event: MessageEvent<string>): void => {
|
|
||||||
try {
|
|
||||||
const message = JSON.parse(event.data) as WebSocketMessage;
|
|
||||||
setLastMessage(message);
|
|
||||||
|
|
||||||
// Handle system messages
|
|
||||||
switch (message.type) {
|
|
||||||
case "connection_established": {
|
|
||||||
console.log(
|
|
||||||
"WebSocket connection established:",
|
|
||||||
(message as ConnectionEstablishedMessage).data,
|
|
||||||
);
|
|
||||||
setIsConnected(true);
|
|
||||||
setIsConnecting(false);
|
|
||||||
setConnectionError(null);
|
|
||||||
attemptCountRef.current = 0;
|
|
||||||
scheduleHeartbeat();
|
|
||||||
onConnect?.();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
case "heartbeat_response":
|
|
||||||
// Heartbeat acknowledged, connection is alive
|
|
||||||
break;
|
|
||||||
|
|
||||||
case "error": {
|
|
||||||
console.error("WebSocket server error:", message);
|
|
||||||
const msg =
|
|
||||||
(message as ErrorMessage).data?.message ?? "Server error";
|
|
||||||
setConnectionError(msg);
|
|
||||||
onError?.(new Event("server_error"));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Pass to user-defined message handler
|
|
||||||
onMessage?.(message);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Error parsing WebSocket message:", error);
|
|
||||||
setConnectionError("Failed to parse message");
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[onMessage, onConnect, onError, scheduleHeartbeat],
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleClose = useCallback(
|
|
||||||
(event: CloseEvent): void => {
|
|
||||||
console.log("WebSocket connection closed:", event.code, event.reason);
|
|
||||||
setIsConnected(false);
|
|
||||||
setIsConnecting(false);
|
|
||||||
|
|
||||||
if (heartbeatTimeoutRef.current) {
|
|
||||||
clearTimeout(heartbeatTimeoutRef.current);
|
|
||||||
}
|
|
||||||
|
|
||||||
onDisconnect?.();
|
|
||||||
|
|
||||||
// Attempt reconnection if not manually closed and component is still mounted
|
|
||||||
// In development, don't aggressively reconnect to prevent UI flashing
|
|
||||||
if (
|
|
||||||
event.code !== 1000 &&
|
|
||||||
mountedRef.current &&
|
|
||||||
attemptCountRef.current < reconnectAttempts &&
|
|
||||||
process.env.NODE_ENV !== "development"
|
|
||||||
) {
|
|
||||||
attemptCountRef.current++;
|
|
||||||
const delay =
|
|
||||||
reconnectInterval * Math.pow(1.5, attemptCountRef.current - 1); // Exponential backoff
|
|
||||||
|
|
||||||
console.log(
|
|
||||||
`Attempting reconnection ${attemptCountRef.current}/${reconnectAttempts} in ${delay}ms`,
|
|
||||||
);
|
|
||||||
setConnectionError(
|
|
||||||
`Connection lost. Reconnecting... (${attemptCountRef.current}/${reconnectAttempts})`,
|
|
||||||
);
|
|
||||||
|
|
||||||
reconnectTimeoutRef.current = setTimeout(() => {
|
|
||||||
if (mountedRef.current) {
|
|
||||||
attemptCountRef.current = 0;
|
|
||||||
setIsConnecting(true);
|
|
||||||
setConnectionError(null);
|
|
||||||
}
|
|
||||||
}, delay);
|
|
||||||
} else if (attemptCountRef.current >= reconnectAttempts) {
|
|
||||||
setConnectionError("Failed to reconnect after maximum attempts");
|
|
||||||
} else if (
|
|
||||||
process.env.NODE_ENV === "development" &&
|
|
||||||
event.code !== 1000
|
|
||||||
) {
|
|
||||||
// In development, set a stable error message without reconnection attempts
|
|
||||||
setConnectionError("WebSocket unavailable - using polling mode");
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[onDisconnect, reconnectAttempts, reconnectInterval],
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleError = useCallback(
|
|
||||||
(event: Event): void => {
|
|
||||||
// In development, WebSocket failures are expected with Edge Runtime
|
|
||||||
if (process.env.NODE_ENV === "development") {
|
|
||||||
// Only set error state after the first failed attempt to prevent flashing
|
|
||||||
if (!hasAttemptedConnection) {
|
|
||||||
setHasAttemptedConnection(true);
|
|
||||||
// Debounce the error state to prevent UI flashing
|
|
||||||
if (connectionStableTimeoutRef.current) {
|
|
||||||
clearTimeout(connectionStableTimeoutRef.current);
|
|
||||||
}
|
|
||||||
connectionStableTimeoutRef.current = setTimeout(() => {
|
|
||||||
setConnectionError("WebSocket unavailable - using polling mode");
|
|
||||||
setIsConnecting(false);
|
|
||||||
}, 1000);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
console.error("WebSocket error:", event);
|
|
||||||
setConnectionError("Connection error");
|
|
||||||
setIsConnecting(false);
|
setIsConnecting(false);
|
||||||
}
|
setConnectionError(null);
|
||||||
onError?.(event);
|
onConnectRef.current?.();
|
||||||
},
|
},
|
||||||
[onError, hasAttemptedConnection],
|
onDisconnect: () => {
|
||||||
);
|
setIsConnected(false);
|
||||||
|
onDisconnectRef.current?.();
|
||||||
|
},
|
||||||
|
onError: (err) => {
|
||||||
|
setConnectionError("Connection error");
|
||||||
|
onErrorRef.current?.(err);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
const connectInternal = useCallback((): void => {
|
return unsubscribe;
|
||||||
if (!session?.user || !trialId) {
|
}, [trialId, session]);
|
||||||
if (!hasAttemptedConnection) {
|
|
||||||
setConnectionError("Missing authentication or trial ID");
|
|
||||||
setHasAttemptedConnection(true);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
const sendMessage = useCallback((message: OutgoingMessage) => {
|
||||||
wsRef.current &&
|
globalWS.sendMessage(message);
|
||||||
(wsRef.current.readyState === WebSocket.CONNECTING ||
|
|
||||||
wsRef.current.readyState === WebSocket.OPEN)
|
|
||||||
) {
|
|
||||||
return; // Already connecting or connected
|
|
||||||
}
|
|
||||||
|
|
||||||
const token = getAuthToken();
|
|
||||||
if (!token) {
|
|
||||||
if (!hasAttemptedConnection) {
|
|
||||||
setConnectionError("Failed to generate auth token");
|
|
||||||
setHasAttemptedConnection(true);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only show connecting state for the first attempt or if we've been stable
|
|
||||||
if (!hasAttemptedConnection || isConnected) {
|
|
||||||
setIsConnecting(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear any pending error updates
|
|
||||||
if (connectionStableTimeoutRef.current) {
|
|
||||||
clearTimeout(connectionStableTimeoutRef.current);
|
|
||||||
}
|
|
||||||
|
|
||||||
setConnectionError(null);
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Use appropriate WebSocket URL based on environment
|
|
||||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
|
||||||
const wsUrl = `${protocol}//${window.location.host}/api/websocket?trialId=${trialId}&token=${token}`;
|
|
||||||
|
|
||||||
wsRef.current = new WebSocket(wsUrl);
|
|
||||||
wsRef.current.onmessage = handleMessage;
|
|
||||||
wsRef.current.onclose = handleClose;
|
|
||||||
wsRef.current.onerror = handleError;
|
|
||||||
|
|
||||||
wsRef.current.onopen = () => {
|
|
||||||
console.log("WebSocket connection opened");
|
|
||||||
// Connection establishment is handled in handleMessage
|
|
||||||
};
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Failed to create WebSocket connection:", error);
|
|
||||||
if (!hasAttemptedConnection) {
|
|
||||||
setConnectionError("Failed to create connection");
|
|
||||||
setHasAttemptedConnection(true);
|
|
||||||
}
|
|
||||||
setIsConnecting(false);
|
|
||||||
}
|
|
||||||
}, [
|
|
||||||
session,
|
|
||||||
trialId,
|
|
||||||
getAuthToken,
|
|
||||||
handleMessage,
|
|
||||||
handleClose,
|
|
||||||
handleError,
|
|
||||||
hasAttemptedConnection,
|
|
||||||
isConnected,
|
|
||||||
]);
|
|
||||||
|
|
||||||
const disconnect = useCallback((): void => {
|
|
||||||
mountedRef.current = false;
|
|
||||||
|
|
||||||
if (reconnectTimeoutRef.current) {
|
|
||||||
clearTimeout(reconnectTimeoutRef.current);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (heartbeatTimeoutRef.current) {
|
|
||||||
clearTimeout(heartbeatTimeoutRef.current);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (connectionStableTimeoutRef.current) {
|
|
||||||
clearTimeout(connectionStableTimeoutRef.current);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wsRef.current) {
|
|
||||||
wsRef.current.close(1000, "Manual disconnect");
|
|
||||||
wsRef.current = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
setIsConnected(false);
|
|
||||||
setIsConnecting(false);
|
|
||||||
setConnectionError(null);
|
|
||||||
setHasAttemptedConnection(false);
|
|
||||||
attemptCountRef.current = 0;
|
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const reconnect = useCallback((): void => {
|
const disconnect = useCallback(() => {
|
||||||
disconnect();
|
globalWS.disconnect();
|
||||||
mountedRef.current = true;
|
}, []);
|
||||||
attemptCountRef.current = 0;
|
|
||||||
setHasAttemptedConnection(false);
|
|
||||||
setTimeout(() => {
|
|
||||||
if (mountedRef.current) {
|
|
||||||
void connectInternal();
|
|
||||||
}
|
|
||||||
}, 100); // Small delay to ensure cleanup
|
|
||||||
}, [disconnect, connectInternal]);
|
|
||||||
|
|
||||||
// Effect to establish initial connection
|
const reconnect = useCallback(() => {
|
||||||
useEffect(() => {
|
globalWS.connect();
|
||||||
if (session?.user?.id && trialId) {
|
}, []);
|
||||||
// In development, only attempt connection once to prevent flashing
|
|
||||||
if (process.env.NODE_ENV === "development" && hasAttemptedConnection) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trigger reconnection if timeout was set
|
|
||||||
if (reconnectTimeoutRef.current) {
|
|
||||||
clearTimeout(reconnectTimeoutRef.current);
|
|
||||||
reconnectTimeoutRef.current = null;
|
|
||||||
void connectInternal();
|
|
||||||
} else {
|
|
||||||
void connectInternal();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return () => {
|
|
||||||
mountedRef.current = false;
|
|
||||||
disconnect();
|
|
||||||
};
|
|
||||||
}, [session?.user?.id, trialId, hasAttemptedConnection]);
|
|
||||||
|
|
||||||
// Cleanup on unmount
|
|
||||||
useEffect(() => {
|
|
||||||
return () => {
|
|
||||||
mountedRef.current = false;
|
|
||||||
if (connectionStableTimeoutRef.current) {
|
|
||||||
clearTimeout(connectionStableTimeoutRef.current);
|
|
||||||
}
|
|
||||||
disconnect();
|
|
||||||
};
|
|
||||||
}, [disconnect]);
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isConnected,
|
isConnected,
|
||||||
@@ -458,115 +454,180 @@ export function useWebSocket({
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hook for trial-specific WebSocket events
|
// Legacy alias
|
||||||
export function useTrialWebSocket(trialId: string) {
|
export const useWebSocket = useGlobalWebSocket;
|
||||||
const [trialEvents, setTrialEvents] = useState<WebSocketMessage[]>([]);
|
|
||||||
const [currentTrialStatus, setCurrentTrialStatus] =
|
|
||||||
useState<TrialSnapshot | null>(null);
|
|
||||||
const [wizardActions, setWizardActions] = useState<WebSocketMessage[]>([]);
|
|
||||||
|
|
||||||
const handleMessage = useCallback((message: WebSocketMessage): void => {
|
// Trial-specific hook
|
||||||
// Add to events log
|
export interface TrialEvent {
|
||||||
setTrialEvents((prev) => [...prev, message].slice(-100)); // Keep last 100 events
|
id: string;
|
||||||
|
trialId: string;
|
||||||
|
eventType: string;
|
||||||
|
data: Record<string, unknown> | null;
|
||||||
|
timestamp: Date;
|
||||||
|
createdBy?: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
switch (message.type) {
|
export interface TrialWebSocketState {
|
||||||
case "trial_status": {
|
trialEvents: TrialEvent[];
|
||||||
const data = (message as TrialStatusMessage).data;
|
currentTrialStatus: TrialSnapshot | null;
|
||||||
setCurrentTrialStatus(data.trial);
|
wizardActions: WebSocketMessage[];
|
||||||
break;
|
}
|
||||||
|
|
||||||
|
export function useTrialWebSocket(
|
||||||
|
trialId: string,
|
||||||
|
options?: {
|
||||||
|
onTrialEvent?: (event: TrialEvent) => void;
|
||||||
|
onStatusChange?: (status: TrialSnapshot) => void;
|
||||||
|
initialEvents?: TrialEvent[];
|
||||||
|
initialStatus?: TrialSnapshot | null;
|
||||||
|
},
|
||||||
|
) {
|
||||||
|
const [state, setState] = useState<TrialWebSocketState>({
|
||||||
|
trialEvents: options?.initialEvents ?? [],
|
||||||
|
currentTrialStatus: options?.initialStatus ?? null,
|
||||||
|
wizardActions: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
const handleMessage = useCallback(
|
||||||
|
(message: WebSocketMessage): void => {
|
||||||
|
switch (message.type) {
|
||||||
|
case "trial_status": {
|
||||||
|
const data = (message as TrialStatusMessage).data;
|
||||||
|
const status = data.trial as TrialSnapshot;
|
||||||
|
setState((prev) => ({
|
||||||
|
...prev,
|
||||||
|
currentTrialStatus: status,
|
||||||
|
}));
|
||||||
|
options?.onStatusChange?.(status);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "trial_events_snapshot": {
|
||||||
|
const data = (message as TrialEventsSnapshotMessage).data;
|
||||||
|
const events = (
|
||||||
|
data.events as Array<{
|
||||||
|
id: string;
|
||||||
|
trialId: string;
|
||||||
|
eventType: string;
|
||||||
|
data: Record<string, unknown> | null;
|
||||||
|
timestamp: Date | string;
|
||||||
|
createdBy?: string | null;
|
||||||
|
}>
|
||||||
|
).map((e) => ({
|
||||||
|
...e,
|
||||||
|
timestamp:
|
||||||
|
typeof e.timestamp === "string"
|
||||||
|
? new Date(e.timestamp)
|
||||||
|
: e.timestamp,
|
||||||
|
}));
|
||||||
|
setState((prev) => ({
|
||||||
|
...prev,
|
||||||
|
trialEvents: events,
|
||||||
|
}));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "trial_event": {
|
||||||
|
const data = (message as TrialEventMessage).data;
|
||||||
|
const event = data.event as {
|
||||||
|
id: string;
|
||||||
|
trialId: string;
|
||||||
|
eventType: string;
|
||||||
|
data: Record<string, unknown> | null;
|
||||||
|
timestamp: Date | string;
|
||||||
|
createdBy?: string | null;
|
||||||
|
};
|
||||||
|
const newEvent: TrialEvent = {
|
||||||
|
...event,
|
||||||
|
timestamp:
|
||||||
|
typeof event.timestamp === "string"
|
||||||
|
? new Date(event.timestamp)
|
||||||
|
: event.timestamp,
|
||||||
|
};
|
||||||
|
setState((prev) => ({
|
||||||
|
...prev,
|
||||||
|
trialEvents: [...prev.trialEvents, newEvent].slice(-500),
|
||||||
|
}));
|
||||||
|
options?.onTrialEvent?.(newEvent);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "trial_action_executed":
|
||||||
|
case "intervention_logged":
|
||||||
|
case "annotation_added":
|
||||||
|
case "step_changed": {
|
||||||
|
setState((prev) => ({
|
||||||
|
...prev,
|
||||||
|
wizardActions: [...prev.wizardActions, message].slice(-100),
|
||||||
|
}));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "pong":
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
if (process.env.NODE_ENV === "development") {
|
||||||
|
console.log(`[WS] Unknown message type: ${message.type}`);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
[options],
|
||||||
|
);
|
||||||
|
|
||||||
case "trial_action_executed":
|
const webSocket = useGlobalWebSocket({
|
||||||
case "intervention_logged":
|
|
||||||
case "step_changed":
|
|
||||||
setWizardActions((prev) => [...prev, message].slice(-50)); // Keep last 50 actions
|
|
||||||
break;
|
|
||||||
|
|
||||||
case "step_changed":
|
|
||||||
// Handle step transitions (optional logging)
|
|
||||||
console.log("Step changed:", (message as StepChangedMessage).data);
|
|
||||||
break;
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Handle other trial-specific messages
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const webSocket = useWebSocket({
|
|
||||||
trialId,
|
trialId,
|
||||||
onMessage: handleMessage,
|
onMessage: handleMessage,
|
||||||
onConnect: () => {
|
onConnect: () => {
|
||||||
if (process.env.NODE_ENV === "development") {
|
if (process.env.NODE_ENV === "development") {
|
||||||
console.log(`Connected to trial ${trialId} WebSocket`);
|
console.log(`[WS] Connected to trial ${trialId}`);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
onDisconnect: () => {
|
onDisconnect: () => {
|
||||||
if (process.env.NODE_ENV === "development") {
|
if (process.env.NODE_ENV === "development") {
|
||||||
console.log(`Disconnected from trial ${trialId} WebSocket`);
|
console.log(`[WS] Disconnected from trial ${trialId}`);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
onError: () => {
|
onError: () => {
|
||||||
// Suppress noisy WebSocket errors in development
|
|
||||||
if (process.env.NODE_ENV !== "development") {
|
if (process.env.NODE_ENV !== "development") {
|
||||||
console.error(`Trial ${trialId} WebSocket connection failed`);
|
console.error(`[WS] Trial ${trialId} WebSocket connection failed`);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Request trial status after connection is established
|
// Request initial data after connection is established
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (webSocket.isConnected) {
|
if (webSocket.isConnected) {
|
||||||
webSocket.sendMessage({ type: "request_trial_status", data: {} });
|
webSocket.sendMessage({ type: "request_trial_status", data: {} });
|
||||||
|
webSocket.sendMessage({
|
||||||
|
type: "request_trial_events",
|
||||||
|
data: { limit: 500 },
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}, [webSocket.isConnected, webSocket]);
|
}, [webSocket.isConnected]);
|
||||||
|
|
||||||
// Trial-specific actions
|
// Helper to add an event locally (for optimistic updates)
|
||||||
const executeTrialAction = useCallback(
|
const addLocalEvent = useCallback((event: TrialEvent) => {
|
||||||
(actionType: string, actionData: Record<string, unknown>): void => {
|
setState((prev) => ({
|
||||||
webSocket.sendMessage({
|
...prev,
|
||||||
type: "trial_action",
|
trialEvents: [...prev.trialEvents, event].slice(-500),
|
||||||
data: {
|
}));
|
||||||
actionType,
|
}, []);
|
||||||
...actionData,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[webSocket],
|
|
||||||
);
|
|
||||||
|
|
||||||
const logWizardIntervention = useCallback(
|
// Helper to update trial status locally
|
||||||
(interventionData: Record<string, unknown>): void => {
|
const updateLocalStatus = useCallback((status: TrialSnapshot) => {
|
||||||
webSocket.sendMessage({
|
setState((prev) => ({
|
||||||
type: "wizard_intervention",
|
...prev,
|
||||||
data: interventionData,
|
currentTrialStatus: status,
|
||||||
});
|
}));
|
||||||
},
|
}, []);
|
||||||
[webSocket],
|
|
||||||
);
|
|
||||||
|
|
||||||
const transitionStep = useCallback(
|
|
||||||
(stepData: {
|
|
||||||
from_step?: number;
|
|
||||||
to_step: number;
|
|
||||||
step_name?: string;
|
|
||||||
[k: string]: unknown;
|
|
||||||
}): void => {
|
|
||||||
webSocket.sendMessage({
|
|
||||||
type: "step_transition",
|
|
||||||
data: stepData,
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[webSocket],
|
|
||||||
);
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...webSocket,
|
...webSocket,
|
||||||
trialEvents,
|
trialEvents: state.trialEvents,
|
||||||
currentTrialStatus,
|
currentTrialStatus: state.currentTrialStatus,
|
||||||
wizardActions,
|
wizardActions: state.wizardActions,
|
||||||
executeTrialAction,
|
addLocalEvent,
|
||||||
logWizardIntervention,
|
updateLocalStatus,
|
||||||
transitionStep,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ import { GetObjectCommand } from "@aws-sdk/client-s3";
|
|||||||
import { getSignedUrl } from "@aws-sdk/s3-request-presigner";
|
import { getSignedUrl } from "@aws-sdk/s3-request-presigner";
|
||||||
import { env } from "~/env";
|
import { env } from "~/env";
|
||||||
import { uploadFile } from "~/lib/storage/minio";
|
import { uploadFile } from "~/lib/storage/minio";
|
||||||
|
import { wsManager } from "~/server/services/websocket-manager";
|
||||||
|
|
||||||
// Helper function to check if user has access to trial
|
// Helper function to check if user has access to trial
|
||||||
async function checkTrialAccess(
|
async function checkTrialAccess(
|
||||||
@@ -591,6 +592,16 @@ export const trialsRouter = createTRPCRouter({
|
|||||||
data: { userId },
|
data: { userId },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Broadcast trial status update
|
||||||
|
await wsManager.broadcast(input.id, {
|
||||||
|
type: "trial_status",
|
||||||
|
data: {
|
||||||
|
trial: trial[0],
|
||||||
|
current_step_index: 0,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
return trial[0];
|
return trial[0];
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -643,6 +654,16 @@ export const trialsRouter = createTRPCRouter({
|
|||||||
data: { userId, notes: input.notes },
|
data: { userId, notes: input.notes },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Broadcast trial status update
|
||||||
|
await wsManager.broadcast(input.id, {
|
||||||
|
type: "trial_status",
|
||||||
|
data: {
|
||||||
|
trial,
|
||||||
|
current_step_index: 0,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
return trial;
|
return trial;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -696,6 +717,16 @@ export const trialsRouter = createTRPCRouter({
|
|||||||
data: { userId, reason: input.reason },
|
data: { userId, reason: input.reason },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Broadcast trial status update
|
||||||
|
await wsManager.broadcast(input.id, {
|
||||||
|
type: "trial_status",
|
||||||
|
data: {
|
||||||
|
trial: trial[0],
|
||||||
|
current_step_index: 0,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
return trial[0];
|
return trial[0];
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -846,6 +877,15 @@ export const trialsRouter = createTRPCRouter({
|
|||||||
})
|
})
|
||||||
.returning();
|
.returning();
|
||||||
|
|
||||||
|
// Broadcast new event to all subscribers
|
||||||
|
await wsManager.broadcast(input.trialId, {
|
||||||
|
type: "trial_event",
|
||||||
|
data: {
|
||||||
|
event,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
return event;
|
return event;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -881,6 +921,15 @@ export const trialsRouter = createTRPCRouter({
|
|||||||
})
|
})
|
||||||
.returning();
|
.returning();
|
||||||
|
|
||||||
|
// Broadcast intervention to all subscribers
|
||||||
|
await wsManager.broadcast(input.trialId, {
|
||||||
|
type: "intervention_logged",
|
||||||
|
data: {
|
||||||
|
intervention,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
return intervention;
|
return intervention;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -936,6 +985,15 @@ export const trialsRouter = createTRPCRouter({
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Broadcast annotation to all subscribers
|
||||||
|
await wsManager.broadcast(input.trialId, {
|
||||||
|
type: "annotation_added",
|
||||||
|
data: {
|
||||||
|
annotation,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
return annotation;
|
return annotation;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
@@ -1302,20 +1360,33 @@ export const trialsRouter = createTRPCRouter({
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Log the manual robot action execution
|
// Log the manual robot action execution
|
||||||
await db.insert(trialEvents).values({
|
const [event] = await db
|
||||||
trialId: input.trialId,
|
.insert(trialEvents)
|
||||||
eventType: "manual_robot_action",
|
.values({
|
||||||
actionId: null, // Ad-hoc action, not linked to a protocol action definition
|
trialId: input.trialId,
|
||||||
|
eventType: "manual_robot_action",
|
||||||
|
actionId: null,
|
||||||
|
data: {
|
||||||
|
userId,
|
||||||
|
pluginName: input.pluginName,
|
||||||
|
actionId: input.actionId,
|
||||||
|
parameters: input.parameters,
|
||||||
|
result: result.data,
|
||||||
|
duration: result.duration,
|
||||||
|
},
|
||||||
|
timestamp: new Date(),
|
||||||
|
createdBy: userId,
|
||||||
|
})
|
||||||
|
.returning();
|
||||||
|
|
||||||
|
// Broadcast robot action to all subscribers
|
||||||
|
await wsManager.broadcast(input.trialId, {
|
||||||
|
type: "trial_action_executed",
|
||||||
data: {
|
data: {
|
||||||
userId,
|
action_type: `${input.pluginName}.${input.actionId}`,
|
||||||
pluginName: input.pluginName,
|
event,
|
||||||
actionId: input.actionId,
|
timestamp: Date.now(),
|
||||||
parameters: input.parameters,
|
|
||||||
result: result.data,
|
|
||||||
duration: result.duration,
|
|
||||||
},
|
},
|
||||||
timestamp: new Date(),
|
|
||||||
createdBy: userId,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -1347,21 +1418,34 @@ export const trialsRouter = createTRPCRouter({
|
|||||||
"wizard",
|
"wizard",
|
||||||
]);
|
]);
|
||||||
|
|
||||||
await db.insert(trialEvents).values({
|
const [event] = await db
|
||||||
trialId: input.trialId,
|
.insert(trialEvents)
|
||||||
eventType: "manual_robot_action",
|
.values({
|
||||||
|
trialId: input.trialId,
|
||||||
|
eventType: "manual_robot_action",
|
||||||
|
data: {
|
||||||
|
userId,
|
||||||
|
pluginName: input.pluginName,
|
||||||
|
actionId: input.actionId,
|
||||||
|
parameters: input.parameters,
|
||||||
|
result: input.result,
|
||||||
|
duration: input.duration,
|
||||||
|
error: input.error,
|
||||||
|
executionMode: "websocket_client",
|
||||||
|
},
|
||||||
|
timestamp: new Date(),
|
||||||
|
createdBy: userId,
|
||||||
|
})
|
||||||
|
.returning();
|
||||||
|
|
||||||
|
// Broadcast robot action to all subscribers
|
||||||
|
await wsManager.broadcast(input.trialId, {
|
||||||
|
type: "trial_action_executed",
|
||||||
data: {
|
data: {
|
||||||
userId,
|
action_type: `${input.pluginName}.${input.actionId}`,
|
||||||
pluginName: input.pluginName,
|
event,
|
||||||
actionId: input.actionId,
|
timestamp: Date.now(),
|
||||||
parameters: input.parameters,
|
|
||||||
result: input.result,
|
|
||||||
duration: input.duration,
|
|
||||||
error: input.error,
|
|
||||||
executionMode: "websocket_client",
|
|
||||||
},
|
},
|
||||||
timestamp: new Date(),
|
|
||||||
createdBy: userId,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
return { success: true };
|
return { success: true };
|
||||||
|
|||||||
@@ -485,6 +485,25 @@ export const trials = createTable("trial", {
|
|||||||
metadata: jsonb("metadata").default({}),
|
metadata: jsonb("metadata").default({}),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export const wsConnections = createTable("ws_connection", {
|
||||||
|
id: uuid("id").notNull().primaryKey().defaultRandom(),
|
||||||
|
trialId: uuid("trial_id")
|
||||||
|
.notNull()
|
||||||
|
.references(() => trials.id, { onDelete: "cascade" }),
|
||||||
|
clientId: text("client_id").notNull().unique(),
|
||||||
|
userId: text("user_id"),
|
||||||
|
connectedAt: timestamp("connected_at", { withTimezone: true })
|
||||||
|
.default(sql`CURRENT_TIMESTAMP`)
|
||||||
|
.notNull(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const wsConnectionsRelations = relations(wsConnections, ({ one }) => ({
|
||||||
|
trial: one(trials, {
|
||||||
|
fields: [wsConnections.trialId],
|
||||||
|
references: [trials.id],
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
export const steps = createTable(
|
export const steps = createTable(
|
||||||
"step",
|
"step",
|
||||||
{
|
{
|
||||||
|
|||||||
272
src/server/services/websocket-manager.ts
Normal file
272
src/server/services/websocket-manager.ts
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
import { db } from "~/server/db";
|
||||||
|
import {
|
||||||
|
trials,
|
||||||
|
trialEvents,
|
||||||
|
wsConnections,
|
||||||
|
experiments,
|
||||||
|
} from "~/server/db/schema";
|
||||||
|
import { eq, sql } from "drizzle-orm";
|
||||||
|
|
||||||
|
interface ClientConnection {
|
||||||
|
socket: WebSocket;
|
||||||
|
trialId: string;
|
||||||
|
userId: string | null;
|
||||||
|
connectedAt: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
type OutgoingMessage = {
|
||||||
|
type: string;
|
||||||
|
data: Record<string, unknown>;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WebSocketManager {
|
||||||
|
private clients: Map<string, ClientConnection> = new Map();
|
||||||
|
private heartbeatIntervals: Map<string, ReturnType<typeof setInterval>> =
|
||||||
|
new Map();
|
||||||
|
|
||||||
|
private getTrialRoomClients(trialId: string): ClientConnection[] {
|
||||||
|
const clients: ClientConnection[] = [];
|
||||||
|
for (const [, client] of this.clients) {
|
||||||
|
if (client.trialId === trialId) {
|
||||||
|
clients.push(client);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return clients;
|
||||||
|
}
|
||||||
|
|
||||||
|
addClient(clientId: string, connection: ClientConnection): void {
|
||||||
|
this.clients.set(clientId, connection);
|
||||||
|
console.log(
|
||||||
|
`[WS] Client ${clientId} added for trial ${connection.trialId}. Total: ${this.clients.size}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
removeClient(clientId: string): void {
|
||||||
|
const client = this.clients.get(clientId);
|
||||||
|
if (client) {
|
||||||
|
console.log(
|
||||||
|
`[WS] Client ${clientId} removed from trial ${client.trialId}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const heartbeatInterval = this.heartbeatIntervals.get(clientId);
|
||||||
|
if (heartbeatInterval) {
|
||||||
|
clearInterval(heartbeatInterval);
|
||||||
|
this.heartbeatIntervals.delete(clientId);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.clients.delete(clientId);
|
||||||
|
}
|
||||||
|
|
||||||
|
async subscribe(
|
||||||
|
clientId: string,
|
||||||
|
socket: WebSocket,
|
||||||
|
trialId: string,
|
||||||
|
userId: string | null,
|
||||||
|
): Promise<void> {
|
||||||
|
const client: ClientConnection = {
|
||||||
|
socket,
|
||||||
|
trialId,
|
||||||
|
userId,
|
||||||
|
connectedAt: Date.now(),
|
||||||
|
};
|
||||||
|
|
||||||
|
this.clients.set(clientId, client);
|
||||||
|
|
||||||
|
const heartbeatInterval = setInterval(() => {
|
||||||
|
this.sendToClient(clientId, { type: "heartbeat", data: {} });
|
||||||
|
}, 30000);
|
||||||
|
|
||||||
|
this.heartbeatIntervals.set(clientId, heartbeatInterval);
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`[WS] Client ${clientId} subscribed to trial ${trialId}. Total clients: ${this.clients.size}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsubscribe(clientId: string): void {
|
||||||
|
const client = this.clients.get(clientId);
|
||||||
|
if (client) {
|
||||||
|
console.log(
|
||||||
|
`[WS] Client ${clientId} unsubscribed from trial ${client.trialId}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const heartbeatInterval = this.heartbeatIntervals.get(clientId);
|
||||||
|
if (heartbeatInterval) {
|
||||||
|
clearInterval(heartbeatInterval);
|
||||||
|
this.heartbeatIntervals.delete(clientId);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.clients.delete(clientId);
|
||||||
|
}
|
||||||
|
|
||||||
|
sendToClient(clientId: string, message: OutgoingMessage): void {
|
||||||
|
const client = this.clients.get(clientId);
|
||||||
|
if (client?.socket.readyState === 1) {
|
||||||
|
try {
|
||||||
|
client.socket.send(JSON.stringify(message));
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`[WS] Error sending to client ${clientId}:`, error);
|
||||||
|
this.unsubscribe(clientId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async broadcast(trialId: string, message: OutgoingMessage): Promise<void> {
|
||||||
|
const clients = this.getTrialRoomClients(trialId);
|
||||||
|
|
||||||
|
if (clients.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const messageStr = JSON.stringify(message);
|
||||||
|
const disconnectedClients: string[] = [];
|
||||||
|
|
||||||
|
for (const [clientId, client] of this.clients) {
|
||||||
|
if (client.trialId === trialId && client.socket.readyState === 1) {
|
||||||
|
try {
|
||||||
|
client.socket.send(messageStr);
|
||||||
|
} catch (error) {
|
||||||
|
console.error(
|
||||||
|
`[WS] Error broadcasting to client ${clientId}:`,
|
||||||
|
error,
|
||||||
|
);
|
||||||
|
disconnectedClients.push(clientId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const clientId of disconnectedClients) {
|
||||||
|
this.unsubscribe(clientId);
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`[WS] Broadcast to ${clients.length} clients for trial ${trialId}: ${message.type}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
async broadcastToAll(message: OutgoingMessage): Promise<void> {
|
||||||
|
const messageStr = JSON.stringify(message);
|
||||||
|
const disconnectedClients: string[] = [];
|
||||||
|
|
||||||
|
for (const [clientId, client] of this.clients) {
|
||||||
|
if (client.socket.readyState === 1) {
|
||||||
|
try {
|
||||||
|
client.socket.send(messageStr);
|
||||||
|
} catch (error) {
|
||||||
|
console.error(
|
||||||
|
`[WS] Error broadcasting to client ${clientId}:`,
|
||||||
|
error,
|
||||||
|
);
|
||||||
|
disconnectedClients.push(clientId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const clientId of disconnectedClients) {
|
||||||
|
this.unsubscribe(clientId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async getTrialStatus(trialId: string): Promise<{
|
||||||
|
trial: {
|
||||||
|
id: string;
|
||||||
|
status: string;
|
||||||
|
startedAt: Date | null;
|
||||||
|
completedAt: Date | null;
|
||||||
|
};
|
||||||
|
currentStepIndex: number;
|
||||||
|
} | null> {
|
||||||
|
const [trial] = await db
|
||||||
|
.select({
|
||||||
|
id: trials.id,
|
||||||
|
status: trials.status,
|
||||||
|
startedAt: trials.startedAt,
|
||||||
|
completedAt: trials.completedAt,
|
||||||
|
})
|
||||||
|
.from(trials)
|
||||||
|
.where(eq(trials.id, trialId))
|
||||||
|
.limit(1);
|
||||||
|
|
||||||
|
if (!trial) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
trial: {
|
||||||
|
id: trial.id,
|
||||||
|
status: trial.status,
|
||||||
|
startedAt: trial.startedAt,
|
||||||
|
completedAt: trial.completedAt,
|
||||||
|
},
|
||||||
|
currentStepIndex: 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async getTrialEvents(
|
||||||
|
trialId: string,
|
||||||
|
limit: number = 100,
|
||||||
|
): Promise<unknown[]> {
|
||||||
|
const events = await db
|
||||||
|
.select()
|
||||||
|
.from(trialEvents)
|
||||||
|
.where(eq(trialEvents.trialId, trialId))
|
||||||
|
.orderBy(trialEvents.timestamp)
|
||||||
|
.limit(limit);
|
||||||
|
|
||||||
|
return events;
|
||||||
|
}
|
||||||
|
|
||||||
|
getTrialStatusSync(trialId: string): {
|
||||||
|
trial: {
|
||||||
|
id: string;
|
||||||
|
status: string;
|
||||||
|
startedAt: Date | null;
|
||||||
|
completedAt: Date | null;
|
||||||
|
};
|
||||||
|
currentStepIndex: number;
|
||||||
|
} | null {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
getTrialEventsSync(trialId: string, limit: number = 100): unknown[] {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
getConnectionCount(trialId?: string): number {
|
||||||
|
if (trialId) {
|
||||||
|
return this.getTrialRoomClients(trialId).length;
|
||||||
|
}
|
||||||
|
return this.clients.size;
|
||||||
|
}
|
||||||
|
|
||||||
|
getConnectedTrialIds(): string[] {
|
||||||
|
const trialIds = new Set<string>();
|
||||||
|
for (const [, client] of this.clients) {
|
||||||
|
trialIds.add(client.trialId);
|
||||||
|
}
|
||||||
|
return Array.from(trialIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
async getTrialsWithActiveConnections(studyIds?: string[]): Promise<string[]> {
|
||||||
|
const conditions =
|
||||||
|
studyIds && studyIds.length > 0
|
||||||
|
? sql`${wsConnections.trialId} IN (
|
||||||
|
SELECT ${trials.id} FROM ${trials}
|
||||||
|
WHERE ${trials.experimentId} IN (
|
||||||
|
SELECT ${experiments.id} FROM ${experiments}
|
||||||
|
WHERE ${experiments.studyId} IN (${sql.raw(studyIds.map((id) => `'${id}'`).join(","))})
|
||||||
|
)
|
||||||
|
)`
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
const connections = await db
|
||||||
|
.selectDistinct({ trialId: wsConnections.trialId })
|
||||||
|
.from(wsConnections);
|
||||||
|
|
||||||
|
return connections.map((c) => c.trialId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const wsManager = new WebSocketManager();
|
||||||
192
ws-server.ts
Normal file
192
ws-server.ts
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
import { serve, type ServerWebSocket } from "bun";
|
||||||
|
import { wsManager } from "./src/server/services/websocket-manager";
|
||||||
|
import { db } from "./src/server/db";
|
||||||
|
import { wsConnections } from "./src/server/db/schema";
|
||||||
|
import { eq } from "drizzle-orm";
|
||||||
|
|
||||||
|
const port = parseInt(process.env.WS_PORT || "3001", 10);
|
||||||
|
|
||||||
|
interface WSData {
|
||||||
|
clientId: string;
|
||||||
|
trialId: string;
|
||||||
|
userId: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function generateClientId(): string {
|
||||||
|
return `ws_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function recordConnection(
|
||||||
|
clientId: string,
|
||||||
|
trialId: string,
|
||||||
|
userId: string | null,
|
||||||
|
): Promise<void> {
|
||||||
|
try {
|
||||||
|
await db.insert(wsConnections).values({
|
||||||
|
clientId,
|
||||||
|
trialId,
|
||||||
|
userId,
|
||||||
|
});
|
||||||
|
console.log(`[DB] Recorded connection for trial ${trialId}`);
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`[DB] Failed to record connection:`, error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function removeConnection(clientId: string): Promise<void> {
|
||||||
|
try {
|
||||||
|
await db.delete(wsConnections).where(eq(wsConnections.clientId, clientId));
|
||||||
|
console.log(`[DB] Removed connection ${clientId}`);
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`[DB] Failed to remove connection:`, error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`Starting WebSocket server on port ${port}...`);
|
||||||
|
|
||||||
|
serve<WSData>({
|
||||||
|
port,
|
||||||
|
fetch(req, server) {
|
||||||
|
const url = new URL(req.url);
|
||||||
|
|
||||||
|
if (url.pathname === "/api/websocket") {
|
||||||
|
if (req.headers.get("upgrade") !== "websocket") {
|
||||||
|
return new Response("WebSocket upgrade required", { status: 426 });
|
||||||
|
}
|
||||||
|
|
||||||
|
const trialId = url.searchParams.get("trialId");
|
||||||
|
const token = url.searchParams.get("token");
|
||||||
|
|
||||||
|
if (!trialId) {
|
||||||
|
return new Response("Missing trialId parameter", { status: 400 });
|
||||||
|
}
|
||||||
|
|
||||||
|
let userId: string | null = null;
|
||||||
|
if (token) {
|
||||||
|
try {
|
||||||
|
const tokenData = JSON.parse(atob(token));
|
||||||
|
userId = tokenData.userId;
|
||||||
|
} catch {
|
||||||
|
return new Response("Invalid token", { status: 401 });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const clientId = generateClientId();
|
||||||
|
const wsData: WSData = { clientId, trialId, userId };
|
||||||
|
|
||||||
|
const upgraded = server.upgrade(req, { data: wsData });
|
||||||
|
|
||||||
|
if (!upgraded) {
|
||||||
|
return new Response("WebSocket upgrade failed", { status: 500 });
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Response("Not found", { status: 404 });
|
||||||
|
},
|
||||||
|
websocket: {
|
||||||
|
async open(ws: ServerWebSocket<WSData>) {
|
||||||
|
const { clientId, trialId, userId } = ws.data;
|
||||||
|
|
||||||
|
wsManager.addClient(clientId, {
|
||||||
|
socket: ws as unknown as WebSocket,
|
||||||
|
trialId,
|
||||||
|
userId,
|
||||||
|
connectedAt: Date.now(),
|
||||||
|
});
|
||||||
|
|
||||||
|
await recordConnection(clientId, trialId, userId);
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`[WS] Client ${clientId} connected to trial ${trialId}. Total: ${wsManager.getConnectionCount()}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
ws.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: "connection_established",
|
||||||
|
data: {
|
||||||
|
trialId,
|
||||||
|
userId,
|
||||||
|
role: "connected",
|
||||||
|
connectedAt: Date.now(),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
},
|
||||||
|
message(ws: ServerWebSocket<WSData>, message) {
|
||||||
|
const { clientId, trialId } = ws.data;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const msg = JSON.parse(message.toString());
|
||||||
|
|
||||||
|
switch (msg.type) {
|
||||||
|
case "heartbeat":
|
||||||
|
ws.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: "heartbeat_response",
|
||||||
|
data: { timestamp: Date.now() },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "request_trial_status": {
|
||||||
|
const status = wsManager.getTrialStatusSync(trialId);
|
||||||
|
ws.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: "trial_status",
|
||||||
|
data: {
|
||||||
|
trial: status?.trial ?? null,
|
||||||
|
current_step_index: status?.currentStepIndex ?? 0,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "request_trial_events": {
|
||||||
|
const events = wsManager.getTrialEventsSync(
|
||||||
|
trialId,
|
||||||
|
msg.data?.limit ?? 100,
|
||||||
|
);
|
||||||
|
ws.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: "trial_events_snapshot",
|
||||||
|
data: { events, timestamp: Date.now() },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case "ping":
|
||||||
|
ws.send(
|
||||||
|
JSON.stringify({
|
||||||
|
type: "pong",
|
||||||
|
data: { timestamp: Date.now() },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
console.log(
|
||||||
|
`[WS] Unknown message type from ${clientId}:`,
|
||||||
|
msg.type,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`[WS] Error processing message from ${clientId}:`, error);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
async close(ws: ServerWebSocket<WSData>) {
|
||||||
|
const { clientId } = ws.data;
|
||||||
|
console.log(`[WS] Client ${clientId} disconnected`);
|
||||||
|
wsManager.removeClient(clientId);
|
||||||
|
await removeConnection(clientId);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`> WebSocket server running on ws://localhost:${port}/api/websocket`,
|
||||||
|
);
|
||||||
Reference in New Issue
Block a user