diff --git a/.changeset/connection-tags-property.md b/.changeset/connection-tags-property.md new file mode 100644 index 0000000..cfac511 --- /dev/null +++ b/.changeset/connection-tags-property.md @@ -0,0 +1,5 @@ +--- +"partyserver": minor +--- + +Add `connection.tags` property to read back tags assigned via `getConnectionTags()`. Works in both hibernating and in-memory modes. Tags are validated and always include the connection id as the first tag. diff --git a/packages/partyserver/src/connection.ts b/packages/partyserver/src/connection.ts index 713f2f3..f3d3648 100644 --- a/packages/partyserver/src/connection.ts +++ b/packages/partyserver/src/connection.ts @@ -34,6 +34,7 @@ type ConnectionAttachments = { // TODO: remove this once we have // durable object level setState server: string; + tags: string[]; }; __user?: unknown; }; @@ -57,11 +58,20 @@ function tryGetPartyServerMeta( if (!pk || typeof pk !== "object") { return null; } - const { id, server } = pk as { id?: unknown; server?: unknown }; + const { id, server, tags } = pk as { + id?: unknown; + server?: unknown; + tags?: unknown; + }; if (typeof id !== "string" || typeof server !== "string") { return null; } - return pk as ConnectionAttachments["__pk"]; + // Default tags to [] for connections created before tags were stored + return { + id, + server, + tags: Array.isArray(tags) ? tags : [] + } as ConnectionAttachments["__pk"]; } catch { return null; } @@ -138,6 +148,12 @@ export const createLazyConnection = ( return attachments.get(ws).__pk.server; } }, + tags: { + get() { + // Default to [] for connections accepted before tags were stored + return attachments.get(ws).__pk.tags ?? []; + } + }, socket: { get() { return ws; @@ -233,6 +249,36 @@ class HibernatingConnectionIterator implements IterableIterator< } } +/** + * Deduplicate and validate connection tags. + * Returns the final tag array (always includes the connection id as the first tag). + */ +function prepareTags(connectionId: string, userTags: string[]): string[] { + const tags = [connectionId, ...userTags.filter((t) => t !== connectionId)]; + + // validate tags against documented restrictions + // https://developers.cloudflare.com/durable-objects/api/hibernatable-websockets-api/#state-methods-for-websockets + if (tags.length > 10) { + throw new Error( + "A connection can only have 10 tags, including the default id tag." + ); + } + + for (const tag of tags) { + if (typeof tag !== "string") { + throw new Error(`A connection tag must be a string. Received: ${tag}`); + } + if (tag === "") { + throw new Error("A connection tag must not be an empty string."); + } + if (tag.length > 256) { + throw new Error("A connection tag must not exceed 256 characters"); + } + } + + return tags; +} + export interface ConnectionManager { getCount(): number; getConnection(id: string): Connection | undefined; @@ -280,12 +326,16 @@ export class InMemoryConnectionManager implements ConnectionManager { accept(connection: Connection, options: { tags: string[]; server: string }) { connection.accept(); + const tags = prepareTags(connection.id, options.tags); + this.#connections.set(connection.id, connection); - this.tags.set(connection, [ - // make sure we have id tag - connection.id, - ...options.tags.filter((t) => t !== connection.id) - ]); + this.tags.set(connection, tags); + + // Expose tags on the connection object itself + Object.defineProperty(connection, "tags", { + get: () => tags, + configurable: true + }); const removeConnection = () => { this.#connections.delete(connection.id); @@ -336,37 +386,14 @@ export class HibernatingConnectionManager implements ConnectionManager { } accept(connection: Connection, options: { tags: string[]; server: string }) { - // dedupe tags in case user already provided id tag - const tags = [ - connection.id, - ...options.tags.filter((t) => t !== connection.id) - ]; - - // validate tags against documented restrictions - // shttps://developers.cloudflare.com/durable-objects/api/hibernatable-websockets-api/#state-methods-for-websockets - if (tags.length > 10) { - throw new Error( - "A connection can only have 10 tags, including the default id tag." - ); - } - - for (const tag of tags) { - if (typeof tag !== "string") { - throw new Error(`A connection tag must be a string. Received: ${tag}`); - } - if (tag === "") { - throw new Error("A connection tag must not be an empty string."); - } - if (tag.length > 256) { - throw new Error("A connection tag must not exceed 256 characters"); - } - } + const tags = prepareTags(connection.id, options.tags); this.controller.acceptWebSocket(connection, tags); connection.serializeAttachment({ __pk: { id: connection.id, - server: options.server + server: options.server, + tags }, __user: null }); diff --git a/packages/partyserver/src/index.ts b/packages/partyserver/src/index.ts index 17e7411..aa3ea32 100644 --- a/packages/partyserver/src/index.ts +++ b/packages/partyserver/src/index.ts @@ -431,6 +431,7 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam let connection: Connection = Object.assign(serverWebSocket, { id: connectionId, server: this.name, + tags: [] as string[], state: null as unknown as ConnectionState, setState(setState: T | ConnectionSetStateFn) { let state: T; diff --git a/packages/partyserver/src/tests/index.test.ts b/packages/partyserver/src/tests/index.test.ts index fd11183..141e041 100644 --- a/packages/partyserver/src/tests/index.test.ts +++ b/packages/partyserver/src/tests/index.test.ts @@ -658,3 +658,105 @@ describe("CORS", () => { ); }); }); + +describe("Connection tags", () => { + it("exposes tags on a hibernating connection", async () => { + const ctx = createExecutionContext(); + const request = new Request( + "http://example.com/parties/tags-server/room1", + { + headers: { Upgrade: "websocket" } + } + ); + const response = await worker.fetch(request, env, ctx); + const ws = response.webSocket!; + ws.accept(); + + const { promise, resolve, reject } = Promise.withResolvers(); + ws.addEventListener("message", (message) => { + try { + const tags = JSON.parse(message.data as string) as string[]; + // Should include the auto-prepended connection id plus the custom tags + expect(tags).toHaveLength(3); + expect(tags[0]).toBeTypeOf("string"); // connection id + expect(tags).toContain("role:admin"); + expect(tags).toContain("room:lobby"); + resolve(); + } catch (e) { + reject(e); + } finally { + ws.close(); + } + }); + + return promise; + }); + + it("exposes tags on a hibernating connection after wake-up", async () => { + const ctx = createExecutionContext(); + const request = new Request( + "http://example.com/parties/tags-server/room2", + { + headers: { Upgrade: "websocket" } + } + ); + const response = await worker.fetch(request, env, ctx); + const ws = response.webSocket!; + ws.accept(); + + // Wait for the onConnect message + const connectMessage = await new Promise((resolve) => { + ws.addEventListener("message", (e) => resolve(e.data as string), { + once: true + }); + }); + const connectTags = JSON.parse(connectMessage) as string[]; + expect(connectTags).toContain("role:admin"); + + // Send a message to trigger onMessage, which reads tags again + ws.send("ping"); + const wakeMessage = await new Promise((resolve) => { + ws.addEventListener("message", (e) => resolve(e.data as string), { + once: true + }); + }); + const wakeTags = JSON.parse(wakeMessage) as string[]; + expect(wakeTags).toHaveLength(3); + expect(wakeTags).toContain("role:admin"); + expect(wakeTags).toContain("room:lobby"); + + ws.close(); + }); + + it("exposes tags on a non-hibernating (in-memory) connection", async () => { + const ctx = createExecutionContext(); + const request = new Request( + "http://example.com/parties/tags-server-in-memory/room1", + { + headers: { Upgrade: "websocket" } + } + ); + const response = await worker.fetch(request, env, ctx); + const ws = response.webSocket!; + ws.accept(); + + const { promise, resolve, reject } = Promise.withResolvers(); + ws.addEventListener("message", (message) => { + try { + const tags = JSON.parse(message.data as string) as string[]; + // Should include the auto-prepended connection id plus the custom tags + expect(tags).toHaveLength(3); + expect(tags[0]).toBeTypeOf("string"); // connection id + expect(tags).toContain("role:viewer"); + expect(tags).toContain("room:general"); + resolve(); + } catch (e) { + reject(e); + } finally { + ws.close(); + } + }); + + return promise; + }); +}); diff --git a/packages/partyserver/src/tests/worker.ts b/packages/partyserver/src/tests/worker.ts index 35b929f..1928a9b 100644 --- a/packages/partyserver/src/tests/worker.ts +++ b/packages/partyserver/src/tests/worker.ts @@ -21,6 +21,8 @@ export type Env = { CustomCorsServer: DurableObjectNamespace; FailingOnStartServer: DurableObjectNamespace; HibernatingNameInMessage: DurableObjectNamespace; + TagsServer: DurableObjectNamespace; + TagsServerInMemory: DurableObjectNamespace; }; export class Stateful extends Server { @@ -307,6 +309,49 @@ export class HibernatingNameInMessage extends Server { } } +/** + * Tests that connection.tags is readable in hibernating mode. + */ +export class TagsServer extends Server { + static options = { + hibernate: true + }; + + getConnectionTags( + _connection: Connection, + _ctx: ConnectionContext + ): string[] { + return ["role:admin", "room:lobby"]; + } + + onConnect(connection: Connection): void { + connection.send(JSON.stringify(connection.tags)); + } + + onMessage(connection: Connection, _message: WSMessage): void { + // Also verify tags survive hibernation wake-up + connection.send(JSON.stringify(connection.tags)); + } +} + +/** + * Tests that connection.tags is readable in non-hibernating (in-memory) mode. + */ +export class TagsServerInMemory extends Server { + // no hibernate — uses the in-memory path + + getConnectionTags( + _connection: Connection, + _ctx: ConnectionContext + ): string[] { + return ["role:viewer", "room:general"]; + } + + onConnect(connection: Connection): void { + connection.send(JSON.stringify(connection.tags)); + } +} + export class CorsServer extends Server { onRequest(): Response | Promise { return Response.json({ cors: true }); diff --git a/packages/partyserver/src/tests/wrangler.jsonc b/packages/partyserver/src/tests/wrangler.jsonc index 36a8860..6e1e8c4 100644 --- a/packages/partyserver/src/tests/wrangler.jsonc +++ b/packages/partyserver/src/tests/wrangler.jsonc @@ -58,6 +58,14 @@ { "name": "HibernatingNameInMessage", "class_name": "HibernatingNameInMessage" + }, + { + "name": "TagsServer", + "class_name": "TagsServer" + }, + { + "name": "TagsServerInMemory", + "class_name": "TagsServerInMemory" } ] }, @@ -76,7 +84,9 @@ "HibernatingOnStartServer", "AlarmServer", "FailingOnStartServer", - "HibernatingNameInMessage" + "HibernatingNameInMessage", + "TagsServer", + "TagsServerInMemory" ] } ] diff --git a/packages/partyserver/src/types.ts b/packages/partyserver/src/types.ts index 2a00507..2ddca78 100644 --- a/packages/partyserver/src/types.ts +++ b/packages/partyserver/src/types.ts @@ -70,6 +70,12 @@ export type Connection = WebSocket & { */ deserializeAttachment(): T | null; + /** + * Tags assigned to this connection via {@link Server.getConnectionTags}. + * Always includes the connection id as the first tag. + */ + tags: readonly string[]; + /** * Server's name */