diff --git a/src/services/auth.js b/src/services/auth.js index b5f1e97..82d4bb9 100644 --- a/src/services/auth.js +++ b/src/services/auth.js @@ -34,6 +34,14 @@ import { AuthenticationError } from "#src/utils/errors.js"; * @property {string} [jti] - JWT ID */ +/** + * @typedef {Object} JWTData + * @property {JWTHeader} header - The JWT header + * @property {JWTClaims} claims - The JWT claims + * @property {Buffer} signature - The JWT signature + * @property {string} signedData - The signed data (header + claims) + */ + let jwtKey; const logger = new Logger("AUTH"); const ALGORITHM = { @@ -84,7 +92,7 @@ function base64Decode(str) { * Signs and creates a JsonWebToken * * @param {JWTClaims} claims - The claims to include in the token - * @param {WithImplicitCoercion} [key] - Optional key, defaults to the configured jwtKey + * @param {WithImplicitCoercion | Buffer} [key] - Optional key, defaults to the configured jwtKey * @param {Object} [options] * @param {string} [options.algorithm] - The algorithm to use, defaults to HS256 * @returns {string} - The signed JsonWebToken @@ -144,31 +152,53 @@ function safeEqual(a, b) { * @throws {AuthenticationError} */ export function verify(jsonWebToken, key = jwtKey) { - const keyBuffer = Buffer.isBuffer(key) ? key : Buffer.from(key, "base64"); - let parsedJWT; - try { - parsedJWT = parseJwt(jsonWebToken); - } catch { - throw new AuthenticationError("Invalid JWT format"); - } - const { header, claims, signature, signedData } = parsedJWT; - const expectedSignature = ALGORITHM_FUNCTIONS[header.alg]?.(signedData, keyBuffer); - if (!expectedSignature) { - throw new AuthenticationError(`Unsupported algorithm: ${header.alg}`); - } - if (!safeEqual(signature, expectedSignature)) { - throw new AuthenticationError("Invalid signature"); - } - // `exp`, `iat` and `nbf` are in seconds (`NumericDate` per RFC7519) - const now = Math.floor(Date.now() / 1000); - if (claims.exp && claims.exp < now) { - throw new AuthenticationError("Token expired"); - } - if (claims.nbf && claims.nbf > now) { - throw new AuthenticationError("Token not valid yet"); + const jwt = new JsonWebToken(jsonWebToken); + return jwt.verify(key); +} + +export class JsonWebToken { + /** + * @type {JWTData} + */ + unsafe; + /** + * @param {string} jsonWebToken + */ + constructor(jsonWebToken) { + let payload; + try { + payload = parseJwt(jsonWebToken); + } catch { + throw new AuthenticationError("Malformed JWT"); + } + this.unsafe = payload; } - if (claims.iat && claims.iat > now + 60) { - throw new AuthenticationError("Token issued in the future"); + + /** + * @param {WithImplicitCoercion} [key] buffer/b64 str + * @return {JWTClaims} + */ + verify(key = jwtKey) { + const { header, claims, signature, signedData } = this.unsafe; + const keyBuffer = Buffer.isBuffer(key) ? key : Buffer.from(key, "base64"); + const expectedSignature = ALGORITHM_FUNCTIONS[header.alg]?.(signedData, keyBuffer); + if (!expectedSignature) { + throw new AuthenticationError(`Unsupported algorithm: ${header.alg}`); + } + if (!safeEqual(signature, expectedSignature)) { + throw new AuthenticationError("Invalid signature"); + } + // `exp`, `iat` and `nbf` are in seconds (`NumericDate` per RFC7519) + const now = Math.floor(Date.now() / 1000); + if (claims.exp && claims.exp < now) { + throw new AuthenticationError("Token expired"); + } + if (claims.nbf && claims.nbf > now) { + throw new AuthenticationError("Token not valid yet"); + } + if (claims.iat && claims.iat > now + 60) { + throw new AuthenticationError("Token issued in the future"); + } + return claims; } - return claims; } diff --git a/src/services/ws.js b/src/services/ws.js index 66589f5..a26a1b5 100644 --- a/src/services/ws.js +++ b/src/services/ws.js @@ -7,11 +7,11 @@ import { Logger, extractRequestInfo } from "#src/utils/utils.js"; import { AuthenticationError, OvercrowdedError } from "#src/utils/errors.js"; import { SESSION_CLOSE_CODE } from "#src/models/session.js"; import { Channel } from "#src/models/channel.js"; -import { verify } from "#src/services/auth.js"; +import { JsonWebToken } from "#src/services/auth.js"; /** * @typedef Credentials - * @property {string} channelUUID + * @property {string} channelUUID deprecated, this is obtained from the jwt * @property {string} jwt */ @@ -53,7 +53,6 @@ export async function start(options) { /** @type {Credentials | String} can be a string (the jwt) for backwards compatibility with version 1.1 and earlier */ const credentials = JSON.parse(message); const session = connect(webSocket, { - channelUUID: credentials?.channelUUID, jwt: credentials.jwt || credentials, }); session.remote = remoteAddress; @@ -102,22 +101,14 @@ export function close() { * @param {import("ws").WebSocket} webSocket * @param {Credentials} */ -function connect(webSocket, { channelUUID, jwt }) { - let channel = Channel.records.get(channelUUID); - const authResult = verify(jwt, channel?.key); - const { sfu_channel_uuid, session_id, ice_servers } = authResult; - if (!channelUUID && sfu_channel_uuid) { - // Cases where the channelUUID is not provided in the credentials for backwards compatibility with version 1.1 and earlier. - channel = Channel.records.get(sfu_channel_uuid); - if (channel.key) { - throw new AuthenticationError( - "A channel with a key can only be accessed by providing a channelUUID in the credentials" - ); - } - } +function connect(webSocket, { jwt }) { + const token = new JsonWebToken(jwt); + const channel = Channel.records.get(token.unsafe.claims.sfu_channel_uuid); if (!channel) { throw new AuthenticationError(`Channel does not exist`); } + const authResult = token.verify(channel.key); + const { session_id, ice_servers } = authResult; if (!session_id) { throw new AuthenticationError("Malformed JWT payload"); } diff --git a/tests/network.test.js b/tests/network.test.js index 1cc385b..ec73819 100644 --- a/tests/network.test.js +++ b/tests/network.test.js @@ -264,7 +264,7 @@ describe("Full network", () => { expect(closeEvent.code).toBe(SESSION_CLOSE_CODE.P_TIMEOUT); }); test("A client can broadcast arbitrary messages to other clients on a channel that does not have webRTC", async () => { - const channelUUID = await network.getChannelUUID(false); + const channelUUID = await network.getChannelUUID({ useWebRtc: false }); const user1 = await network.connect(channelUUID, 1); const user2 = await network.connect(channelUUID, 2); const sender = await network.connect(channelUUID, 3); diff --git a/tests/security.test.js b/tests/security.test.js index a7780a5..0e77d57 100644 --- a/tests/security.test.js +++ b/tests/security.test.js @@ -39,4 +39,17 @@ describe("Security", () => { const [event] = await once(websocket, "close"); expect(event).toBe(WS_CLOSE_CODE.TIMEOUT); }); + test("cannot use the default jwt key to access a keyed channel", async () => { + const channelUUID = await network.getChannelUUID({ key: "channel-specific-key" }); + const channel = Channel.records.get(channelUUID); + await expect(network.connect(channelUUID, 3)).rejects.toThrow(); + expect(channel.sessions.size).toBe(0); + }); + test("can join a keyed channel with the appropriate key", async () => { + const key = "channel-specific-key"; + const channelUUID = await network.getChannelUUID({ key }); + const channel = Channel.records.get(channelUUID); + await network.connect(channelUUID, 4, { key }); + expect(channel.sessions.size).toBe(1); + }); }); diff --git a/tests/utils/network.js b/tests/utils/network.js index 624924b..dbd35ae 100644 --- a/tests/utils/network.js +++ b/tests/utils/network.js @@ -12,8 +12,8 @@ import { Channel } from "#src/models/channel.js"; const HMAC_B64_KEY = "u6bsUQEWrHdKIuYplirRnbBmLbrKV5PxKG7DtA71mng="; const HMAC_KEY = Buffer.from(HMAC_B64_KEY, "base64"); -export function makeJwt(data) { - return auth.sign(data, HMAC_KEY, { algorithm: "HS256" }); +export function makeJwt(data, { key = HMAC_KEY } = {}) { + return auth.sign(data, key, { algorithm: "HS256" }); } /** @@ -44,13 +44,15 @@ export class LocalNetwork { } /** - * @param {boolean} [useWebRtc] + * @param {Object} [param0] + * @param {boolean} [useWebRtc=true] + * @param {string} [key] the channel-specific key * @returns {Promise} */ - async getChannelUUID(useWebRtc = true) { + async getChannelUUID({ useWebRtc = true, key = HMAC_B64_KEY } = {}) { const jwt = this.makeJwt({ iss: `http://${this.hostname}:${this.port}/`, - key: HMAC_B64_KEY, + key, }); const response = await fetch( `http://${this.hostname}:${this.port}/v${http.API_VERSION}/channel?webRTC=${useWebRtc}`, @@ -70,10 +72,12 @@ export class LocalNetwork { * * @param {string} channelUUID * @param {number} sessionId + * @param {Object} [options] + * @param {string} [options.key] the key to use to authenticate the session (this should be the key of the channel) * @returns { Promise<{ session: import("#src/models/session.js").Session, sfuClient: import("#src/client.js").SfuClient }>} * @throws {Error} if the client is closed before being authenticated */ - async connect(channelUUID, sessionId) { + async connect(channelUUID, sessionId, { key = HMAC_KEY } = {}) { const sfuClient = new SfuClient(); this._sfuClients.push(sfuClient); sfuClient._createDevice = () => { @@ -100,10 +104,13 @@ export class LocalNetwork { }); sfuClient.connect( `ws://${this.hostname}:${this.port}`, - this.makeJwt({ - sfu_channel_uuid: channelUUID, - session_id: sessionId, - }), + this.makeJwt( + { + sfu_channel_uuid: channelUUID, + session_id: sessionId, + }, + { key } + ), { channelUUID } ); const channel = Channel.records.get(channelUUID);