⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/connection-tags-property.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"partyserver": minor
---

Add `connection.tags` property to read back tags assigned via `getConnectionTags()`. Works in both hibernating and in-memory modes. Tags are validated and always include the connection id as the first tag.
93 changes: 60 additions & 33 deletions packages/partyserver/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type ConnectionAttachments = {
// TODO: remove this once we have
// durable object level setState
server: string;
tags: string[];
};
__user?: unknown;
};
Expand All @@ -57,11 +58,20 @@ function tryGetPartyServerMeta(
if (!pk || typeof pk !== "object") {
return null;
}
const { id, server } = pk as { id?: unknown; server?: unknown };
const { id, server, tags } = pk as {
id?: unknown;
server?: unknown;
tags?: unknown;
};
if (typeof id !== "string" || typeof server !== "string") {
return null;
}
return pk as ConnectionAttachments["__pk"];
// Default tags to [] for connections created before tags were stored
return {
id,
server,
tags: Array.isArray(tags) ? tags : []
} as ConnectionAttachments["__pk"];
} catch {
return null;
}
Expand Down Expand Up @@ -138,6 +148,12 @@ export const createLazyConnection = (
return attachments.get(ws).__pk.server;
}
},
tags: {
get() {
// Default to [] for connections accepted before tags were stored
return attachments.get(ws).__pk.tags ?? [];
}
},
socket: {
get() {
return ws;
Expand Down Expand Up @@ -233,6 +249,36 @@ class HibernatingConnectionIterator<T> implements IterableIterator<
}
}

/**
* Deduplicate and validate connection tags.
* Returns the final tag array (always includes the connection id as the first tag).
*/
function prepareTags(connectionId: string, userTags: string[]): string[] {
const tags = [connectionId, ...userTags.filter((t) => t !== connectionId)];

// validate tags against documented restrictions
// https://developers.cloudflare.com/durable-objects/api/hibernatable-websockets-api/#state-methods-for-websockets
if (tags.length > 10) {
throw new Error(
"A connection can only have 10 tags, including the default id tag."
);
}

for (const tag of tags) {
if (typeof tag !== "string") {
throw new Error(`A connection tag must be a string. Received: ${tag}`);
}
if (tag === "") {
throw new Error("A connection tag must not be an empty string.");
}
if (tag.length > 256) {
throw new Error("A connection tag must not exceed 256 characters");
}
}

return tags;
}

