feat: implement WebSocket for real-time trial updates

- Create standalone WebSocket server (ws-server.ts) on port 3001 using Bun
- Add ws_connections table to track active connections in database
- Create global WebSocket manager that persists across component unmounts
- Fix useWebSocket hook to prevent infinite re-renders and use refs
- Fix TrialForm Select components with proper default values
- Add trialId to WebSocket URL for server-side tracking
- Update package.json with dev:ws script for separate WS server
This commit is contained in:
2026-03-22 00:48:43 -04:00
parent 20d6d3de1a
commit a5762ec935
9 changed files with 1257 additions and 481 deletions
+109 -25
View File
@@ -35,6 +35,7 @@ import { GetObjectCommand } from "@aws-sdk/client-s3";
import { getSignedUrl } from "@aws-sdk/s3-request-presigner";
import { env } from "~/env";
import { uploadFile } from "~/lib/storage/minio";
import { wsManager } from "~/server/services/websocket-manager";
// Helper function to check if user has access to trial
async function checkTrialAccess(
@@ -591,6 +592,16 @@ export const trialsRouter = createTRPCRouter({
data: { userId },
});
// Broadcast trial status update
await wsManager.broadcast(input.id, {
type: "trial_status",
data: {
trial: trial[0],
current_step_index: 0,
timestamp: Date.now(),
},
});
return trial[0];
}),
@@ -643,6 +654,16 @@ export const trialsRouter = createTRPCRouter({
data: { userId, notes: input.notes },
});
// Broadcast trial status update
await wsManager.broadcast(input.id, {
type: "trial_status",
data: {
trial,
current_step_index: 0,
timestamp: Date.now(),
},
});
return trial;
}),
@@ -696,6 +717,16 @@ export const trialsRouter = createTRPCRouter({
data: { userId, reason: input.reason },
});
// Broadcast trial status update
await wsManager.broadcast(input.id, {
type: "trial_status",
data: {
trial: trial[0],
current_step_index: 0,
timestamp: Date.now(),
},
});
return trial[0];
}),
@@ -846,6 +877,15 @@ export const trialsRouter = createTRPCRouter({
})
.returning();
// Broadcast new event to all subscribers
await wsManager.broadcast(input.trialId, {
type: "trial_event",
data: {
event,
timestamp: Date.now(),
},
});
return event;
}),
@@ -881,6 +921,15 @@ export const trialsRouter = createTRPCRouter({
})
.returning();
// Broadcast intervention to all subscribers
await wsManager.broadcast(input.trialId, {
type: "intervention_logged",
data: {
intervention,
timestamp: Date.now(),
},
});
return intervention;
}),
@@ -936,6 +985,15 @@ export const trialsRouter = createTRPCRouter({
});
}
// Broadcast annotation to all subscribers
await wsManager.broadcast(input.trialId, {
type: "annotation_added",
data: {
annotation,
timestamp: Date.now(),
},
});
return annotation;
}),
@@ -1302,20 +1360,33 @@ export const trialsRouter = createTRPCRouter({
}
// Log the manual robot action execution
await db.insert(trialEvents).values({
trialId: input.trialId,
eventType: "manual_robot_action",
actionId: null, // Ad-hoc action, not linked to a protocol action definition
const [event] = await db
.insert(trialEvents)
.values({
trialId: input.trialId,
eventType: "manual_robot_action",
actionId: null,
data: {
userId,
pluginName: input.pluginName,
actionId: input.actionId,
parameters: input.parameters,
result: result.data,
duration: result.duration,
},
timestamp: new Date(),
createdBy: userId,
})
.returning();
// Broadcast robot action to all subscribers
await wsManager.broadcast(input.trialId, {
type: "trial_action_executed",
data: {
userId,
pluginName: input.pluginName,
actionId: input.actionId,
parameters: input.parameters,
result: result.data,
duration: result.duration,
action_type: `${input.pluginName}.${input.actionId}`,
event,
timestamp: Date.now(),
},
timestamp: new Date(),
createdBy: userId,
});
return {
@@ -1347,21 +1418,34 @@ export const trialsRouter = createTRPCRouter({
"wizard",
]);
await db.insert(trialEvents).values({
trialId: input.trialId,
eventType: "manual_robot_action",
const [event] = await db
.insert(trialEvents)
.values({
trialId: input.trialId,
eventType: "manual_robot_action",
data: {
userId,
pluginName: input.pluginName,
actionId: input.actionId,
parameters: input.parameters,
result: input.result,
duration: input.duration,
error: input.error,
executionMode: "websocket_client",
},
timestamp: new Date(),
createdBy: userId,
})
.returning();
// Broadcast robot action to all subscribers
await wsManager.broadcast(input.trialId, {
type: "trial_action_executed",
data: {
userId,
pluginName: input.pluginName,
actionId: input.actionId,
parameters: input.parameters,
result: input.result,
duration: input.duration,
error: input.error,
executionMode: "websocket_client",
action_type: `${input.pluginName}.${input.actionId}`,
event,
timestamp: Date.now(),
},
timestamp: new Date(),
createdBy: userId,
});
return { success: true };
+19
View File
@@ -485,6 +485,25 @@ export const trials = createTable("trial", {
metadata: jsonb("metadata").default({}),
});
export const wsConnections = createTable("ws_connection", {
id: uuid("id").notNull().primaryKey().defaultRandom(),
trialId: uuid("trial_id")
.notNull()
.references(() => trials.id, { onDelete: "cascade" }),
clientId: text("client_id").notNull().unique(),
userId: text("user_id"),
connectedAt: timestamp("connected_at", { withTimezone: true })
.default(sql`CURRENT_TIMESTAMP`)
.notNull(),
});
export const wsConnectionsRelations = relations(wsConnections, ({ one }) => ({
trial: one(trials, {
fields: [wsConnections.trialId],
references: [trials.id],
}),
}));
export const steps = createTable(
"step",
{
+272
View File
@@ -0,0 +1,272 @@
import { db } from "~/server/db";
import {
trials,
trialEvents,
wsConnections,
experiments,
} from "~/server/db/schema";
import { eq, sql } from "drizzle-orm";
interface ClientConnection {
socket: WebSocket;
trialId: string;
userId: string | null;
connectedAt: number;
}
type OutgoingMessage = {
type: string;
data: Record<string, unknown>;
};
class WebSocketManager {
private clients: Map<string, ClientConnection> = new Map();
private heartbeatIntervals: Map<string, ReturnType<typeof setInterval>> =
new Map();
private getTrialRoomClients(trialId: string): ClientConnection[] {
const clients: ClientConnection[] = [];
for (const [, client] of this.clients) {
if (client.trialId === trialId) {
clients.push(client);
}
}
return clients;
}
addClient(clientId: string, connection: ClientConnection): void {
this.clients.set(clientId, connection);
console.log(
`[WS] Client ${clientId} added for trial ${connection.trialId}. Total: ${this.clients.size}`,
);
}
removeClient(clientId: string): void {
const client = this.clients.get(clientId);
if (client) {
console.log(
`[WS] Client ${clientId} removed from trial ${client.trialId}`,
);
}
const heartbeatInterval = this.heartbeatIntervals.get(clientId);
if (heartbeatInterval) {
clearInterval(heartbeatInterval);
this.heartbeatIntervals.delete(clientId);
}
this.clients.delete(clientId);
}
async subscribe(
clientId: string,
socket: WebSocket,
trialId: string,
userId: string | null,
): Promise<void> {
const client: ClientConnection = {
socket,
trialId,
userId,
connectedAt: Date.now(),
};
this.clients.set(clientId, client);
const heartbeatInterval = setInterval(() => {
this.sendToClient(clientId, { type: "heartbeat", data: {} });
}, 30000);
this.heartbeatIntervals.set(clientId, heartbeatInterval);
console.log(
`[WS] Client ${clientId} subscribed to trial ${trialId}. Total clients: ${this.clients.size}`,
);
}
unsubscribe(clientId: string): void {
const client = this.clients.get(clientId);
if (client) {
console.log(
`[WS] Client ${clientId} unsubscribed from trial ${client.trialId}`,
);
}
const heartbeatInterval = this.heartbeatIntervals.get(clientId);
if (heartbeatInterval) {
clearInterval(heartbeatInterval);
this.heartbeatIntervals.delete(clientId);
}
this.clients.delete(clientId);
}
sendToClient(clientId: string, message: OutgoingMessage): void {
const client = this.clients.get(clientId);
if (client?.socket.readyState === 1) {
try {
client.socket.send(JSON.stringify(message));
} catch (error) {
console.error(`[WS] Error sending to client ${clientId}:`, error);
this.unsubscribe(clientId);
}
}
}
async broadcast(trialId: string, message: OutgoingMessage): Promise<void> {
const clients = this.getTrialRoomClients(trialId);
if (clients.length === 0) {
return;
}
const messageStr = JSON.stringify(message);
const disconnectedClients: string[] = [];
for (const [clientId, client] of this.clients) {
if (client.trialId === trialId && client.socket.readyState === 1) {
try {
client.socket.send(messageStr);
} catch (error) {
console.error(
`[WS] Error broadcasting to client ${clientId}:`,
error,
);
disconnectedClients.push(clientId);
}
}
}
for (const clientId of disconnectedClients) {
this.unsubscribe(clientId);
}
console.log(
`[WS] Broadcast to ${clients.length} clients for trial ${trialId}: ${message.type}`,
);
}
async broadcastToAll(message: OutgoingMessage): Promise<void> {
const messageStr = JSON.stringify(message);
const disconnectedClients: string[] = [];
for (const [clientId, client] of this.clients) {
if (client.socket.readyState === 1) {
try {
client.socket.send(messageStr);
} catch (error) {
console.error(
`[WS] Error broadcasting to client ${clientId}:`,
error,
);
disconnectedClients.push(clientId);
}
}
}
for (const clientId of disconnectedClients) {
this.unsubscribe(clientId);
}
}
async getTrialStatus(trialId: string): Promise<{
trial: {
id: string;
status: string;
startedAt: Date | null;
completedAt: Date | null;
};
currentStepIndex: number;
} | null> {
const [trial] = await db
.select({
id: trials.id,
status: trials.status,
startedAt: trials.startedAt,
completedAt: trials.completedAt,
})
.from(trials)
.where(eq(trials.id, trialId))
.limit(1);
if (!trial) {
return null;
}
return {
trial: {
id: trial.id,
status: trial.status,
startedAt: trial.startedAt,
completedAt: trial.completedAt,
},
currentStepIndex: 0,
};
}
async getTrialEvents(
trialId: string,
limit: number = 100,
): Promise<unknown[]> {
const events = await db
.select()
.from(trialEvents)
.where(eq(trialEvents.trialId, trialId))
.orderBy(trialEvents.timestamp)
.limit(limit);
return events;
}
getTrialStatusSync(trialId: string): {
trial: {
id: string;
status: string;
startedAt: Date | null;
completedAt: Date | null;
};
currentStepIndex: number;
} | null {
return null;
}
getTrialEventsSync(trialId: string, limit: number = 100): unknown[] {
return [];
}
getConnectionCount(trialId?: string): number {
if (trialId) {
return this.getTrialRoomClients(trialId).length;
}
return this.clients.size;
}
getConnectedTrialIds(): string[] {
const trialIds = new Set<string>();
for (const [, client] of this.clients) {
trialIds.add(client.trialId);
}
return Array.from(trialIds);
}
async getTrialsWithActiveConnections(studyIds?: string[]): Promise<string[]> {
const conditions =
studyIds && studyIds.length > 0
? sql`${wsConnections.trialId} IN (
SELECT ${trials.id} FROM ${trials}
WHERE ${trials.experimentId} IN (
SELECT ${experiments.id} FROM ${experiments}
WHERE ${experiments.studyId} IN (${sql.raw(studyIds.map((id) => `'${id}'`).join(","))})
)
)`
: undefined;
const connections = await db
.selectDistinct({ trialId: wsConnections.trialId })
.from(wsConnections);
return connections.map((c) => c.trialId);
}
}
export const wsManager = new WebSocketManager();