diff --git a/server/src/domain/database/database.service.spec.ts b/server/src/domain/database/database.service.spec.ts new file mode 100644 index 00000000000..827061ce566 --- /dev/null +++ b/server/src/domain/database/database.service.spec.ts @@ -0,0 +1,134 @@ +import { ImmichLogger } from '@app/infra/logger'; +import { newDatabaseRepositoryMock } from '@test'; +import { Version } from '../domain.constant'; +import { DatabaseExtension, IDatabaseRepository } from '../repositories'; +import { DatabaseService } from './database.service'; + +describe(DatabaseService.name, () => { + let sut: DatabaseService; + let databaseMock: jest.Mocked; + let fatalLog: jest.SpyInstance; + + beforeEach(async () => { + databaseMock = newDatabaseRepositoryMock(); + fatalLog = jest.spyOn(ImmichLogger.prototype, 'fatal'); + + sut = new DatabaseService(databaseMock); + + sut.minVectorsVersion = new Version(0, 1, 1); + sut.maxVectorsVersion = new Version(0, 1, 11); + }); + + afterEach(() => { + fatalLog.mockRestore(); + }); + + it('should work', () => { + expect(sut).toBeDefined(); + }); + + describe('init', () => { + it('should return if minimum supported vectors version is installed', async () => { + databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 1)); + + await sut.init(); + + expect(databaseMock.createExtension).toHaveBeenCalledWith(DatabaseExtension.VECTORS); + expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); + expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); + expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); + expect(fatalLog).not.toHaveBeenCalled(); + }); + + it('should return if maximum supported vectors version is installed', async () => { + databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 11)); + + await sut.init(); + + expect(databaseMock.createExtension).toHaveBeenCalledWith(DatabaseExtension.VECTORS); + expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); + expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); + expect(databaseMock.runMigrations).toHaveBeenCalledTimes(1); + expect(fatalLog).not.toHaveBeenCalled(); + }); + + it('should throw an error if vectors version is not installed even after createVectors', async () => { + databaseMock.getExtensionVersion.mockResolvedValueOnce(null); + + await expect(sut.init()).rejects.toThrow('Unexpected: The pgvecto.rs extension is not installed.'); + + expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); + expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); + expect(databaseMock.runMigrations).not.toHaveBeenCalled(); + }); + + it('should throw an error if vectors version is below minimum supported version', async () => { + databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 0, 1)); + + await expect(sut.init()).rejects.toThrow(/('tensorchord\/pgvecto-rs:pg14-v0.1.11')/s); + + expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); + expect(databaseMock.runMigrations).not.toHaveBeenCalled(); + }); + + it('should throw an error if vectors version is above maximum supported version', async () => { + databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 1, 12)); + + await expect(sut.init()).rejects.toThrow( + /('DROP EXTENSION IF EXISTS vectors').*('tensorchord\/pgvecto-rs:pg14-v0\.1\.11')/s, + ); + + expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); + expect(databaseMock.runMigrations).not.toHaveBeenCalled(); + }); + + it('should throw an error if vectors version is a nightly', async () => { + databaseMock.getExtensionVersion.mockResolvedValueOnce(new Version(0, 0, 0)); + + await expect(sut.init()).rejects.toThrow( + /(nightly).*('DROP EXTENSION IF EXISTS vectors').*('tensorchord\/pgvecto-rs:pg14-v0\.1\.11')/s, + ); + + expect(databaseMock.getExtensionVersion).toHaveBeenCalledTimes(1); + expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); + expect(databaseMock.runMigrations).not.toHaveBeenCalled(); + }); + + it('should throw error if vectors extension could not be created', async () => { + databaseMock.createExtension.mockRejectedValueOnce(new Error('Failed to create extension')); + + await expect(sut.init()).rejects.toThrow('Failed to create extension'); + + expect(fatalLog).toHaveBeenCalledTimes(1); + expect(fatalLog.mock.calls[0][0]).toMatch(/('tensorchord\/pgvecto-rs:pg14-v0\.1\.11').*(v1\.91\.0)/s); + expect(databaseMock.createExtension).toHaveBeenCalledTimes(1); + expect(databaseMock.runMigrations).not.toHaveBeenCalled(); + }); + + for (const major of [14, 15, 16]) { + it(`should suggest image with postgres ${major} if database is ${major}`, async () => { + databaseMock.getExtensionVersion.mockResolvedValue(new Version(0, 0, 1)); + databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(major, 0, 0)); + + await expect(sut.init()).rejects.toThrow(new RegExp(`tensorchord\/pgvecto-rs:pg${major}-v0\\.1\\.11`, 's')); + }); + } + + it('should not suggest image if postgres version is not in 14, 15 or 16', async () => { + databaseMock.getExtensionVersion.mockResolvedValue(new Version(0, 0, 1)); + [13, 17].forEach((major) => databaseMock.getPostgresVersion.mockResolvedValueOnce(new Version(major, 0, 0))); + + await expect(sut.init()).rejects.toThrow(/^(?:(?!tensorchord\/pgvecto-rs).)*$/s); + await expect(sut.init()).rejects.toThrow(/^(?:(?!tensorchord\/pgvecto-rs).)*$/s); + }); + + it('should set the image to the maximum supported version', async () => { + databaseMock.getExtensionVersion.mockResolvedValue(new Version(0, 0, 1)); + + await expect(sut.init()).rejects.toThrow(/('tensorchord\/pgvecto-rs:pg14-v0\.1\.11')/s); + + sut.maxVectorsVersion = new Version(0, 1, 12); + await expect(sut.init()).rejects.toThrow(/('tensorchord\/pgvecto-rs:pg14-v0\.1\.12')/s); + }); + }); +}); diff --git a/server/src/domain/database/database.service.ts b/server/src/domain/database/database.service.ts new file mode 100644 index 00000000000..f70fc48d5e6 --- /dev/null +++ b/server/src/domain/database/database.service.ts @@ -0,0 +1,69 @@ +import { ImmichLogger } from '@app/infra/logger'; +import { Inject, Injectable } from '@nestjs/common'; +import { QueryFailedError } from 'typeorm'; +import { Version } from '../domain.constant'; +import { DatabaseExtension, IDatabaseRepository } from '../repositories'; + +@Injectable() +export class DatabaseService { + private logger = new ImmichLogger(DatabaseService.name); + minVectorsVersion = new Version(0, 1, 1); + maxVectorsVersion = new Version(0, 1, 11); + + constructor(@Inject(IDatabaseRepository) private databaseRepository: IDatabaseRepository) {} + + async init() { + await this.createVectors(); + await this.assertVectors(); + await this.databaseRepository.runMigrations(); + } + + private async assertVectors() { + const version = await this.databaseRepository.getExtensionVersion(DatabaseExtension.VECTORS); + if (version == null) { + throw new Error('Unexpected: The pgvecto.rs extension is not installed.'); + } + + const image = await this.getVectorsImage(); + const suggestion = image ? `, such as with the docker image '${image}'` : ''; + + if (version.isEqual(new Version(0, 0, 0))) { + throw new Error( + `The pgvecto.rs extension version is ${version}, which means it is a nightly release.` + + `Please run 'DROP EXTENSION IF EXISTS vectors' and switch to a release version${suggestion}.`, + ); + } + + if (version.isNewerThan(this.maxVectorsVersion)) { + throw new Error(` + The pgvecto.rs extension version is ${version} instead of ${this.maxVectorsVersion}. + Please run 'DROP EXTENSION IF EXISTS vectors' and switch to ${this.maxVectorsVersion}${suggestion}.`); + } + + if (version.isOlderThan(this.minVectorsVersion)) { + throw new Error(` + The pgvecto.rs extension version is ${version}, which is older than the minimum supported version ${this.minVectorsVersion}. + Please upgrade to this version or later${suggestion}.`); + } + } + + private async createVectors() { + await this.databaseRepository.createExtension(DatabaseExtension.VECTORS).catch(async (err: QueryFailedError) => { + const image = await this.getVectorsImage(); + this.logger.fatal(` + Failed to create pgvecto.rs extension. + If you have not updated your Postgres instance to a docker image that supports pgvecto.rs (such as '${image}'), please do so. + See the v1.91.0 release notes for more info: https://github.com/immich-app/immich/releases/tag/v1.91.0' + `); + throw err; + }); + } + + private async getVectorsImage() { + const { major } = await this.databaseRepository.getPostgresVersion(); + if (![14, 15, 16].includes(major)) { + return null; + } + return `tensorchord/pgvecto-rs:pg${major}-v${this.maxVectorsVersion}`; + } +} diff --git a/server/src/domain/database/index.ts b/server/src/domain/database/index.ts new file mode 100644 index 00000000000..cd4e1d21772 --- /dev/null +++ b/server/src/domain/database/index.ts @@ -0,0 +1 @@ +export * from './database.service'; diff --git a/server/src/domain/domain.constant.spec.ts b/server/src/domain/domain.constant.spec.ts index 0163254301b..89465638bc3 100644 --- a/server/src/domain/domain.constant.spec.ts +++ b/server/src/domain/domain.constant.spec.ts @@ -1,4 +1,4 @@ -import { ServerVersion, mimeTypes } from './domain.constant'; +import { Version, mimeTypes } from './domain.constant'; describe('mimeTypes', () => { for (const { mimetype, extension } of [ @@ -196,64 +196,76 @@ describe('mimeTypes', () => { }); describe('ServerVersion', () => { + const tests = [ + { this: new Version(0, 0, 1), other: new Version(0, 0, 0), expected: 1 }, + { this: new Version(0, 1, 0), other: new Version(0, 0, 0), expected: 1 }, + { this: new Version(1, 0, 0), other: new Version(0, 0, 0), expected: 1 }, + { this: new Version(0, 0, 0), other: new Version(0, 0, 1), expected: -1 }, + { this: new Version(0, 0, 0), other: new Version(0, 1, 0), expected: -1 }, + { this: new Version(0, 0, 0), other: new Version(1, 0, 0), expected: -1 }, + { this: new Version(0, 0, 0), other: new Version(0, 0, 0), expected: 0 }, + { this: new Version(0, 0, 1), other: new Version(0, 0, 1), expected: 0 }, + { this: new Version(0, 1, 0), other: new Version(0, 1, 0), expected: 0 }, + { this: new Version(1, 0, 0), other: new Version(1, 0, 0), expected: 0 }, + { this: new Version(1, 0), other: new Version(1, 0, 0), expected: 0 }, + { this: new Version(1, 0), other: new Version(1, 0, 1), expected: -1 }, + { this: new Version(1, 1), other: new Version(1, 0, 1), expected: 1 }, + { this: new Version(1), other: new Version(1, 0, 0), expected: 0 }, + { this: new Version(1), other: new Version(1, 0, 1), expected: -1 }, + ]; + + describe('compare', () => { + for (const { this: thisVersion, other: otherVersion, expected } of tests) { + it(`should return ${expected} when comparing ${thisVersion} to ${otherVersion}`, () => { + expect(thisVersion.compare(otherVersion)).toEqual(expected); + }); + } + }); + + describe('isOlderThan', () => { + for (const { this: thisVersion, other: otherVersion, expected } of tests) { + const bool = expected < 0; + it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => { + expect(thisVersion.isOlderThan(otherVersion)).toEqual(bool); + }); + } + }); + + describe('isEqual', () => { + for (const { this: thisVersion, other: otherVersion, expected } of tests) { + const bool = expected === 0; + it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => { + expect(thisVersion.isEqual(otherVersion)).toEqual(bool); + }); + } + }); + describe('isNewerThan', () => { - it('should work on patch versions', () => { - expect(new ServerVersion(0, 0, 1).isNewerThan(new ServerVersion(0, 0, 0))).toBe(true); - expect(new ServerVersion(1, 72, 1).isNewerThan(new ServerVersion(1, 72, 0))).toBe(true); - - expect(new ServerVersion(0, 0, 0).isNewerThan(new ServerVersion(0, 0, 1))).toBe(false); - expect(new ServerVersion(1, 72, 0).isNewerThan(new ServerVersion(1, 72, 1))).toBe(false); - }); - - it('should work on minor versions', () => { - expect(new ServerVersion(0, 1, 0).isNewerThan(new ServerVersion(0, 0, 0))).toBe(true); - expect(new ServerVersion(1, 72, 0).isNewerThan(new ServerVersion(1, 71, 0))).toBe(true); - expect(new ServerVersion(1, 72, 0).isNewerThan(new ServerVersion(1, 71, 9))).toBe(true); - - expect(new ServerVersion(0, 0, 0).isNewerThan(new ServerVersion(0, 1, 0))).toBe(false); - expect(new ServerVersion(1, 71, 0).isNewerThan(new ServerVersion(1, 72, 0))).toBe(false); - expect(new ServerVersion(1, 71, 9).isNewerThan(new ServerVersion(1, 72, 0))).toBe(false); - }); - - it('should work on major versions', () => { - expect(new ServerVersion(1, 0, 0).isNewerThan(new ServerVersion(0, 0, 0))).toBe(true); - expect(new ServerVersion(2, 0, 0).isNewerThan(new ServerVersion(1, 71, 0))).toBe(true); - - expect(new ServerVersion(0, 0, 0).isNewerThan(new ServerVersion(1, 0, 0))).toBe(false); - expect(new ServerVersion(1, 71, 0).isNewerThan(new ServerVersion(2, 0, 0))).toBe(false); - }); - - it('should work on equal', () => { - for (const version of [ - new ServerVersion(0, 0, 0), - new ServerVersion(0, 0, 1), - new ServerVersion(0, 1, 1), - new ServerVersion(0, 1, 0), - new ServerVersion(1, 1, 1), - new ServerVersion(1, 0, 0), - new ServerVersion(1, 72, 1), - new ServerVersion(1, 72, 0), - new ServerVersion(1, 73, 9), - ]) { - expect(version.isNewerThan(version)).toBe(false); - } - }); + for (const { this: thisVersion, other: otherVersion, expected } of tests) { + const bool = expected > 0; + it(`should return ${bool} when comparing ${thisVersion} to ${otherVersion}`, () => { + expect(thisVersion.isNewerThan(otherVersion)).toEqual(bool); + }); + } }); describe('fromString', () => { const tests = [ - { scenario: 'leading v', value: 'v1.72.2', expected: new ServerVersion(1, 72, 2) }, - { scenario: 'uppercase v', value: 'V1.72.2', expected: new ServerVersion(1, 72, 2) }, - { scenario: 'missing v', value: '1.72.2', expected: new ServerVersion(1, 72, 2) }, - { scenario: 'large patch', value: '1.72.123', expected: new ServerVersion(1, 72, 123) }, - { scenario: 'large minor', value: '1.123.0', expected: new ServerVersion(1, 123, 0) }, - { scenario: 'large major', value: '123.0.0', expected: new ServerVersion(123, 0, 0) }, - { scenario: 'major bump', value: 'v2.0.0', expected: new ServerVersion(2, 0, 0) }, + { scenario: 'leading v', value: 'v1.72.2', expected: new Version(1, 72, 2) }, + { scenario: 'uppercase v', value: 'V1.72.2', expected: new Version(1, 72, 2) }, + { scenario: 'missing v', value: '1.72.2', expected: new Version(1, 72, 2) }, + { scenario: 'large patch', value: '1.72.123', expected: new Version(1, 72, 123) }, + { scenario: 'large minor', value: '1.123.0', expected: new Version(1, 123, 0) }, + { scenario: 'large major', value: '123.0.0', expected: new Version(123, 0, 0) }, + { scenario: 'major bump', value: 'v2.0.0', expected: new Version(2, 0, 0) }, + { scenario: 'has dash', value: '14.10-1', expected: new Version(14, 10, 1) }, + { scenario: 'missing patch', value: '14.10', expected: new Version(14, 10, 0) }, + { scenario: 'only major', value: '14', expected: new Version(14, 0, 0) }, ]; for (const { scenario, value, expected } of tests) { it(`should correctly parse ${scenario}`, () => { - const actual = ServerVersion.fromString(value); + const actual = Version.fromString(value); expect(actual.major).toEqual(expected.major); expect(actual.minor).toEqual(expected.minor); expect(actual.patch).toEqual(expected.patch); diff --git a/server/src/domain/domain.constant.ts b/server/src/domain/domain.constant.ts index a6ac96ba74f..557cd4b18b3 100644 --- a/server/src/domain/domain.constant.ts +++ b/server/src/domain/domain.constant.ts @@ -6,17 +6,17 @@ import pkg from 'src/../../package.json'; export const AUDIT_LOG_MAX_DURATION = Duration.fromObject({ days: 100 }); export const ONE_HOUR = Duration.fromObject({ hours: 1 }); -export interface IServerVersion { +export interface IVersion { major: number; minor: number; patch: number; } -export class ServerVersion implements IServerVersion { +export class Version implements IVersion { constructor( public readonly major: number, - public readonly minor: number, - public readonly patch: number, + public readonly minor: number = 0, + public readonly patch: number = 0, ) {} toString() { @@ -28,33 +28,45 @@ export class ServerVersion implements IServerVersion { return { major, minor, patch }; } - static fromString(version: string): ServerVersion { - const regex = /(?:v)?(?\d+)\.(?\d+)\.(?\d+)/i; + static fromString(version: string): Version { + const regex = /(?:v)?(?\d+)(?:\.(?\d+))?(?:[\.-](?\d+))?/i; const matchResult = version.match(regex); if (matchResult) { - const [, major, minor, patch] = matchResult.map(Number); - return new ServerVersion(major, minor, patch); + const { major, minor = '0', patch = '0' } = matchResult.groups as { [K in keyof IVersion]: string }; + return new Version(Number(major), Number(minor), Number(patch)); } else { throw new Error(`Invalid version format: ${version}`); } } - isNewerThan(version: ServerVersion): boolean { - const equalMajor = this.major === version.major; - const equalMinor = this.minor === version.minor; + compare(version: Version): number { + for (const key of ['major', 'minor', 'patch'] as const) { + const diff = this[key] - version[key]; + if (diff !== 0) { + return diff > 0 ? 1 : -1; + } + } - return ( - this.major > version.major || - (equalMajor && this.minor > version.minor) || - (equalMajor && equalMinor && this.patch > version.patch) - ); + return 0; + } + + isOlderThan(version: Version): boolean { + return this.compare(version) < 0; + } + + isEqual(version: Version): boolean { + return this.compare(version) === 0; + } + + isNewerThan(version: Version): boolean { + return this.compare(version) > 0; } } export const envName = (process.env.NODE_ENV || 'development').toUpperCase(); export const isDev = process.env.NODE_ENV === 'development'; -export const serverVersion = ServerVersion.fromString(pkg.version); +export const serverVersion = Version.fromString(pkg.version); export const APP_MEDIA_LOCATION = process.env.IMMICH_MEDIA_LOCATION || './upload'; diff --git a/server/src/domain/domain.module.ts b/server/src/domain/domain.module.ts index 5851d3a9080..4cd4b77dbf0 100644 --- a/server/src/domain/domain.module.ts +++ b/server/src/domain/domain.module.ts @@ -6,6 +6,7 @@ import { APIKeyService } from './api-key'; import { AssetService } from './asset'; import { AuditService } from './audit'; import { AuthService } from './auth'; +import { DatabaseService } from './database'; import { JobService } from './job'; import { LibraryService } from './library'; import { MediaService } from './media'; @@ -29,6 +30,7 @@ const providers: Provider[] = [ AssetService, AuditService, AuthService, + DatabaseService, JobService, MediaService, MetadataService, diff --git a/server/src/domain/index.ts b/server/src/domain/index.ts index e76159d400b..ca3d4ced465 100644 --- a/server/src/domain/index.ts +++ b/server/src/domain/index.ts @@ -5,6 +5,7 @@ export * from './api-key'; export * from './asset'; export * from './audit'; export * from './auth'; +export * from './database'; export * from './domain.config'; export * from './domain.constant'; export * from './domain.module'; diff --git a/server/src/domain/repositories/database.repository.ts b/server/src/domain/repositories/database.repository.ts new file mode 100644 index 00000000000..9075a0a9c77 --- /dev/null +++ b/server/src/domain/repositories/database.repository.ts @@ -0,0 +1,16 @@ +import { Version } from '../domain.constant'; + +export enum DatabaseExtension { + CUBE = 'cube', + EARTH_DISTANCE = 'earthdistance', + VECTORS = 'vectors', +} + +export const IDatabaseRepository = 'IDatabaseRepository'; + +export interface IDatabaseRepository { + getExtensionVersion(extName: string): Promise; + getPostgresVersion(): Promise; + createExtension(extension: DatabaseExtension): Promise; + runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise; +} diff --git a/server/src/domain/repositories/index.ts b/server/src/domain/repositories/index.ts index f812e6ee595..997cd4e8b5a 100644 --- a/server/src/domain/repositories/index.ts +++ b/server/src/domain/repositories/index.ts @@ -6,6 +6,7 @@ export * from './asset.repository'; export * from './audit.repository'; export * from './communication.repository'; export * from './crypto.repository'; +export * from './database.repository'; export * from './job.repository'; export * from './library.repository'; export * from './machine-learning.repository'; diff --git a/server/src/domain/server-info/server-info.dto.ts b/server/src/domain/server-info/server-info.dto.ts index b5b693a34ba..e8c68d559bf 100644 --- a/server/src/domain/server-info/server-info.dto.ts +++ b/server/src/domain/server-info/server-info.dto.ts @@ -1,4 +1,4 @@ -import { FeatureFlags, IServerVersion } from '@app/domain'; +import { FeatureFlags, IVersion } from '@app/domain'; import { ApiProperty, ApiResponseProperty } from '@nestjs/swagger'; import { SystemConfigThemeDto } from '../system-config/dto/system-config-theme.dto'; @@ -25,7 +25,7 @@ export class ServerInfoResponseDto { diskUsagePercentage!: number; } -export class ServerVersionResponseDto implements IServerVersion { +export class ServerVersionResponseDto implements IVersion { @ApiProperty({ type: 'integer' }) major!: number; @ApiProperty({ type: 'integer' }) diff --git a/server/src/domain/server-info/server-info.service.ts b/server/src/domain/server-info/server-info.service.ts index 014dbfc8da6..5a1bed5674a 100644 --- a/server/src/domain/server-info/server-info.service.ts +++ b/server/src/domain/server-info/server-info.service.ts @@ -1,7 +1,7 @@ import { ImmichLogger } from '@app/infra/logger'; import { Inject, Injectable } from '@nestjs/common'; import { DateTime } from 'luxon'; -import { ServerVersion, isDev, mimeTypes, serverVersion } from '../domain.constant'; +import { Version, isDev, mimeTypes, serverVersion } from '../domain.constant'; import { asHumanReadable } from '../domain.util'; import { ClientEvent, @@ -138,7 +138,7 @@ export class ServerInfoService { } const githubRelease = await this.repository.getGitHubRelease(); - const githubVersion = ServerVersion.fromString(githubRelease.tag_name); + const githubVersion = Version.fromString(githubRelease.tag_name); const publishedAt = new Date(githubRelease.published_at); this.releaseVersion = githubVersion; this.releaseVersionCheckedAt = DateTime.now(); diff --git a/server/src/immich/app.service.ts b/server/src/immich/app.service.ts index 79da7b2bed3..1ac1dd9ffcc 100644 --- a/server/src/immich/app.service.ts +++ b/server/src/immich/app.service.ts @@ -1,5 +1,6 @@ import { AuthService, + DatabaseService, JobService, ONE_HOUR, OpenGraphTags, @@ -46,6 +47,7 @@ export class AppService { private serverService: ServerInfoService, private sharedLinkService: SharedLinkService, private storageService: StorageService, + private databaseService: DatabaseService, ) {} @Interval(ONE_HOUR.as('milliseconds')) @@ -59,6 +61,7 @@ export class AppService { } async init() { + await this.databaseService.init(); await this.configService.init(); this.storageService.init(); await this.serverService.handleVersionCheck(); diff --git a/server/src/immich/main.ts b/server/src/immich/main.ts index 2b2158ce62c..05f3f038117 100644 --- a/server/src/immich/main.ts +++ b/server/src/immich/main.ts @@ -1,5 +1,5 @@ import { envName, isDev, serverVersion } from '@app/domain'; -import { WebSocketAdapter, databaseChecks } from '@app/infra'; +import { WebSocketAdapter } from '@app/infra'; import { ImmichLogger } from '@app/infra/logger'; import { NestFactory } from '@nestjs/core'; import { NestExpressApplication } from '@nestjs/platform-express'; @@ -31,8 +31,6 @@ export async function bootstrap() { app.useStaticAssets('www'); app.use(app.get(AppService).ssr(excludePaths)); - await databaseChecks(); - const server = await app.listen(port); server.requestTimeout = 30 * 60 * 1000; diff --git a/server/src/infra/database.config.ts b/server/src/infra/database.config.ts index eb41b17bb38..88a46cd068a 100644 --- a/server/src/infra/database.config.ts +++ b/server/src/infra/database.config.ts @@ -1,4 +1,4 @@ -import { DataSource, QueryRunner } from 'typeorm'; +import { DataSource } from 'typeorm'; import { PostgresConnectionOptions } from 'typeorm/driver/postgres/PostgresConnectionOptions'; const url = process.env.DB_URL; @@ -18,58 +18,10 @@ export const databaseConfig: PostgresConnectionOptions = { synchronize: false, migrations: [__dirname + '/migrations/*.{js,ts}'], subscribers: [__dirname + '/subscribers/*.{js,ts}'], - migrationsRun: true, + migrationsRun: false, connectTimeoutMS: 10000, // 10 seconds ...urlOrParts, }; // this export is used by TypeORM commands in package.json#scripts export const dataSource = new DataSource(databaseConfig); - -export async function databaseChecks() { - if (!dataSource.isInitialized) { - await dataSource.initialize(); - } - - await assertVectors(dataSource); - await enablePrefilter(dataSource); - await dataSource.runMigrations(); -} - -export async function enablePrefilter(runner: DataSource | QueryRunner) { - await runner.query(`SET vectors.enable_prefilter = on`); -} - -export async function getExtensionVersion(extName: string, runner: DataSource | QueryRunner): Promise { - const res = await runner.query(`SELECT extversion FROM pg_extension WHERE extname = $1`, [extName]); - return res[0]?.['extversion'] ?? null; -} - -export async function getPostgresVersion(runner: DataSource | QueryRunner): Promise { - const res = await runner.query(`SHOW server_version`); - return res[0]['server_version'].split('.')[0]; -} - -export async function assertVectors(runner: DataSource | QueryRunner) { - const postgresVersion = await getPostgresVersion(runner); - const expected = ['0.1.1', '0.1.11']; - const image = `tensorchord/pgvecto-rs:pg${postgresVersion}-v${expected[expected.length - 1]}`; - - await runner.query('CREATE EXTENSION IF NOT EXISTS vectors').catch((err) => { - console.error( - 'Failed to create pgvecto.rs extension. ' + - `If you have not updated your Postgres instance to an image that supports pgvecto.rs (such as '${image}'), please do so. ` + - 'See the v1.91.0 release notes for more info: https://github.com/immich-app/immich/releases/tag/v1.91.0', - ); - throw err; - }); - - const version = await getExtensionVersion('vectors', runner); - if (version != null && !expected.includes(version)) { - throw new Error( - `The pgvecto.rs extension version is ${version} instead of the expected version ${ - expected[expected.length - 1] - }.` + `If you're using the 'latest' tag, please switch to '${image}'.`, - ); - } -} diff --git a/server/src/infra/infra.module.ts b/server/src/infra/infra.module.ts index 4d5cadeb635..aef41f5627e 100644 --- a/server/src/infra/infra.module.ts +++ b/server/src/infra/infra.module.ts @@ -6,6 +6,7 @@ import { IAuditRepository, ICommunicationRepository, ICryptoRepository, + IDatabaseRepository, IJobRepository, IKeyRepository, ILibraryRepository, @@ -43,6 +44,7 @@ import { AuditRepository, CommunicationRepository, CryptoRepository, + DatabaseRepository, FilesystemProvider, JobRepository, LibraryRepository, @@ -70,6 +72,7 @@ const providers: Provider[] = [ { provide: IAuditRepository, useClass: AuditRepository }, { provide: ICommunicationRepository, useClass: CommunicationRepository }, { provide: ICryptoRepository, useClass: CryptoRepository }, + { provide: IDatabaseRepository, useClass: DatabaseRepository }, { provide: IJobRepository, useClass: JobRepository }, { provide: ILibraryRepository, useClass: LibraryRepository }, { provide: IKeyRepository, useClass: ApiKeyRepository }, diff --git a/server/src/infra/logger.ts b/server/src/infra/logger.ts index c059111d219..183ffb492fb 100644 --- a/server/src/infra/logger.ts +++ b/server/src/infra/logger.ts @@ -5,7 +5,7 @@ import { LogLevel } from './entities'; const LOG_LEVELS = [LogLevel.VERBOSE, LogLevel.DEBUG, LogLevel.LOG, LogLevel.WARN, LogLevel.ERROR, LogLevel.FATAL]; export class ImmichLogger extends ConsoleLogger { - private static logLevels: LogLevel[] = []; + private static logLevels: LogLevel[] = [LogLevel.WARN, LogLevel.ERROR, LogLevel.FATAL]; constructor(context: string) { super(context); diff --git a/server/src/infra/migrations/1700713871511-UsePgVectors.ts b/server/src/infra/migrations/1700713871511-UsePgVectors.ts index 9b13f83643f..9f8a72cff34 100644 --- a/server/src/infra/migrations/1700713871511-UsePgVectors.ts +++ b/server/src/infra/migrations/1700713871511-UsePgVectors.ts @@ -1,13 +1,11 @@ import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant'; import { MigrationInterface, QueryRunner } from 'typeorm'; -import { assertVectors } from '../database.config'; export class UsePgVectors1700713871511 implements MigrationInterface { name = 'UsePgVectors1700713871511'; public async up(queryRunner: QueryRunner): Promise { - await assertVectors(queryRunner); - + await queryRunner.query(`CREATE EXTENSION IF NOT EXISTS vectors`); const faceDimQuery = await queryRunner.query(` SELECT CARDINALITY(embedding::real[]) as dimsize FROM asset_faces diff --git a/server/src/infra/repositories/database.repository.ts b/server/src/infra/repositories/database.repository.ts new file mode 100644 index 00000000000..71875342efc --- /dev/null +++ b/server/src/infra/repositories/database.repository.ts @@ -0,0 +1,28 @@ +import { DatabaseExtension, IDatabaseRepository, Version } from '@app/domain'; +import { Injectable } from '@nestjs/common'; +import { InjectDataSource } from '@nestjs/typeorm'; +import { DataSource } from 'typeorm'; + +@Injectable() +export class DatabaseRepository implements IDatabaseRepository { + constructor(@InjectDataSource() private dataSource: DataSource) {} + + async getExtensionVersion(extension: DatabaseExtension): Promise { + const res = await this.dataSource.query(`SELECT extversion FROM pg_extension WHERE extname = $1`, [extension]); + const version = res[0]?.['extversion']; + return version == null ? null : Version.fromString(version); + } + + async getPostgresVersion(): Promise { + const res = await this.dataSource.query(`SHOW server_version`); + return Version.fromString(res[0]['server_version']); + } + + async createExtension(extension: DatabaseExtension): Promise { + await this.dataSource.query(`CREATE EXTENSION IF NOT EXISTS ${extension}`); + } + + async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise { + await this.dataSource.runMigrations(options); + } +} diff --git a/server/src/infra/repositories/index.ts b/server/src/infra/repositories/index.ts index 0fd0070e23d..63b8f2afb27 100644 --- a/server/src/infra/repositories/index.ts +++ b/server/src/infra/repositories/index.ts @@ -6,6 +6,7 @@ export * from './asset.repository'; export * from './audit.repository'; export * from './communication.repository'; export * from './crypto.repository'; +export * from './database.repository'; export * from './filesystem.provider'; export * from './job.repository'; export * from './library.repository'; diff --git a/server/src/infra/repositories/smart-info.repository.ts b/server/src/infra/repositories/smart-info.repository.ts index 340ffa99902..73679af1103 100644 --- a/server/src/infra/repositories/smart-info.repository.ts +++ b/server/src/infra/repositories/smart-info.repository.ts @@ -52,6 +52,7 @@ export class SmartInfoRepository implements ISmartInfoRepository { let results: AssetEntity[] = []; await this.assetRepository.manager.transaction(async (manager) => { await manager.query(`SET LOCAL vectors.k = '${numResults}'`); + await manager.query(`SET LOCAL vectors.enable_prefilter = on`); results = await manager .createQueryBuilder(AssetEntity, 'a') .innerJoin('a.smartSearch', 's') diff --git a/server/src/infra/sql/smart.info.repository.sql b/server/src/infra/sql/smart.info.repository.sql index c03168d879c..b2cb551a611 100644 --- a/server/src/infra/sql/smart.info.repository.sql +++ b/server/src/infra/sql/smart.info.repository.sql @@ -4,6 +4,8 @@ START TRANSACTION SET LOCAL vectors.k = '100' +SET + LOCAL vectors.enable_prefilter = on SELECT "a"."id" AS "a_id", "a"."deviceAssetId" AS "a_deviceAssetId", diff --git a/server/src/microservices/app.service.ts b/server/src/microservices/app.service.ts index 32a1c114840..ef8af48aca8 100644 --- a/server/src/microservices/app.service.ts +++ b/server/src/microservices/app.service.ts @@ -1,6 +1,7 @@ import { AssetService, AuditService, + DatabaseService, IDeleteFilesJob, JobName, JobService, @@ -31,9 +32,11 @@ export class AppService { private storageTemplateService: StorageTemplateService, private storageService: StorageService, private userService: UserService, + private databaseService: DatabaseService, ) {} async init() { + await this.databaseService.init(); await this.configService.init(); await this.jobService.registerHandlers({ [JobName.ASSET_DELETION]: (data) => this.assetService.handleAssetDeletion(data), diff --git a/server/src/microservices/main.ts b/server/src/microservices/main.ts index c50fa94252f..0a85cb37a48 100644 --- a/server/src/microservices/main.ts +++ b/server/src/microservices/main.ts @@ -1,5 +1,5 @@ import { envName, serverVersion } from '@app/domain'; -import { WebSocketAdapter, databaseChecks } from '@app/infra'; +import { WebSocketAdapter } from '@app/infra'; import { ImmichLogger } from '@app/infra/logger'; import { NestFactory } from '@nestjs/core'; import { MicroservicesModule } from './microservices.module'; @@ -12,7 +12,6 @@ export async function bootstrap() { app.useLogger(app.get(ImmichLogger)); app.useWebSocketAdapter(new WebSocketAdapter(app)); - await databaseChecks(); await app.listen(port); diff --git a/server/test/repositories/database.repository.mock.ts b/server/test/repositories/database.repository.mock.ts new file mode 100644 index 00000000000..b68a6c277db --- /dev/null +++ b/server/test/repositories/database.repository.mock.ts @@ -0,0 +1,10 @@ +import { IDatabaseRepository, Version } from '@app/domain'; + +export const newDatabaseRepositoryMock = (): jest.Mocked => { + return { + getExtensionVersion: jest.fn(), + getPostgresVersion: jest.fn().mockResolvedValue(new Version(14, 0, 0)), + createExtension: jest.fn().mockImplementation(() => Promise.resolve()), + runMigrations: jest.fn(), + }; +}; diff --git a/server/test/repositories/index.ts b/server/test/repositories/index.ts index b68de4ba221..f625dc5213d 100644 --- a/server/test/repositories/index.ts +++ b/server/test/repositories/index.ts @@ -5,6 +5,7 @@ export * from './asset.repository.mock'; export * from './audit.repository.mock'; export * from './communication.repository.mock'; export * from './crypto.repository.mock'; +export * from './database.repository.mock'; export * from './job.repository.mock'; export * from './library.repository.mock'; export * from './machine-learning.repository.mock'; diff --git a/server/test/test-utils.ts b/server/test/test-utils.ts index 9fac33427e1..c4167fb1c26 100644 --- a/server/test/test-utils.ts +++ b/server/test/test-utils.ts @@ -1,6 +1,6 @@ import { AssetCreate, IJobRepository, JobItem, JobItemHandler, LibraryResponseDto, QueueName } from '@app/domain'; import { AppModule } from '@app/immich'; -import { dataSource, databaseChecks } from '@app/infra'; +import { dataSource } from '@app/infra'; import { AssetEntity, AssetType, LibraryType } from '@app/infra/entities'; import { INestApplication } from '@nestjs/common'; import { Test } from '@nestjs/testing'; @@ -24,7 +24,9 @@ export interface ResetOptions { } export const db = { reset: async (options?: ResetOptions) => { - await databaseChecks(); + if (!dataSource.isInitialized) { + await dataSource.initialize(); + } await dataSource.transaction(async (em) => { const entities = options?.entities || []; const tableNames = @@ -87,10 +89,7 @@ export const testApp = { app = await moduleFixture.createNestApplication().init(); await app.listen(0); - - if (jobs) { - await app.get(AppService).init(); - } + await app.get(AppService).init(); const port = app.getHttpServer().address().port; const protocol = app instanceof Server ? 'https' : 'http'; @@ -99,6 +98,7 @@ export const testApp = { return app; }, reset: async (options?: ResetOptions) => { + await app.get(AppService).init(); await db.reset(options); }, teardown: async () => {