export interface ConnectionManager {
getCount(): number;
getConnection<TState>(id: string): Connection<TState> | undefined;
Expand Down Expand Up @@ -280,12 +326,16 @@ export class InMemoryConnectionManager<TState> implements ConnectionManager {
accept(connection: Connection, options: { tags: string[]; server: string }) {
connection.accept();

const tags = prepareTags(connection.id, options.tags);

this.#connections.set(connection.id, connection);
this.tags.set(connection, [
// make sure we have id tag
connection.id,
...options.tags.filter((t) => t !== connection.id)
]);
this.tags.set(connection, tags);

// Expose tags on the connection object itself
Object.defineProperty(connection, "tags", {
get: () => tags,
configurable: true
});

const removeConnection = () => {
this.#connections.delete(connection.id);
Expand Down Expand Up @@ -336,37 +386,14 @@ export class HibernatingConnectionManager<TState> implements ConnectionManager {
}

accept(connection: Connection, options: { tags: string[]; server: string }) {
// dedupe tags in case user already provided id tag
const tags = [
connection.id,
...options.tags.filter((t) => t !== connection.id)
];

// validate tags against documented restrictions
// shttps://developers.cloudflare.com/durable-objects/api/hibernatable-websockets-api/#state-methods-for-websockets
if (tags.length > 10) {
throw new Error(
"A connection can only have 10 tags, including the default id tag."
);
}

for (const tag of tags) {
if (typeof tag !== "string") {
throw new Error(`A connection tag must be a string. Received: ${tag}`);
}
if (tag === "") {
throw new Error("A connection tag must not be an empty string.");
}
if (tag.length > 256) {
throw new Error("A connection tag must not exceed 256 characters");
}
}
const tags = prepareTags(connection.id, options.tags);

this.controller.acceptWebSocket(connection, tags);
connection.serializeAttachment({
__pk: {
id: connection.id,
server: options.server
server: options.server,
tags
},
__user: null
});
Expand Down
1 change: 1 addition & 0 deletions packages/partyserver/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam
let connection: Connection = Object.assign(serverWebSocket, {
id: connectionId,
server: this.name,
tags: [] as string[],
state: null as unknown as ConnectionState<unknown>,
setState<T = unknown>(setState: T | ConnectionSetStateFn<T>) {
let state: T;
Expand Down
102 changes: 102 additions & 0 deletions packages/partyserver/src/tests/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -658,3 +658,105 @@ describe("CORS", () => {
);
});
});

describe("Connection tags", () => {
it("exposes tags on a hibernating connection", async () => {
const ctx = createExecutionContext();
const request = new Request(
"http://example.com/parties/tags-server/room1",
{
headers: { Upgrade: "websocket" }
}
);
const response = await worker.fetch(request, env, ctx);
const ws = response.webSocket!;
ws.accept();

const { promise, resolve, reject } = Promise.withResolvers<void>();
ws.addEventListener("message", (message) => {
try {
const tags = JSON.parse(message.data as string) as string[];
// Should include the auto-prepended connection id plus the custom tags
expect(tags).toHaveLength(3);
expect(tags[0]).toBeTypeOf("string"); // connection id
expect(tags).toContain("role:admin");
expect(tags).toContain("room:lobby");
resolve();
} catch (e) {
reject(e);
} finally {
ws.close();
}
});

return promise;
});

it("exposes tags on a hibernating connection after wake-up", async () => {
const ctx = createExecutionContext();
const request = new Request(
"http://example.com/parties/tags-server/room2",
{
headers: { Upgrade: "websocket" }
}
);
const response = await worker.fetch(request, env, ctx);
const ws = response.webSocket!;
ws.accept();

// Wait for the onConnect message
const connectMessage = await new Promise<string>((resolve) => {
ws.addEventListener("message", (e) => resolve(e.data as string), {
once: true
});
});
const connectTags = JSON.parse(connectMessage) as string[];
expect(connectTags).toContain("role:admin");

// Send a message to trigger onMessage, which reads tags again
ws.send("ping");
const wakeMessage = await new Promise<string>((resolve) => {
ws.addEventListener("message", (e) => resolve(e.data as string), {
once: true
});
});
const wakeTags = JSON.parse(wakeMessage) as string[];
expect(wakeTags).toHaveLength(3);
expect(wakeTags).toContain("role:admin");
expect(wakeTags).toContain("room:lobby");

ws.close();
});

it("exposes tags on a non-hibernating (in-memory) connection", async () => {
const ctx = createExecutionContext();
const request = new Request(
"http://example.com/parties/tags-server-in-memory/room1",
{
headers: { Upgrade: "websocket" }
}
);
const response = await worker.fetch(request, env, ctx);
const ws = response.webSocket!;
ws.accept();

const { promise, resolve, reject } = Promise.withResolvers<void>();
ws.addEventListener("message", (message) => {
try {
const tags = JSON.parse(message.data as string) as string[];
// Should include the auto-prepended connection id plus the custom tags
expect(tags).toHaveLength(3);
expect(tags[0]).toBeTypeOf("string"); // connection id
expect(tags).toContain("role:viewer");
expect(tags).toContain("room:general");
resolve();
} catch (e) {
reject(e);
} finally {
ws.close();
}
});

return promise;
});
});
45 changes: 45 additions & 0 deletions packages/partyserver/src/tests/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ export type Env = {
CustomCorsServer: DurableObjectNamespace<CustomCorsServer>;
FailingOnStartServer: DurableObjectNamespace<FailingOnStartServer>;
HibernatingNameInMessage: DurableObjectNamespace<HibernatingNameInMessage>;
TagsServer: DurableObjectNamespace<TagsServer>;
TagsServerInMemory: DurableObjectNamespace<TagsServerInMemory>;
};

export class Stateful extends Server {
Expand Down Expand Up @@ -307,6 +309,49 @@ export class HibernatingNameInMessage extends Server {
}
}

/**
* Tests that connection.tags is readable in hibernating mode.
*/
export class TagsServer extends Server {
static options = {
hibernate: true
};

getConnectionTags(
_connection: Connection,
_ctx: ConnectionContext
): string[] {
return ["role:admin", "room:lobby"];
}

onConnect(connection: Connection): void {
connection.send(JSON.stringify(connection.tags));
}

onMessage(connection: Connection, _message: WSMessage): void {
// Also verify tags survive hibernation wake-up
connection.send(JSON.stringify(connection.tags));
}
}

/**
* Tests that connection.tags is readable in non-hibernating (in-memory) mode.
*/
export class TagsServerInMemory extends Server {
// no hibernate — uses the in-memory path

getConnectionTags(
_connection: Connection,
_ctx: ConnectionContext
): string[] {
return ["role:viewer", "room:general"];
}

onConnect(connection: Connection): void {
connection.send(JSON.stringify(connection.tags));
}
}

export class CorsServer extends Server {
onRequest(): Response | Promise<Response> {
return Response.json({ cors: true });
Expand Down
12 changes: 11 additions & 1 deletion packages/partyserver/src/tests/wrangler.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@
{
"name": "HibernatingNameInMessage",
"class_name": "HibernatingNameInMessage"
},
{
"name": "TagsServer",
"class_name": "TagsServer"
},
{
"name": "TagsServerInMemory",
"class_name": "TagsServerInMemory"
}
]
},
Expand All @@ -76,7 +84,9 @@
"HibernatingOnStartServer",
"AlarmServer",
"FailingOnStartServer",
"HibernatingNameInMessage"
"HibernatingNameInMessage",
"TagsServer",
"TagsServerInMemory"
]
}
]
Expand Down
6 changes: 6 additions & 0 deletions packages/partyserver/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ export type Connection<TState = unknown> = WebSocket & {
*/
deserializeAttachment<T = unknown>(): T | null;

/**
* Tags assigned to this connection via {@link Server.getConnectionTags}.
* Always includes the connection id as the first tag.
*/
tags: readonly string[];

/**
* Server's name
*/
Expand Down
Loading