From b7a0cf2470cdb4ed1008bd2c13b6e24ceca7d876 Mon Sep 17 00:00:00 2001 From: Tin Pecirep Date: Wed, 23 Apr 2025 16:05:00 +0200 Subject: [PATCH] feat: add oauth2 code verifier * fix: ensure oauth state param matches before finishing oauth flow Signed-off-by: Tin Pecirep * chore: upgrade openid-client to v6 Signed-off-by: Tin Pecirep * feat: use PKCE for oauth2 on supported clients Signed-off-by: Tin Pecirep * feat: use state and PKCE in mobile app Signed-off-by: Tin Pecirep * fix: remove obsolete oauth repository init Signed-off-by: Tin Pecirep * fix: rewrite callback url if mobile redirect url is enabled Signed-off-by: Tin Pecirep * fix: propagate oidc client error cause when oauth callback fails Signed-off-by: Tin Pecirep * fix: adapt auth service tests to required state and PKCE params Signed-off-by: Tin Pecirep * fix: update sdk types Signed-off-by: Tin Pecirep * fix: adapt oauth e2e test to work with PKCE Signed-off-by: Tin Pecirep * fix: allow insecure (http) oauth clients Signed-off-by: Tin Pecirep --------- Signed-off-by: Tin Pecirep Co-authored-by: Jason Rasmussen --- e2e/src/api/specs/oauth.e2e-spec.ts | 118 +++++++++++---- mobile/lib/services/oauth.service.dart | 16 +- .../lib/widgets/forms/login/login_form.dart | 33 ++++- .../lib/model/o_auth_callback_dto.dart | 43 ++++-- .../openapi/lib/model/o_auth_config_dto.dart | 43 ++++-- mobile/pubspec.lock | 2 +- mobile/pubspec.yaml | 1 + open-api/immich-openapi-specs.json | 12 ++ open-api/typescript-sdk/src/fetch-client.ts | 4 + server/package-lock.json | 49 +++--- server/package.json | 2 +- server/src/controllers/oauth.controller.ts | 32 +++- server/src/dtos/auth.dto.ts | 20 ++- server/src/enum.ts | 2 + server/src/repositories/oauth.repository.ts | 74 ++++++--- server/src/services/auth.service.spec.ts | 140 +++++++++++------- server/src/services/auth.service.ts | 68 ++++++--- server/src/utils/response.ts | 2 + 18 files changed, 469 insertions(+), 192 deletions(-) diff --git a/e2e/src/api/specs/oauth.e2e-spec.ts b/e2e/src/api/specs/oauth.e2e-spec.ts index 9cd5f0252a6..3b1e75d3e5d 100644 --- a/e2e/src/api/specs/oauth.e2e-spec.ts +++ b/e2e/src/api/specs/oauth.e2e-spec.ts @@ -6,6 +6,7 @@ import { startOAuth, updateConfig, } from '@immich/sdk'; +import { createHash, randomBytes } from 'node:crypto'; import { errorDto } from 'src/responses'; import { OAuthClient, OAuthUser } from 'src/setup/auth-server'; import { app, asBearerAuth, baseUrl, utils } from 'src/utils'; @@ -21,18 +22,30 @@ const mobileOverrideRedirectUri = 'https://photos.immich.app/oauth/mobile-redire const redirect = async (url: string, cookies?: string[]) => { const { headers } = await request(url) - .get('/') + .get('') .set('Cookie', cookies || []); return { cookies: (headers['set-cookie'] as unknown as string[]) || [], location: headers.location }; }; +// Function to generate a code challenge from the verifier +const generateCodeChallenge = async (codeVerifier: string): Promise => { + const hashed = createHash('sha256').update(codeVerifier).digest(); + return hashed.toString('base64url'); +}; + const loginWithOAuth = async (sub: OAuthUser | string, redirectUri?: string) => { - const { url } = await startOAuth({ oAuthConfigDto: { redirectUri: redirectUri ?? `${baseUrl}/auth/login` } }); + const state = randomBytes(16).toString('base64url'); + const codeVerifier = randomBytes(64).toString('base64url'); + const codeChallenge = await generateCodeChallenge(codeVerifier); + + const { url } = await startOAuth({ + oAuthConfigDto: { redirectUri: redirectUri ?? `${baseUrl}/auth/login`, state, codeChallenge }, + }); // login const response1 = await redirect(url.replace(authServer.internal, authServer.external)); const response2 = await request(authServer.external + response1.location) - .post('/') + .post('') .set('Cookie', response1.cookies) .type('form') .send({ prompt: 'login', login: sub, password: 'password' }); @@ -40,7 +53,7 @@ const loginWithOAuth = async (sub: OAuthUser | string, redirectUri?: string) => // approve const response3 = await redirect(response2.header.location, response1.cookies); const response4 = await request(authServer.external + response3.location) - .post('/') + .post('') .type('form') .set('Cookie', response3.cookies) .send({ prompt: 'consent' }); @@ -51,9 +64,9 @@ const loginWithOAuth = async (sub: OAuthUser | string, redirectUri?: string) => expect(redirectUrl).toBeDefined(); const params = new URL(redirectUrl).searchParams; expect(params.get('code')).toBeDefined(); - expect(params.get('state')).toBeDefined(); + expect(params.get('state')).toBe(state); - return redirectUrl; + return { url: redirectUrl, state, codeVerifier }; }; const setupOAuth = async (token: string, dto: Partial) => { @@ -119,9 +132,42 @@ describe(`/oauth`, () => { expect(body).toEqual(errorDto.badRequest(['url should not be empty'])); }); - it('should auto register the user by default', async () => { - const url = await loginWithOAuth('oauth-auto-register'); + it(`should throw an error if the state is not provided`, async () => { + const { url } = await loginWithOAuth('oauth-auto-register'); const { status, body } = await request(app).post('/oauth/callback').send({ url }); + expect(status).toBe(400); + expect(body).toEqual(errorDto.badRequest('OAuth state is missing')); + }); + + it(`should throw an error if the state mismatches`, async () => { + const callbackParams = await loginWithOAuth('oauth-auto-register'); + const { state } = await loginWithOAuth('oauth-auto-register'); + const { status, body } = await request(app) + .post('/oauth/callback') + .send({ ...callbackParams, state }); + expect(status).toBeGreaterThanOrEqual(400); + }); + + it(`should throw an error if the codeVerifier is not provided`, async () => { + const { url, state } = await loginWithOAuth('oauth-auto-register'); + const { status, body } = await request(app).post('/oauth/callback').send({ url, state }); + expect(status).toBe(400); + expect(body).toEqual(errorDto.badRequest('OAuth code verifier is missing')); + }); + + it(`should throw an error if the codeVerifier doesn't match the challenge`, async () => { + const callbackParams = await loginWithOAuth('oauth-auto-register'); + const { codeVerifier } = await loginWithOAuth('oauth-auto-register'); + const { status, body } = await request(app) + .post('/oauth/callback') + .send({ ...callbackParams, codeVerifier }); + console.log(body); + expect(status).toBeGreaterThanOrEqual(400); + }); + + it('should auto register the user by default', async () => { + const callbackParams = await loginWithOAuth('oauth-auto-register'); + const { status, body } = await request(app).post('/oauth/callback').send(callbackParams); expect(status).toBe(201); expect(body).toMatchObject({ accessToken: expect.any(String), @@ -132,16 +178,30 @@ describe(`/oauth`, () => { }); }); + it('should allow passing state and codeVerifier via cookies', async () => { + const { url, state, codeVerifier } = await loginWithOAuth('oauth-auto-register'); + const { status, body } = await request(app) + .post('/oauth/callback') + .set('Cookie', [`immich_oauth_state=${state}`, `immich_oauth_code_verifier=${codeVerifier}`]) + .send({ url }); + expect(status).toBe(201); + expect(body).toMatchObject({ + accessToken: expect.any(String), + userId: expect.any(String), + userEmail: 'oauth-auto-register@immich.app', + }); + }); + it('should handle a user without an email', async () => { - const url = await loginWithOAuth(OAuthUser.NO_EMAIL); - const { status, body } = await request(app).post('/oauth/callback').send({ url }); + const callbackParams = await loginWithOAuth(OAuthUser.NO_EMAIL); + const { status, body } = await request(app).post('/oauth/callback').send(callbackParams); expect(status).toBe(400); expect(body).toEqual(errorDto.badRequest('OAuth profile does not have an email address')); }); it('should set the quota from a claim', async () => { - const url = await loginWithOAuth(OAuthUser.WITH_QUOTA); - const { status, body } = await request(app).post('/oauth/callback').send({ url }); + const callbackParams = await loginWithOAuth(OAuthUser.WITH_QUOTA); + const { status, body } = await request(app).post('/oauth/callback').send(callbackParams); expect(status).toBe(201); expect(body).toMatchObject({ accessToken: expect.any(String), @@ -154,8 +214,8 @@ describe(`/oauth`, () => { }); it('should set the storage label from a claim', async () => { - const url = await loginWithOAuth(OAuthUser.WITH_USERNAME); - const { status, body } = await request(app).post('/oauth/callback').send({ url }); + const callbackParams = await loginWithOAuth(OAuthUser.WITH_USERNAME); + const { status, body } = await request(app).post('/oauth/callback').send(callbackParams); expect(status).toBe(201); expect(body).toMatchObject({ accessToken: expect.any(String), @@ -176,8 +236,8 @@ describe(`/oauth`, () => { buttonText: 'Login with Immich', signingAlgorithm: 'RS256', }); - const url = await loginWithOAuth('oauth-RS256-token'); - const { status, body } = await request(app).post('/oauth/callback').send({ url }); + const callbackParams = await loginWithOAuth('oauth-RS256-token'); + const { status, body } = await request(app).post('/oauth/callback').send(callbackParams); expect(status).toBe(201); expect(body).toMatchObject({ accessToken: expect.any(String), @@ -196,8 +256,8 @@ describe(`/oauth`, () => { buttonText: 'Login with Immich', profileSigningAlgorithm: 'RS256', }); - const url = await loginWithOAuth('oauth-signed-profile'); - const { status, body } = await request(app).post('/oauth/callback').send({ url }); + const callbackParams = await loginWithOAuth('oauth-signed-profile'); + const { status, body } = await request(app).post('/oauth/callback').send(callbackParams); expect(status).toBe(201); expect(body).toMatchObject({ userId: expect.any(String), @@ -213,8 +273,8 @@ describe(`/oauth`, () => { buttonText: 'Login with Immich', signingAlgorithm: 'something-that-does-not-work', }); - const url = await loginWithOAuth('oauth-signed-bad'); - const { status, body } = await request(app).post('/oauth/callback').send({ url }); + const callbackParams = await loginWithOAuth('oauth-signed-bad'); + const { status, body } = await request(app).post('/oauth/callback').send(callbackParams); expect(status).toBe(500); expect(body).toMatchObject({ error: 'Internal Server Error', @@ -235,8 +295,8 @@ describe(`/oauth`, () => { }); it('should not auto register the user', async () => { - const url = await loginWithOAuth('oauth-no-auto-register'); - const { status, body } = await request(app).post('/oauth/callback').send({ url }); + const callbackParams = await loginWithOAuth('oauth-no-auto-register'); + const { status, body } = await request(app).post('/oauth/callback').send(callbackParams); expect(status).toBe(400); expect(body).toEqual(errorDto.badRequest('User does not exist and auto registering is disabled.')); }); @@ -247,8 +307,8 @@ describe(`/oauth`, () => { email: 'oauth-user3@immich.app', password: 'password', }); - const url = await loginWithOAuth('oauth-user3'); - const { status, body } = await request(app).post('/oauth/callback').send({ url }); + const callbackParams = await loginWithOAuth('oauth-user3'); + const { status, body } = await request(app).post('/oauth/callback').send(callbackParams); expect(status).toBe(201); expect(body).toMatchObject({ userId, @@ -286,13 +346,15 @@ describe(`/oauth`, () => { }); it('should auto register the user by default', async () => { - const url = await loginWithOAuth('oauth-mobile-override', 'app.immich:///oauth-callback'); - expect(url).toEqual(expect.stringContaining(mobileOverrideRedirectUri)); + const callbackParams = await loginWithOAuth('oauth-mobile-override', 'app.immich:///oauth-callback'); + expect(callbackParams.url).toEqual(expect.stringContaining(mobileOverrideRedirectUri)); // simulate redirecting back to mobile app - const redirectUri = url.replace(mobileOverrideRedirectUri, 'app.immich:///oauth-callback'); + const url = callbackParams.url.replace(mobileOverrideRedirectUri, 'app.immich:///oauth-callback'); - const { status, body } = await request(app).post('/oauth/callback').send({ url: redirectUri }); + const { status, body } = await request(app) + .post('/oauth/callback') + .send({ ...callbackParams, url }); expect(status).toBe(201); expect(body).toMatchObject({ accessToken: expect.any(String), diff --git a/mobile/lib/services/oauth.service.dart b/mobile/lib/services/oauth.service.dart index ddd97522f82..9a54a8d7c96 100644 --- a/mobile/lib/services/oauth.service.dart +++ b/mobile/lib/services/oauth.service.dart @@ -13,6 +13,8 @@ class OAuthService { Future getOAuthServerUrl( String serverUrl, + String state, + String codeChallenge, ) async { // Resolve API server endpoint from user provided serverUrl await _apiService.resolveAndSetEndpoint(serverUrl); @@ -22,7 +24,11 @@ class OAuthService { ); final dto = await _apiService.oAuthApi.startOAuth( - OAuthConfigDto(redirectUri: redirectUri), + OAuthConfigDto( + redirectUri: redirectUri, + state: state, + codeChallenge: codeChallenge, + ), ); final authUrl = dto?.url; @@ -31,7 +37,11 @@ class OAuthService { return authUrl; } - Future oAuthLogin(String oauthUrl) async { + Future oAuthLogin( + String oauthUrl, + String state, + String codeVerifier, + ) async { String result = await FlutterWebAuth2.authenticate( url: oauthUrl, callbackUrlScheme: callbackUrlScheme, @@ -49,6 +59,8 @@ class OAuthService { return await _apiService.oAuthApi.finishOAuth( OAuthCallbackDto( url: result, + state: state, + codeVerifier: codeVerifier, ), ); } diff --git a/mobile/lib/widgets/forms/login/login_form.dart b/mobile/lib/widgets/forms/login/login_form.dart index 7af52b413d9..3433648e9f6 100644 --- a/mobile/lib/widgets/forms/login/login_form.dart +++ b/mobile/lib/widgets/forms/login/login_form.dart @@ -1,6 +1,9 @@ +import 'dart:convert'; import 'dart:io'; +import 'dart:math'; import 'package:auto_route/auto_route.dart'; +import 'package:crypto/crypto.dart'; import 'package:easy_localization/easy_localization.dart'; import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; @@ -203,13 +206,32 @@ class LoginForm extends HookConsumerWidget { } } + String generateRandomString(int length) { + final random = Random.secure(); + return base64Url + .encode(List.generate(32, (i) => random.nextInt(256))); + } + + Future generatePKCECodeChallenge(String codeVerifier) async { + var bytes = utf8.encode(codeVerifier); + var digest = sha256.convert(bytes); + return base64Url.encode(digest.bytes).replaceAll('=', ''); + } + oAuthLogin() async { var oAuthService = ref.watch(oAuthServiceProvider); String? oAuthServerUrl; + final state = generateRandomString(32); + final codeVerifier = generateRandomString(64); + final codeChallenge = await generatePKCECodeChallenge(codeVerifier); + try { - oAuthServerUrl = await oAuthService - .getOAuthServerUrl(sanitizeUrl(serverEndpointController.text)); + oAuthServerUrl = await oAuthService.getOAuthServerUrl( + sanitizeUrl(serverEndpointController.text), + state, + codeChallenge, + ); isLoading.value = true; @@ -230,8 +252,11 @@ class LoginForm extends HookConsumerWidget { if (oAuthServerUrl != null) { try { - final loginResponseDto = - await oAuthService.oAuthLogin(oAuthServerUrl); + final loginResponseDto = await oAuthService.oAuthLogin( + oAuthServerUrl, + state, + codeVerifier, + ); if (loginResponseDto == null) { return; diff --git a/mobile/openapi/lib/model/o_auth_callback_dto.dart b/mobile/openapi/lib/model/o_auth_callback_dto.dart index d0b98d5c6f5..ebe2661c52a 100644 --- a/mobile/openapi/lib/model/o_auth_callback_dto.dart +++ b/mobile/openapi/lib/model/o_auth_callback_dto.dart @@ -14,25 +14,36 @@ class OAuthCallbackDto { /// Returns a new [OAuthCallbackDto] instance. OAuthCallbackDto({ required this.url, + required this.state, + required this.codeVerifier, }); String url; + String state; + String codeVerifier; @override - bool operator ==(Object other) => identical(this, other) || other is OAuthCallbackDto && - other.url == url; + bool operator ==(Object other) => + identical(this, other) || + other is OAuthCallbackDto && + other.url == url && + other.state == state && + other.codeVerifier == codeVerifier; @override int get hashCode => - // ignore: unnecessary_parenthesis - (url.hashCode); + // ignore: unnecessary_parenthesis + (url.hashCode) + (state.hashCode) + (codeVerifier.hashCode); @override - String toString() => 'OAuthCallbackDto[url=$url]'; + String toString() => + 'OAuthCallbackDto[url=$url, state=$state, codeVerifier=$codeVerifier]'; Map toJson() { final json = {}; - json[r'url'] = this.url; + json[r'url'] = this.url; + json[r'state'] = this.state; + json[r'codeVerifier'] = this.codeVerifier; return json; } @@ -46,12 +57,17 @@ class OAuthCallbackDto { return OAuthCallbackDto( url: mapValueOfType(json, r'url')!, + state: mapValueOfType(json, r'state')!, + codeVerifier: mapValueOfType(json, r'codeVerifier')!, ); } return null; } - static List listFromJson(dynamic json, {bool growable = false,}) { + static List listFromJson( + dynamic json, { + bool growable = false, + }) { final result = []; if (json is List && json.isNotEmpty) { for (final row in json) { @@ -79,13 +95,19 @@ class OAuthCallbackDto { } // maps a json object with a list of OAuthCallbackDto-objects as value to a dart map - static Map> mapListFromJson(dynamic json, {bool growable = false,}) { + static Map> mapListFromJson( + dynamic json, { + bool growable = false, + }) { final map = >{}; if (json is Map && json.isNotEmpty) { // ignore: parameter_assignments json = json.cast(); for (final entry in json.entries) { - map[entry.key] = OAuthCallbackDto.listFromJson(entry.value, growable: growable,); + map[entry.key] = OAuthCallbackDto.listFromJson( + entry.value, + growable: growable, + ); } } return map; @@ -94,6 +116,7 @@ class OAuthCallbackDto { /// The list of required keys that must be present in a JSON. static const requiredKeys = { 'url', + 'state', + 'codeVerifier', }; } - diff --git a/mobile/openapi/lib/model/o_auth_config_dto.dart b/mobile/openapi/lib/model/o_auth_config_dto.dart index 86c79b4e04f..e142c17c069 100644 --- a/mobile/openapi/lib/model/o_auth_config_dto.dart +++ b/mobile/openapi/lib/model/o_auth_config_dto.dart @@ -14,25 +14,36 @@ class OAuthConfigDto { /// Returns a new [OAuthConfigDto] instance. OAuthConfigDto({ required this.redirectUri, + required this.state, + required this.codeChallenge, }); String redirectUri; + String state; + String codeChallenge; @override - bool operator ==(Object other) => identical(this, other) || other is OAuthConfigDto && - other.redirectUri == redirectUri; + bool operator ==(Object other) => + identical(this, other) || + other is OAuthConfigDto && + other.redirectUri == redirectUri && + other.state == state && + other.codeChallenge == codeChallenge; @override int get hashCode => - // ignore: unnecessary_parenthesis - (redirectUri.hashCode); + // ignore: unnecessary_parenthesis + (redirectUri.hashCode) + (state.hashCode) + (codeChallenge.hashCode); @override - String toString() => 'OAuthConfigDto[redirectUri=$redirectUri]'; + String toString() => + 'OAuthConfigDto[redirectUri=$redirectUri, state=$state, codeChallenge=$codeChallenge]'; Map toJson() { final json = {}; - json[r'redirectUri'] = this.redirectUri; + json[r'redirectUri'] = this.redirectUri; + json[r'state'] = this.state; + json[r'codeChallenge'] = this.codeChallenge; return json; } @@ -46,12 +57,17 @@ class OAuthConfigDto { return OAuthConfigDto( redirectUri: mapValueOfType(json, r'redirectUri')!, + state: mapValueOfType(json, r'state')!, + codeChallenge: mapValueOfType(json, r'codeChallenge')!, ); } return null; } - static List listFromJson(dynamic json, {bool growable = false,}) { + static List listFromJson( + dynamic json, { + bool growable = false, + }) { final result = []; if (json is List && json.isNotEmpty) { for (final row in json) { @@ -79,13 +95,19 @@ class OAuthConfigDto { } // maps a json object with a list of OAuthConfigDto-objects as value to a dart map - static Map> mapListFromJson(dynamic json, {bool growable = false,}) { + static Map> mapListFromJson( + dynamic json, { + bool growable = false, + }) { final map = >{}; if (json is Map && json.isNotEmpty) { // ignore: parameter_assignments json = json.cast(); for (final entry in json.entries) { - map[entry.key] = OAuthConfigDto.listFromJson(entry.value, growable: growable,); + map[entry.key] = OAuthConfigDto.listFromJson( + entry.value, + growable: growable, + ); } } return map; @@ -94,6 +116,7 @@ class OAuthConfigDto { /// The list of required keys that must be present in a JSON. static const requiredKeys = { 'redirectUri', + 'state', + 'codeChallenge', }; } - diff --git a/mobile/pubspec.lock b/mobile/pubspec.lock index 9e8aced11cf..7e490edd25a 100644 --- a/mobile/pubspec.lock +++ b/mobile/pubspec.lock @@ -303,7 +303,7 @@ packages: source: hosted version: "0.3.4+2" crypto: - dependency: transitive + dependency: "direct main" description: name: crypto sha256: "1e445881f28f22d6140f181e07737b22f1e099a5e1ff94b0af2f9e4a463f4855" diff --git a/mobile/pubspec.yaml b/mobile/pubspec.yaml index 44d2e7e5d14..4e57b0fb3bb 100644 --- a/mobile/pubspec.yaml +++ b/mobile/pubspec.yaml @@ -22,6 +22,7 @@ dependencies: collection: ^1.18.0 connectivity_plus: ^6.1.3 crop_image: ^1.0.16 + crypto: ^3.0.6 device_info_plus: ^11.3.3 dynamic_color: ^1.7.0 easy_image_viewer: ^1.5.1 diff --git a/open-api/immich-openapi-specs.json b/open-api/immich-openapi-specs.json index 169c0763416..c9ea04ac5f2 100644 --- a/open-api/immich-openapi-specs.json +++ b/open-api/immich-openapi-specs.json @@ -10354,6 +10354,12 @@ }, "OAuthCallbackDto": { "properties": { + "codeVerifier": { + "type": "string" + }, + "state": { + "type": "string" + }, "url": { "type": "string" } @@ -10365,8 +10371,14 @@ }, "OAuthConfigDto": { "properties": { + "codeChallenge": { + "type": "string" + }, "redirectUri": { "type": "string" + }, + "state": { + "type": "string" } }, "required": [ diff --git a/open-api/typescript-sdk/src/fetch-client.ts b/open-api/typescript-sdk/src/fetch-client.ts index e45449c9cd6..3b0b32916da 100644 --- a/open-api/typescript-sdk/src/fetch-client.ts +++ b/open-api/typescript-sdk/src/fetch-client.ts @@ -688,12 +688,16 @@ export type TestEmailResponseDto = { }; export type OAuthConfigDto = { redirectUri: string; + state?: string; + codeChallenge?: string; }; export type OAuthAuthorizeResponseDto = { url: string; }; export type OAuthCallbackDto = { url: string; + state?: string; + codeVerifier?: string; }; export type PartnerResponseDto = { avatarColor: UserAvatarColor; diff --git a/server/package-lock.json b/server/package-lock.json index 297370187d3..72fd6f451d6 100644 --- a/server/package-lock.json +++ b/server/package-lock.json @@ -52,7 +52,7 @@ "nestjs-kysely": "^1.1.0", "nestjs-otel": "^6.0.0", "nodemailer": "^6.9.13", - "openid-client": "^5.4.3", + "openid-client": "^6.3.3", "pg": "^8.11.3", "picomatch": "^4.0.2", "react": "^19.0.0", @@ -11370,9 +11370,9 @@ } }, "node_modules/jose": { - "version": "4.15.9", - "resolved": "https://registry.npmjs.org/jose/-/jose-4.15.9.tgz", - "integrity": "sha512-1vUQX+IdDMVPj4k8kOxgUqlcK518yluMuGZwqlr44FS1ppZB/5GWh4rZG89erpOBOJjU/OBsnCVFfapsRz6nEA==", + "version": "6.0.8", + "resolved": "https://registry.npmjs.org/jose/-/jose-6.0.8.tgz", + "integrity": "sha512-EyUPtOKyTYq+iMOszO42eobQllaIjJnwkZ2U93aJzNyPibCy7CEvT9UQnaCVB51IAd49gbNdCew1c0LcLTCB2g==", "license": "MIT", "funding": { "url": "https://github.com/sponsors/panva" @@ -11879,18 +11879,6 @@ "dev": true, "license": "MIT" }, - "node_modules/lru-cache": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", - "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", - "license": "ISC", - "dependencies": { - "yallist": "^4.0.0" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/luxon": { "version": "3.6.1", "resolved": "https://registry.npmjs.org/luxon/-/luxon-3.6.1.tgz", @@ -12750,6 +12738,14 @@ "set-blocking": "^2.0.0" } }, + "node_modules/oauth4webapi": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/oauth4webapi/-/oauth4webapi-3.3.0.tgz", + "integrity": "sha512-ZlozhPlFfobzh3hB72gnBFLjXpugl/dljz1fJSRdqaV2r3D5dmi5lg2QWI0LmUYuazmE+b5exsloEv6toUtw9g==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/panva" + } "node_modules/nwsapi": { "version": "2.2.20", "resolved": "https://registry.npmjs.org/nwsapi/-/nwsapi-2.2.20.tgz", @@ -12869,29 +12865,18 @@ } }, "node_modules/openid-client": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/openid-client/-/openid-client-5.7.1.tgz", - "integrity": "sha512-jDBPgSVfTnkIh71Hg9pRvtJc6wTwqjRkN88+gCFtYWrlP4Yx2Dsrow8uPi3qLr/aeymPF3o2+dS+wOpglK04ew==", + "version": "6.3.3", + "resolved": "https://registry.npmjs.org/openid-client/-/openid-client-6.3.3.tgz", + "integrity": "sha512-lTK8AV8SjqCM4qznLX0asVESAwzV39XTVdfMAM185ekuaZCnkWdPzcxMTXNlsm9tsUAMa1Q30MBmKAykdT1LWw==", "license": "MIT", "dependencies": { - "jose": "^4.15.9", - "lru-cache": "^6.0.0", - "object-hash": "^2.2.0", - "oidc-token-hash": "^5.0.3" + "jose": "^6.0.6", + "oauth4webapi": "^3.3.0" }, "funding": { "url": "https://github.com/sponsors/panva" } }, - "node_modules/openid-client/node_modules/object-hash": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-2.2.0.tgz", - "integrity": "sha512-gScRMn0bS5fH+IuwyIFgnh9zBdo4DV+6GhygmWM9HyNJSgS0hScp1f5vjtm7oIIOiT9trXrShAkLFSc2IqKNgw==", - "license": "MIT", - "engines": { - "node": ">= 6" - } - }, "node_modules/optionator": { "version": "0.9.4", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", diff --git a/server/package.json b/server/package.json index 441f04abebf..178f0fb0a03 100644 --- a/server/package.json +++ b/server/package.json @@ -77,7 +77,7 @@ "nestjs-kysely": "^1.1.0", "nestjs-otel": "^6.0.0", "nodemailer": "^6.9.13", - "openid-client": "^5.4.3", + "openid-client": "^6.3.3", "pg": "^8.11.3", "picomatch": "^4.0.2", "react": "^19.0.0", diff --git a/server/src/controllers/oauth.controller.ts b/server/src/controllers/oauth.controller.ts index b5b94030f2e..23ddff5ddcb 100644 --- a/server/src/controllers/oauth.controller.ts +++ b/server/src/controllers/oauth.controller.ts @@ -29,17 +29,35 @@ export class OAuthController { } @Post('authorize') - startOAuth(@Body() dto: OAuthConfigDto): Promise { - return this.service.authorize(dto); + async startOAuth( + @Body() dto: OAuthConfigDto, + @Res({ passthrough: true }) res: Response, + @GetLoginDetails() loginDetails: LoginDetails, + ): Promise { + const { url, state, codeVerifier } = await this.service.authorize(dto); + return respondWithCookie( + res, + { url }, + { + isSecure: loginDetails.isSecure, + values: [ + { key: ImmichCookie.OAUTH_STATE, value: state }, + { key: ImmichCookie.OAUTH_CODE_VERIFIER, value: codeVerifier }, + ], + }, + ); } @Post('callback') async finishOAuth( + @Req() request: Request, @Res({ passthrough: true }) res: Response, @Body() dto: OAuthCallbackDto, @GetLoginDetails() loginDetails: LoginDetails, ): Promise { - const body = await this.service.callback(dto, loginDetails); + const body = await this.service.callback(dto, request.headers, loginDetails); + res.clearCookie(ImmichCookie.OAUTH_STATE); + res.clearCookie(ImmichCookie.OAUTH_CODE_VERIFIER); return respondWithCookie(res, body, { isSecure: loginDetails.isSecure, values: [ @@ -52,8 +70,12 @@ export class OAuthController { @Post('link') @Authenticated() - linkOAuthAccount(@Auth() auth: AuthDto, @Body() dto: OAuthCallbackDto): Promise { - return this.service.link(auth, dto); + linkOAuthAccount( + @Req() request: Request, + @Auth() auth: AuthDto, + @Body() dto: OAuthCallbackDto, + ): Promise { + return this.service.link(auth, dto, request.headers); } @Post('unlink') diff --git a/server/src/dtos/auth.dto.ts b/server/src/dtos/auth.dto.ts index 7f2ffa58783..a1978d39dd0 100644 --- a/server/src/dtos/auth.dto.ts +++ b/server/src/dtos/auth.dto.ts @@ -3,11 +3,11 @@ import { Transform } from 'class-transformer'; import { IsEmail, IsNotEmpty, IsString, MinLength } from 'class-validator'; import { AuthApiKey, AuthSession, AuthSharedLink, AuthUser, UserAdmin } from 'src/database'; import { ImmichCookie } from 'src/enum'; -import { toEmail } from 'src/validation'; +import { Optional, toEmail } from 'src/validation'; export type CookieResponse = { isSecure: boolean; - values: Array<{ key: ImmichCookie; value: string }>; + values: Array<{ key: ImmichCookie; value: string | null }>; }; export class AuthDto { @@ -87,12 +87,28 @@ export class OAuthCallbackDto { @IsString() @ApiProperty() url!: string; + + @Optional() + @IsString() + state?: string; + + @Optional() + @IsString() + codeVerifier?: string; } export class OAuthConfigDto { @IsNotEmpty() @IsString() redirectUri!: string; + + @Optional() + @IsString() + state?: string; + + @Optional() + @IsString() + codeChallenge?: string; } export class OAuthAuthorizeResponseDto { diff --git a/server/src/enum.ts b/server/src/enum.ts index e5c6039be8a..baf864aa497 100644 --- a/server/src/enum.ts +++ b/server/src/enum.ts @@ -8,6 +8,8 @@ export enum ImmichCookie { AUTH_TYPE = 'immich_auth_type', IS_AUTHENTICATED = 'immich_is_authenticated', SHARED_LINK_TOKEN = 'immich_shared_link_token', + OAUTH_STATE = 'immich_oauth_state', + OAUTH_CODE_VERIFIER = 'immich_oauth_code_verifier', } export enum ImmichHeader { diff --git a/server/src/repositories/oauth.repository.ts b/server/src/repositories/oauth.repository.ts index dc19a1fe015..d3e03720897 100644 --- a/server/src/repositories/oauth.repository.ts +++ b/server/src/repositories/oauth.repository.ts @@ -1,5 +1,5 @@ import { Injectable, InternalServerErrorException } from '@nestjs/common'; -import { custom, generators, Issuer, UserinfoResponse } from 'openid-client'; +import type { UserInfoResponse } from 'openid-client' with { 'resolution-mode': 'import' }; import { LoggingRepository } from 'src/repositories/logging.repository'; export type OAuthConfig = { @@ -12,7 +12,7 @@ export type OAuthConfig = { scope: string; signingAlgorithm: string; }; -export type OAuthProfile = UserinfoResponse; +export type OAuthProfile = UserInfoResponse; @Injectable() export class OAuthRepository { @@ -20,30 +20,47 @@ export class OAuthRepository { this.logger.setContext(OAuthRepository.name); } - init() { - custom.setHttpOptionsDefaults({ timeout: 30_000 }); - } - - async authorize(config: OAuthConfig, redirectUrl: string) { + async authorize(config: OAuthConfig, redirectUrl: string, state?: string, codeChallenge?: string) { + const { buildAuthorizationUrl, randomState, randomPKCECodeVerifier, calculatePKCECodeChallenge } = await import( + 'openid-client' + ); const client = await this.getClient(config); - return client.authorizationUrl({ + state ??= randomState(); + let codeVerifier: string | null; + if (codeChallenge) { + codeVerifier = null; + } else { + codeVerifier = randomPKCECodeVerifier(); + codeChallenge = await calculatePKCECodeChallenge(codeVerifier); + } + const url = buildAuthorizationUrl(client, { redirect_uri: redirectUrl, scope: config.scope, - state: generators.state(), - }); + state, + code_challenge: codeChallenge, + code_challenge_method: 'S256', + }).toString(); + return { url, state, codeVerifier }; } async getLogoutEndpoint(config: OAuthConfig) { const client = await this.getClient(config); - return client.issuer.metadata.end_session_endpoint; + return client.serverMetadata().end_session_endpoint; } - async getProfile(config: OAuthConfig, url: string, redirectUrl: string): Promise { + async getProfile( + config: OAuthConfig, + url: string, + expectedState: string, + codeVerifier: string, + ): Promise { + const { authorizationCodeGrant, fetchUserInfo, ...oidc } = await import('openid-client'); const client = await this.getClient(config); - const params = client.callbackParams(url); + const pkceCodeVerifier = client.serverMetadata().supportsPKCE() ? codeVerifier : undefined; + try { - const tokens = await client.callback(redirectUrl, params, { state: params.state }); - const profile = await client.userinfo(tokens.access_token || ''); + const tokens = await authorizationCodeGrant(client, new URL(url), { expectedState, pkceCodeVerifier }); + const profile = await fetchUserInfo(client, tokens.access_token, oidc.skipSubjectCheck); if (!profile.sub) { throw new Error('Unexpected profile response, no `sub`'); } @@ -59,6 +76,11 @@ export class OAuthRepository { ); } + if (error.code === 'OAUTH_INVALID_RESPONSE') { + this.logger.warn(`Invalid response from authorization server. Cause: ${error.cause?.message}`); + throw error.cause; + } + throw error; } } @@ -83,14 +105,20 @@ export class OAuthRepository { signingAlgorithm, }: OAuthConfig) { try { - const issuer = await Issuer.discover(issuerUrl); - return new issuer.Client({ - client_id: clientId, - client_secret: clientSecret, - response_types: ['code'], - userinfo_signed_response_alg: profileSigningAlgorithm === 'none' ? undefined : profileSigningAlgorithm, - id_token_signed_response_alg: signingAlgorithm, - }); + const { allowInsecureRequests, discovery } = await import('openid-client'); + return await discovery( + new URL(issuerUrl), + clientId, + { + client_secret: clientSecret, + response_types: ['code'], + userinfo_signed_response_alg: profileSigningAlgorithm === 'none' ? undefined : profileSigningAlgorithm, + id_token_signed_response_alg: signingAlgorithm, + timeout: 30_000, + }, + undefined, + { execute: [allowInsecureRequests] }, + ); } catch (error: any | AggregateError) { this.logger.error(`Error in OAuth discovery: ${error}`, error?.stack, error?.errors); throw new InternalServerErrorException(`Error in OAuth discovery: ${error}`, { cause: error }); diff --git a/server/src/services/auth.service.spec.ts b/server/src/services/auth.service.spec.ts index b1bfe00e852..46241599253 100644 --- a/server/src/services/auth.service.spec.ts +++ b/server/src/services/auth.service.spec.ts @@ -55,7 +55,7 @@ describe(AuthService.name, () => { beforeEach(() => { ({ sut, mocks } = newTestService(AuthService)); - mocks.oauth.authorize.mockResolvedValue('access-token'); + mocks.oauth.authorize.mockResolvedValue({ url: 'http://test', state: 'state', codeVerifier: 'codeVerifier' }); mocks.oauth.getProfile.mockResolvedValue({ sub, email }); mocks.oauth.getLogoutEndpoint.mockResolvedValue('http://end-session-endpoint'); }); @@ -64,16 +64,6 @@ describe(AuthService.name, () => { expect(sut).toBeDefined(); }); - describe('onBootstrap', () => { - it('should init the repo', () => { - mocks.oauth.init.mockResolvedValue(); - - sut.onBootstrap(); - - expect(mocks.oauth.init).toHaveBeenCalled(); - }); - }); - describe('login', () => { it('should throw an error if password login is disabled', async () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.disabled); @@ -519,16 +509,22 @@ describe(AuthService.name, () => { describe('callback', () => { it('should throw an error if OAuth is not enabled', async () => { - await expect(sut.callback({ url: '' }, loginDetails)).rejects.toBeInstanceOf(BadRequestException); + await expect( + sut.callback({ url: '', state: 'xyz789', codeVerifier: 'foo' }, {}, loginDetails), + ).rejects.toBeInstanceOf(BadRequestException); }); it('should not allow auto registering', async () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.oauthEnabled); mocks.user.getByEmail.mockResolvedValue(void 0); - await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).rejects.toBeInstanceOf( - BadRequestException, - ); + await expect( + sut.callback( + { url: 'http://immich/auth/login?code=abc123', state: 'xyz789', codeVerifier: 'foo' }, + {}, + loginDetails, + ), + ).rejects.toBeInstanceOf(BadRequestException); expect(mocks.user.getByEmail).toHaveBeenCalledTimes(1); }); @@ -541,9 +537,13 @@ describe(AuthService.name, () => { mocks.user.update.mockResolvedValue(user); mocks.session.create.mockResolvedValue(factory.session()); - await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual( - oauthResponse(user), - ); + await expect( + sut.callback( + { url: 'http://immich/auth/login?code=abc123', state: 'xyz789', codeVerifier: 'foobar' }, + {}, + loginDetails, + ), + ).resolves.toEqual(oauthResponse(user)); expect(mocks.user.getByEmail).toHaveBeenCalledTimes(1); expect(mocks.user.update).toHaveBeenCalledWith(user.id, { oauthId: sub }); @@ -557,9 +557,13 @@ describe(AuthService.name, () => { mocks.user.getAdmin.mockResolvedValue(user); mocks.user.create.mockResolvedValue(user); - await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).rejects.toThrow( - BadRequestException, - ); + await expect( + sut.callback( + { url: 'http://immich/auth/login?code=abc123', state: 'xyz789', codeVerifier: 'foobar' }, + {}, + loginDetails, + ), + ).rejects.toThrow(BadRequestException); expect(mocks.user.update).not.toHaveBeenCalled(); expect(mocks.user.create).not.toHaveBeenCalled(); @@ -574,9 +578,13 @@ describe(AuthService.name, () => { mocks.user.create.mockResolvedValue(user); mocks.session.create.mockResolvedValue(factory.session()); - await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual( - oauthResponse(user), - ); + await expect( + sut.callback( + { url: 'http://immich/auth/login?code=abc123', state: 'xyz789', codeVerifier: 'foobar' }, + {}, + loginDetails, + ), + ).resolves.toEqual(oauthResponse(user)); expect(mocks.user.getByEmail).toHaveBeenCalledTimes(2); // second call is for domain check before create expect(mocks.user.create).toHaveBeenCalledTimes(1); @@ -592,18 +600,19 @@ describe(AuthService.name, () => { mocks.session.create.mockResolvedValue(factory.session()); mocks.oauth.getProfile.mockResolvedValue({ sub, email: undefined }); - await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).rejects.toBeInstanceOf( - BadRequestException, - ); + await expect( + sut.callback( + { url: 'http://immich/auth/login?code=abc123', state: 'xyz789', codeVerifier: 'foobar' }, + {}, + loginDetails, + ), + ).rejects.toBeInstanceOf(BadRequestException); expect(mocks.user.getByEmail).not.toHaveBeenCalled(); expect(mocks.user.create).not.toHaveBeenCalled(); }); for (const url of [ - 'app.immich:/', - 'app.immich://', - 'app.immich:///', 'app.immich:/oauth-callback?code=abc123', 'app.immich://oauth-callback?code=abc123', 'app.immich:///oauth-callback?code=abc123', @@ -615,9 +624,14 @@ describe(AuthService.name, () => { mocks.user.getByOAuthId.mockResolvedValue(user); mocks.session.create.mockResolvedValue(factory.session()); - await sut.callback({ url }, loginDetails); + await sut.callback({ url, state: 'xyz789', codeVerifier: 'foo' }, {}, loginDetails); - expect(mocks.oauth.getProfile).toHaveBeenCalledWith(expect.objectContaining({}), url, 'http://mobile-redirect'); + expect(mocks.oauth.getProfile).toHaveBeenCalledWith( + expect.objectContaining({}), + 'http://mobile-redirect?code=abc123', + 'xyz789', + 'foo', + ); }); } @@ -630,9 +644,13 @@ describe(AuthService.name, () => { mocks.user.create.mockResolvedValue(user); mocks.session.create.mockResolvedValue(factory.session()); - await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual( - oauthResponse(user), - ); + await expect( + sut.callback( + { url: 'http://immich/auth/login?code=abc123', state: 'xyz789', codeVerifier: 'foo' }, + {}, + loginDetails, + ), + ).resolves.toEqual(oauthResponse(user)); expect(mocks.user.create).toHaveBeenCalledWith(expect.objectContaining({ quotaSizeInBytes: 1_073_741_824 })); }); @@ -647,9 +665,13 @@ describe(AuthService.name, () => { mocks.user.create.mockResolvedValue(user); mocks.session.create.mockResolvedValue(factory.session()); - await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual( - oauthResponse(user), - ); + await expect( + sut.callback( + { url: 'http://immich/auth/login?code=abc123', state: 'xyz789', codeVerifier: 'foo' }, + {}, + loginDetails, + ), + ).resolves.toEqual(oauthResponse(user)); expect(mocks.user.create).toHaveBeenCalledWith(expect.objectContaining({ quotaSizeInBytes: 1_073_741_824 })); }); @@ -664,9 +686,13 @@ describe(AuthService.name, () => { mocks.user.create.mockResolvedValue(user); mocks.session.create.mockResolvedValue(factory.session()); - await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual( - oauthResponse(user), - ); + await expect( + sut.callback( + { url: 'http://immich/auth/login?code=abc123', state: 'xyz789', codeVerifier: 'foo' }, + {}, + loginDetails, + ), + ).resolves.toEqual(oauthResponse(user)); expect(mocks.user.create).toHaveBeenCalledWith(expect.objectContaining({ quotaSizeInBytes: 1_073_741_824 })); }); @@ -681,9 +707,13 @@ describe(AuthService.name, () => { mocks.user.create.mockResolvedValue(user); mocks.session.create.mockResolvedValue(factory.session()); - await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual( - oauthResponse(user), - ); + await expect( + sut.callback( + { url: 'http://immich/auth/login?code=abc123', state: 'xyz789', codeVerifier: 'foo' }, + {}, + loginDetails, + ), + ).resolves.toEqual(oauthResponse(user)); expect(mocks.user.create).toHaveBeenCalledWith({ email: user.email, @@ -705,9 +735,13 @@ describe(AuthService.name, () => { mocks.user.create.mockResolvedValue(user); mocks.session.create.mockResolvedValue(factory.session()); - await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' }, loginDetails)).resolves.toEqual( - oauthResponse(user), - ); + await expect( + sut.callback( + { url: 'http://immich/auth/login?code=abc123', state: 'xyz789', codeVerifier: 'foo' }, + {}, + loginDetails, + ), + ).resolves.toEqual(oauthResponse(user)); expect(mocks.user.create).toHaveBeenCalledWith({ email: user.email, @@ -779,7 +813,11 @@ describe(AuthService.name, () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled); mocks.user.update.mockResolvedValue(user); - await sut.link(auth, { url: 'http://immich/user-settings?code=abc123' }); + await sut.link( + auth, + { url: 'http://immich/user-settings?code=abc123', state: 'xyz789', codeVerifier: 'foo' }, + {}, + ); expect(mocks.user.update).toHaveBeenCalledWith(auth.user.id, { oauthId: sub }); }); @@ -792,9 +830,9 @@ describe(AuthService.name, () => { mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.enabled); mocks.user.getByOAuthId.mockResolvedValue({ id: 'other-user' } as UserAdmin); - await expect(sut.link(auth, { url: 'http://immich/user-settings?code=abc123' })).rejects.toBeInstanceOf( - BadRequestException, - ); + await expect( + sut.link(auth, { url: 'http://immich/user-settings?code=abc123', state: 'xyz789', codeVerifier: 'foo' }, {}), + ).rejects.toBeInstanceOf(BadRequestException); expect(mocks.user.update).not.toHaveBeenCalled(); }); diff --git a/server/src/services/auth.service.ts b/server/src/services/auth.service.ts index ee4ca4dc5d8..b250b63a5ef 100644 --- a/server/src/services/auth.service.ts +++ b/server/src/services/auth.service.ts @@ -7,13 +7,11 @@ import { join } from 'node:path'; import { LOGIN_URL, MOBILE_REDIRECT, SALT_ROUNDS } from 'src/constants'; import { StorageCore } from 'src/cores/storage.core'; import { UserAdmin } from 'src/database'; -import { OnEvent } from 'src/decorators'; import { AuthDto, ChangePasswordDto, LoginCredentialDto, LogoutResponseDto, - OAuthAuthorizeResponseDto, OAuthCallbackDto, OAuthConfigDto, SignUpDto, @@ -52,11 +50,6 @@ export type ValidateRequest = { @Injectable() export class AuthService extends BaseService { - @OnEvent({ name: 'app.bootstrap' }) - onBootstrap() { - this.oauthRepository.init(); - } - async login(dto: LoginCredentialDto, details: LoginDetails) { const config = await this.getConfig({ withCache: false }); if (!config.passwordLogin.enabled) { @@ -176,20 +169,35 @@ export class AuthService extends BaseService { return `${MOBILE_REDIRECT}?${url.split('?')[1] || ''}`; } - async authorize(dto: OAuthConfigDto): Promise { + async authorize(dto: OAuthConfigDto) { const { oauth } = await this.getConfig({ withCache: false }); if (!oauth.enabled) { throw new BadRequestException('OAuth is not enabled'); } - const url = await this.oauthRepository.authorize(oauth, this.resolveRedirectUri(oauth, dto.redirectUri)); - return { url }; + return await this.oauthRepository.authorize( + oauth, + this.resolveRedirectUri(oauth, dto.redirectUri), + dto.state, + dto.codeChallenge, + ); } - async callback(dto: OAuthCallbackDto, loginDetails: LoginDetails) { + async callback(dto: OAuthCallbackDto, headers: IncomingHttpHeaders, loginDetails: LoginDetails) { + const expectedState = dto.state ?? this.getCookieOauthState(headers); + if (!expectedState?.length) { + throw new BadRequestException('OAuth state is missing'); + } + + const codeVerifier = dto.codeVerifier ?? this.getCookieCodeVerifier(headers); + if (!codeVerifier?.length) { + throw new BadRequestException('OAuth code verifier is missing'); + } + const { oauth } = await this.getConfig({ withCache: false }); - const profile = await this.oauthRepository.getProfile(oauth, dto.url, this.resolveRedirectUri(oauth, dto.url)); + const url = this.resolveRedirectUri(oauth, dto.url); + const profile = await this.oauthRepository.getProfile(oauth, url, expectedState, codeVerifier); const { autoRegister, defaultStorageQuota, storageLabelClaim, storageQuotaClaim } = oauth; this.logger.debug(`Logging in with OAuth: ${JSON.stringify(profile)}`); let user: UserAdmin | undefined = await this.userRepository.getByOAuthId(profile.sub); @@ -271,13 +279,19 @@ export class AuthService extends BaseService { } } - async link(auth: AuthDto, dto: OAuthCallbackDto): Promise { + async link(auth: AuthDto, dto: OAuthCallbackDto, headers: IncomingHttpHeaders): Promise { + const expectedState = dto.state ?? this.getCookieOauthState(headers); + if (!expectedState?.length) { + throw new BadRequestException('OAuth state is missing'); + } + + const codeVerifier = dto.codeVerifier ?? this.getCookieCodeVerifier(headers); + if (!codeVerifier?.length) { + throw new BadRequestException('OAuth code verifier is missing'); + } + const { oauth } = await this.getConfig({ withCache: false }); - const { sub: oauthId } = await this.oauthRepository.getProfile( - oauth, - dto.url, - this.resolveRedirectUri(oauth, dto.url), - ); + const { sub: oauthId } = await this.oauthRepository.getProfile(oauth, dto.url, expectedState, codeVerifier); const duplicate = await this.userRepository.getByOAuthId(oauthId); if (duplicate && duplicate.id !== auth.user.id) { this.logger.warn(`OAuth link account failed: sub is already linked to another user (${duplicate.email}).`); @@ -320,6 +334,16 @@ export class AuthService extends BaseService { return cookies[ImmichCookie.ACCESS_TOKEN] || null; } + private getCookieOauthState(headers: IncomingHttpHeaders): string | null { + const cookies = parse(headers.cookie || ''); + return cookies[ImmichCookie.OAUTH_STATE] || null; + } + + private getCookieCodeVerifier(headers: IncomingHttpHeaders): string | null { + const cookies = parse(headers.cookie || ''); + return cookies[ImmichCookie.OAUTH_CODE_VERIFIER] || null; + } + async validateSharedLink(key: string | string[]): Promise { key = Array.isArray(key) ? key[0] : key; @@ -399,11 +423,9 @@ export class AuthService extends BaseService { { mobileRedirectUri, mobileOverrideEnabled }: { mobileRedirectUri: string; mobileOverrideEnabled: boolean }, url: string, ) { - const redirectUri = url.split('?')[0]; - const isMobile = redirectUri.startsWith('app.immich:/'); - if (isMobile && mobileOverrideEnabled && mobileRedirectUri) { - return mobileRedirectUri; + if (mobileOverrideEnabled && mobileRedirectUri) { + return url.replace(/app\.immich:\/+oauth-callback/, mobileRedirectUri); } - return redirectUri; + return url; } } diff --git a/server/src/utils/response.ts b/server/src/utils/response.ts index 679d947afb7..a50e86a4ffb 100644 --- a/server/src/utils/response.ts +++ b/server/src/utils/response.ts @@ -15,6 +15,8 @@ export const respondWithCookie = (res: Response, body: T, { isSecure, values const cookieOptions: Record = { [ImmichCookie.AUTH_TYPE]: defaults, [ImmichCookie.ACCESS_TOKEN]: defaults, + [ImmichCookie.OAUTH_STATE]: defaults, + [ImmichCookie.OAUTH_CODE_VERIFIER]: defaults, // no httpOnly so that the client can know the auth state [ImmichCookie.IS_AUTHENTICATED]: { ...defaults, httpOnly: false }, [ImmichCookie.SHARED_LINK_TOKEN]: { ...defaults, maxAge: Duration.fromObject({ days: 1 }).toMillis() },