From 6789c2ac197b9bcf39b318d9b9e1e63d6e896f9b Mon Sep 17 00:00:00 2001
From: Mert <101130780+mertalev@users.noreply.github.com>
Date: Mon, 31 Mar 2025 11:06:57 -0400
Subject: [PATCH] feat(ml): better multilingual search with nllb models
(#13567)
---
...t-serach.webp => mobile-smart-search.webp} | Bin
docs/docs/features/searching.md | 17 +++-
.../immich_ml/models/clip/textual.py | 22 +++--
.../immich_ml/models/constants.py | 60 +++++++++++++
machine-learning/test_main.py | 82 ++++++++++++++++++
.../models/search/search_filter.model.dart | 8 +-
mobile/lib/pages/search/search.page.dart | 2 +
mobile/lib/services/search.service.dart | 1 +
.../openapi/lib/model/smart_search_dto.dart | 19 +++-
open-api/immich-openapi-specs.json | 3 +
open-api/typescript-sdk/src/fetch-client.ts | 1 +
server/src/dtos/search.dto.ts | 5 ++
.../machine-learning.repository.ts | 5 +-
server/src/services/search.service.spec.ts | 81 +++++++++++++++++
server/src/services/search.service.ts | 10 +--
.../[[assetId=id]]/+page.svelte | 3 +-
16 files changed, 301 insertions(+), 18 deletions(-)
rename docs/docs/features/img/{moblie-smart-serach.webp => mobile-smart-search.webp} (100%)
diff --git a/docs/docs/features/img/moblie-smart-serach.webp b/docs/docs/features/img/mobile-smart-search.webp
similarity index 100%
rename from docs/docs/features/img/moblie-smart-serach.webp
rename to docs/docs/features/img/mobile-smart-search.webp
diff --git a/docs/docs/features/searching.md b/docs/docs/features/searching.md
index 15f83949f22..7c7e3872181 100644
--- a/docs/docs/features/searching.md
+++ b/docs/docs/features/searching.md
@@ -45,7 +45,7 @@ Some search examples:
-
+
@@ -56,7 +56,20 @@ Navigating to `Administration > Settings > Machine Learning Settings > Smart Sea
### CLIP models
-More powerful models can be used for more accurate search results, but are slower and can require more server resources. Check the dropdowns below to see how they compare in memory usage, speed and quality by language.
+The default search model is fast, but there are many other options that can provide better search results. The tradeoff of using these models is that they're slower and/or use more memory (both when indexing images with background Smart Search jobs and when searching).
+
+The first step of choosing the right model for you is to know which languages your users will search in.
+
+If your users will only search in English, then the [CLIP][huggingface-clip] section is the first place to look. This is a curated list of the models that generally perform the best for their size class. The models here are ordered from higher to lower quality. This means that the top models will generally rank the most relevant results higher and have a higher capacity to understand descriptive, detailed, and/or niche queries. The models are also generally ordered from larger to smaller, so consider the impact on memory usage, job processing and search speed when deciding on one. The smaller models in this list are not too different in quality and many times faster.
+
+[Multilingual models][huggingface-multilingual-clip] are also available so users can search in their native language. Use these models if you expect non-English searches to be common. They can be separated into three search patterns:
+
+- `nllb` models expect the search query to be in the language specified in the user settings
+- `xlm` and `siglip2` models understand search text regardless of the current language setting
+
+`nllb` models tend to perform the best and are recommended when users primarily searches in their native, non-English language. `xlm` and `siglip2` models are more flexible and are recommended for mixed language search, where the same user might search in different languages at different times.
+
+For more details, check the tables below to see how they compare in memory usage, speed and quality by language.
Once you've chosen a model, follow these steps:
diff --git a/machine-learning/immich_ml/models/clip/textual.py b/machine-learning/immich_ml/models/clip/textual.py
index 603cd294000..c1b3a9eba44 100644
--- a/machine-learning/immich_ml/models/clip/textual.py
+++ b/machine-learning/immich_ml/models/clip/textual.py
@@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer
from immich_ml.config import log
from immich_ml.models.base import InferenceModel
+from immich_ml.models.constants import WEBLATE_TO_FLORES200
from immich_ml.models.transforms import clean_text, serialize_np_array
from immich_ml.schemas import ModelSession, ModelTask, ModelType
@@ -18,8 +19,9 @@ class BaseCLIPTextualEncoder(InferenceModel):
depends = []
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
- def _predict(self, inputs: str, **kwargs: Any) -> str:
- res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0]
+ def _predict(self, inputs: str, language: str | None = None, **kwargs: Any) -> str:
+ tokens = self.tokenize(inputs, language=language)
+ res: NDArray[np.float32] = self.session.run(None, tokens)[0][0]
return serialize_np_array(res)
def _load(self) -> ModelSession:
@@ -28,6 +30,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
self.tokenizer = self._load_tokenizer()
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
+ self.is_nllb = self.model_name.startswith("nllb")
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
return session
@@ -37,7 +40,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
pass
@abstractmethod
- def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
+ def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
pass
@property
@@ -92,14 +95,23 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
return tokenizer
- def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
+ def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
text = clean_text(text, canonicalize=self.canonicalize)
+ if self.is_nllb and language is not None:
+ flores_code = WEBLATE_TO_FLORES200.get(language)
+ if flores_code is None:
+ no_country = language.split("-")[0]
+ flores_code = WEBLATE_TO_FLORES200.get(no_country)
+ if flores_code is None:
+ log.warning(f"Language '{language}' not found, defaulting to 'en'")
+ flores_code = "eng_Latn"
+ text = f"{flores_code}{text}"
tokens: Encoding = self.tokenizer.encode(text)
return {"text": np.array([tokens.ids], dtype=np.int32)}
class MClipTextualEncoder(OpenClipTextualEncoder):
- def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
+ def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
text = clean_text(text, canonicalize=self.canonicalize)
tokens: Encoding = self.tokenizer.encode(text)
return {
diff --git a/machine-learning/immich_ml/models/constants.py b/machine-learning/immich_ml/models/constants.py
index 85b5b539917..41b0990f71f 100644
--- a/machine-learning/immich_ml/models/constants.py
+++ b/machine-learning/immich_ml/models/constants.py
@@ -86,6 +86,66 @@ RKNN_SUPPORTED_SOCS = ["rk3566", "rk3568", "rk3576", "rk3588"]
RKNN_COREMASK_SUPPORTED_SOCS = ["rk3576", "rk3588"]
+WEBLATE_TO_FLORES200 = {
+ "af": "afr_Latn",
+ "ar": "arb_Arab",
+ "az": "azj_Latn",
+ "be": "bel_Cyrl",
+ "bg": "bul_Cyrl",
+ "ca": "cat_Latn",
+ "cs": "ces_Latn",
+ "da": "dan_Latn",
+ "de": "deu_Latn",
+ "el": "ell_Grek",
+ "en": "eng_Latn",
+ "es": "spa_Latn",
+ "et": "est_Latn",
+ "fa": "pes_Arab",
+ "fi": "fin_Latn",
+ "fr": "fra_Latn",
+ "he": "heb_Hebr",
+ "hi": "hin_Deva",
+ "hr": "hrv_Latn",
+ "hu": "hun_Latn",
+ "hy": "hye_Armn",
+ "id": "ind_Latn",
+ "it": "ita_Latn",
+ "ja": "jpn_Hira",
+ "kmr": "kmr_Latn",
+ "ko": "kor_Hang",
+ "lb": "ltz_Latn",
+ "lt": "lit_Latn",
+ "lv": "lav_Latn",
+ "mfa": "zsm_Latn",
+ "mk": "mkd_Cyrl",
+ "mn": "khk_Cyrl",
+ "mr": "mar_Deva",
+ "ms": "zsm_Latn",
+ "nb-NO": "nob_Latn",
+ "nn": "nno_Latn",
+ "nl": "nld_Latn",
+ "pl": "pol_Latn",
+ "pt-BR": "por_Latn",
+ "pt": "por_Latn",
+ "ro": "ron_Latn",
+ "ru": "rus_Cyrl",
+ "sk": "slk_Latn",
+ "sl": "slv_Latn",
+ "sr-Cyrl": "srp_Cyrl",
+ "sv": "swe_Latn",
+ "ta": "tam_Taml",
+ "te": "tel_Telu",
+ "th": "tha_Thai",
+ "tr": "tur_Latn",
+ "uk": "ukr_Cyrl",
+ "ur": "urd_Arab",
+ "vi": "vie_Latn",
+ "zh-CN": "zho_Hans",
+ "zh-Hans": "zho_Hans",
+ "zh-TW": "zho_Hant",
+}
+
+
def get_model_source(model_name: str) -> ModelSource | None:
cleaned_name = clean_name(model_name)
diff --git a/machine-learning/test_main.py b/machine-learning/test_main.py
index 4a3696f320b..a19ec65c5f4 100644
--- a/machine-learning/test_main.py
+++ b/machine-learning/test_main.py
@@ -494,6 +494,88 @@ class TestCLIP:
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
mock_tokenizer.encode.assert_called_once_with("test search query")
+ def test_openclip_tokenizer_adds_flores_token_for_nllb(
+ self,
+ mocker: MockerFixture,
+ clip_model_cfg: dict[str, Any],
+ clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
+ ) -> None:
+ mocker.patch.object(OpenClipTextualEncoder, "download")
+ mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
+ mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
+ mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
+ mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
+ mock_ids = [randint(0, 50000) for _ in range(77)]
+ mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
+
+ clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
+ clip_encoder._load()
+ clip_encoder.tokenize("test search query", language="de")
+
+ mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
+
+ def test_openclip_tokenizer_removes_country_code_from_language_for_nllb_if_not_found(
+ self,
+ mocker: MockerFixture,
+ clip_model_cfg: dict[str, Any],
+ clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
+ ) -> None:
+ mocker.patch.object(OpenClipTextualEncoder, "download")
+ mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
+ mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
+ mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
+ mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
+ mock_ids = [randint(0, 50000) for _ in range(77)]
+ mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
+
+ clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
+ clip_encoder._load()
+ clip_encoder.tokenize("test search query", language="de-CH")
+
+ mock_tokenizer.encode.assert_called_once_with("deu_Latntest search query")
+
+ def test_openclip_tokenizer_falls_back_to_english_for_nllb_if_language_code_not_found(
+ self,
+ mocker: MockerFixture,
+ clip_model_cfg: dict[str, Any],
+ clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
+ warning: mock.Mock,
+ ) -> None:
+ mocker.patch.object(OpenClipTextualEncoder, "download")
+ mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
+ mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
+ mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
+ mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
+ mock_ids = [randint(0, 50000) for _ in range(77)]
+ mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
+
+ clip_encoder = OpenClipTextualEncoder("nllb-clip-base-siglip__mrl", cache_dir="test_cache")
+ clip_encoder._load()
+ clip_encoder.tokenize("test search query", language="unknown")
+
+ mock_tokenizer.encode.assert_called_once_with("eng_Latntest search query")
+ warning.assert_called_once_with("Language 'unknown' not found, defaulting to 'en'")
+
+ def test_openclip_tokenizer_does_not_add_flores_token_for_non_nllb_model(
+ self,
+ mocker: MockerFixture,
+ clip_model_cfg: dict[str, Any],
+ clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
+ ) -> None:
+ mocker.patch.object(OpenClipTextualEncoder, "download")
+ mocker.patch.object(OpenClipTextualEncoder, "model_cfg", clip_model_cfg)
+ mocker.patch.object(OpenClipTextualEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
+ mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
+ mock_tokenizer = mocker.patch("immich_ml.models.clip.textual.Tokenizer.from_file", autospec=True).return_value
+ mock_ids = [randint(0, 50000) for _ in range(77)]
+ mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
+
+ clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
+ clip_encoder._load()
+ clip_encoder.tokenize("test search query", language="de")
+
+ mock_tokenizer.encode.assert_called_once_with("test search query")
+
def test_mclip_tokenizer(
self,
mocker: MockerFixture,
diff --git a/mobile/lib/models/search/search_filter.model.dart b/mobile/lib/models/search/search_filter.model.dart
index 87e7b24e346..598b71ef4eb 100644
--- a/mobile/lib/models/search/search_filter.model.dart
+++ b/mobile/lib/models/search/search_filter.model.dart
@@ -236,6 +236,7 @@ class SearchFilter {
String? context;
String? filename;
String? description;
+ String? language;
Set people;
SearchLocationFilter location;
SearchCameraFilter camera;
@@ -249,6 +250,7 @@ class SearchFilter {
this.context,
this.filename,
this.description,
+ this.language,
required this.people,
required this.location,
required this.camera,
@@ -279,6 +281,7 @@ class SearchFilter {
String? context,
String? filename,
String? description,
+ String? language,
Set? people,
SearchLocationFilter? location,
SearchCameraFilter? camera,
@@ -290,6 +293,7 @@ class SearchFilter {
context: context ?? this.context,
filename: filename ?? this.filename,
description: description ?? this.description,
+ language: language ?? this.language,
people: people ?? this.people,
location: location ?? this.location,
camera: camera ?? this.camera,
@@ -301,7 +305,7 @@ class SearchFilter {
@override
String toString() {
- return 'SearchFilter(context: $context, filename: $filename, description: $description, people: $people, location: $location, camera: $camera, date: $date, display: $display, mediaType: $mediaType)';
+ return 'SearchFilter(context: $context, filename: $filename, description: $description, language: $language, people: $people, location: $location, camera: $camera, date: $date, display: $display, mediaType: $mediaType)';
}
@override
@@ -311,6 +315,7 @@ class SearchFilter {
return other.context == context &&
other.filename == filename &&
other.description == description &&
+ other.language == language &&
other.people == people &&
other.location == location &&
other.camera == camera &&
@@ -324,6 +329,7 @@ class SearchFilter {
return context.hashCode ^
filename.hashCode ^
description.hashCode ^
+ language.hashCode ^
people.hashCode ^
location.hashCode ^
camera.hashCode ^
diff --git a/mobile/lib/pages/search/search.page.dart b/mobile/lib/pages/search/search.page.dart
index 9ff8caff1dc..b2bed73c6a2 100644
--- a/mobile/lib/pages/search/search.page.dart
+++ b/mobile/lib/pages/search/search.page.dart
@@ -48,6 +48,8 @@ class SearchPage extends HookConsumerWidget {
isFavorite: false,
),
mediaType: prefilter?.mediaType ?? AssetType.other,
+ language:
+ "${context.locale.languageCode}-${context.locale.countryCode}",
),
);
diff --git a/mobile/lib/services/search.service.dart b/mobile/lib/services/search.service.dart
index 4c6c80abf3c..44ace788527 100644
--- a/mobile/lib/services/search.service.dart
+++ b/mobile/lib/services/search.service.dart
@@ -60,6 +60,7 @@ class SearchService {
response = await _apiService.searchApi.searchSmart(
SmartSearchDto(
query: filter.context!,
+ language: filter.language,
country: filter.location.country,
state: filter.location.state,
city: filter.location.city,
diff --git a/mobile/openapi/lib/model/smart_search_dto.dart b/mobile/openapi/lib/model/smart_search_dto.dart
index f377c23f223..47c800ff095 100644
--- a/mobile/openapi/lib/model/smart_search_dto.dart
+++ b/mobile/openapi/lib/model/smart_search_dto.dart
@@ -25,6 +25,7 @@ class SmartSearchDto {
this.isNotInAlbum,
this.isOffline,
this.isVisible,
+ this.language,
this.lensModel,
this.libraryId,
this.make,
@@ -132,6 +133,14 @@ class SmartSearchDto {
///
bool? isVisible;
+ ///
+ /// Please note: This property should have been non-nullable! Since the specification file
+ /// does not include a default value (using the "default:" property), however, the generated
+ /// source code must fall back to having a nullable type.
+ /// Consider adding a "default:" property in the specification file to hide this note.
+ ///
+ String? language;
+
String? lensModel;
String? libraryId;
@@ -271,6 +280,7 @@ class SmartSearchDto {
other.isNotInAlbum == isNotInAlbum &&
other.isOffline == isOffline &&
other.isVisible == isVisible &&
+ other.language == language &&
other.lensModel == lensModel &&
other.libraryId == libraryId &&
other.make == make &&
@@ -308,6 +318,7 @@ class SmartSearchDto {
(isNotInAlbum == null ? 0 : isNotInAlbum!.hashCode) +
(isOffline == null ? 0 : isOffline!.hashCode) +
(isVisible == null ? 0 : isVisible!.hashCode) +
+ (language == null ? 0 : language!.hashCode) +
(lensModel == null ? 0 : lensModel!.hashCode) +
(libraryId == null ? 0 : libraryId!.hashCode) +
(make == null ? 0 : make!.hashCode) +
@@ -331,7 +342,7 @@ class SmartSearchDto {
(withExif == null ? 0 : withExif!.hashCode);
@override
- String toString() => 'SmartSearchDto[city=$city, country=$country, createdAfter=$createdAfter, createdBefore=$createdBefore, deviceId=$deviceId, isArchived=$isArchived, isEncoded=$isEncoded, isFavorite=$isFavorite, isMotion=$isMotion, isNotInAlbum=$isNotInAlbum, isOffline=$isOffline, isVisible=$isVisible, lensModel=$lensModel, libraryId=$libraryId, make=$make, model=$model, page=$page, personIds=$personIds, query=$query, rating=$rating, size=$size, state=$state, tagIds=$tagIds, takenAfter=$takenAfter, takenBefore=$takenBefore, trashedAfter=$trashedAfter, trashedBefore=$trashedBefore, type=$type, updatedAfter=$updatedAfter, updatedBefore=$updatedBefore, withArchived=$withArchived, withDeleted=$withDeleted, withExif=$withExif]';
+ String toString() => 'SmartSearchDto[city=$city, country=$country, createdAfter=$createdAfter, createdBefore=$createdBefore, deviceId=$deviceId, isArchived=$isArchived, isEncoded=$isEncoded, isFavorite=$isFavorite, isMotion=$isMotion, isNotInAlbum=$isNotInAlbum, isOffline=$isOffline, isVisible=$isVisible, language=$language, lensModel=$lensModel, libraryId=$libraryId, make=$make, model=$model, page=$page, personIds=$personIds, query=$query, rating=$rating, size=$size, state=$state, tagIds=$tagIds, takenAfter=$takenAfter, takenBefore=$takenBefore, trashedAfter=$trashedAfter, trashedBefore=$trashedBefore, type=$type, updatedAfter=$updatedAfter, updatedBefore=$updatedBefore, withArchived=$withArchived, withDeleted=$withDeleted, withExif=$withExif]';
Map toJson() {
final json = {};
@@ -395,6 +406,11 @@ class SmartSearchDto {
} else {
// json[r'isVisible'] = null;
}
+ if (this.language != null) {
+ json[r'language'] = this.language;
+ } else {
+ // json[r'language'] = null;
+ }
if (this.lensModel != null) {
json[r'lensModel'] = this.lensModel;
} else {
@@ -508,6 +524,7 @@ class SmartSearchDto {
isNotInAlbum: mapValueOfType(json, r'isNotInAlbum'),
isOffline: mapValueOfType(json, r'isOffline'),
isVisible: mapValueOfType(json, r'isVisible'),
+ language: mapValueOfType(json, r'language'),
lensModel: mapValueOfType(json, r'lensModel'),
libraryId: mapValueOfType(json, r'libraryId'),
make: mapValueOfType(json, r'make'),
diff --git a/open-api/immich-openapi-specs.json b/open-api/immich-openapi-specs.json
index 5ba08ab80bf..b948ef0386a 100644
--- a/open-api/immich-openapi-specs.json
+++ b/open-api/immich-openapi-specs.json
@@ -11853,6 +11853,9 @@
"isVisible": {
"type": "boolean"
},
+ "language": {
+ "type": "string"
+ },
"lensModel": {
"nullable": true,
"type": "string"
diff --git a/open-api/typescript-sdk/src/fetch-client.ts b/open-api/typescript-sdk/src/fetch-client.ts
index 252ce9bc69f..26929ba4e65 100644
--- a/open-api/typescript-sdk/src/fetch-client.ts
+++ b/open-api/typescript-sdk/src/fetch-client.ts
@@ -924,6 +924,7 @@ export type SmartSearchDto = {
isNotInAlbum?: boolean;
isOffline?: boolean;
isVisible?: boolean;
+ language?: string;
lensModel?: string | null;
libraryId?: string | null;
make?: string;
diff --git a/server/src/dtos/search.dto.ts b/server/src/dtos/search.dto.ts
index 3589331c78c..e0b5c9b7793 100644
--- a/server/src/dtos/search.dto.ts
+++ b/server/src/dtos/search.dto.ts
@@ -191,6 +191,11 @@ export class SmartSearchDto extends BaseSearchDto {
@IsNotEmpty()
query!: string;
+ @IsString()
+ @IsNotEmpty()
+ @Optional()
+ language?: string;
+
@IsInt()
@Min(1)
@Type(() => Number)
diff --git a/server/src/repositories/machine-learning.repository.ts b/server/src/repositories/machine-learning.repository.ts
index 95aa4cff1eb..a52bc58bc35 100644
--- a/server/src/repositories/machine-learning.repository.ts
+++ b/server/src/repositories/machine-learning.repository.ts
@@ -53,6 +53,7 @@ export interface Face {
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
export type DetectedFaces = { faces: Face[] } & VisualResponse;
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
+export type TextEncodingOptions = ModelOptions & { language?: string };
@Injectable()
export class MachineLearningRepository {
@@ -170,8 +171,8 @@ export class MachineLearningRepository {
return response[ModelTask.SEARCH];
}
- async encodeText(urls: string[], text: string, { modelName }: CLIPConfig) {
- const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
+ async encodeText(urls: string[], text: string, { language, modelName }: TextEncodingOptions) {
+ const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, options: { language } } } };
const response = await this.predict(urls, { text }, request);
return response[ModelTask.SEARCH];
}
diff --git a/server/src/services/search.service.spec.ts b/server/src/services/search.service.spec.ts
index 79f3a77ebe3..51c6b55e11b 100644
--- a/server/src/services/search.service.spec.ts
+++ b/server/src/services/search.service.spec.ts
@@ -1,3 +1,4 @@
+import { BadRequestException } from '@nestjs/common';
import { mapAsset } from 'src/dtos/asset-response.dto';
import { SearchSuggestionType } from 'src/dtos/search.dto';
import { SearchService } from 'src/services/search.service';
@@ -15,6 +16,7 @@ describe(SearchService.name, () => {
beforeEach(() => {
({ sut, mocks } = newTestService(SearchService));
+ mocks.partner.getAll.mockResolvedValue([]);
});
it('should work', () => {
@@ -155,4 +157,83 @@ describe(SearchService.name, () => {
expect(mocks.search.getCameraModels).toHaveBeenCalledWith([authStub.user1.user.id], expect.anything());
});
});
+
+ describe('searchSmart', () => {
+ beforeEach(() => {
+ mocks.search.searchSmart.mockResolvedValue({ hasNextPage: false, items: [] });
+ mocks.machineLearning.encodeText.mockResolvedValue('[1, 2, 3]');
+ });
+
+ it('should raise a BadRequestException if machine learning is disabled', async () => {
+ mocks.systemMetadata.get.mockResolvedValue({
+ machineLearning: { enabled: false },
+ });
+
+ await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError(
+ new BadRequestException('Smart search is not enabled'),
+ );
+ });
+
+ it('should raise a BadRequestException if smart search is disabled', async () => {
+ mocks.systemMetadata.get.mockResolvedValue({
+ machineLearning: { clip: { enabled: false } },
+ });
+
+ await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError(
+ new BadRequestException('Smart search is not enabled'),
+ );
+ });
+
+ it('should work', async () => {
+ await sut.searchSmart(authStub.user1, { query: 'test' });
+
+ expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
+ [expect.any(String)],
+ 'test',
+ expect.objectContaining({ modelName: expect.any(String) }),
+ );
+ expect(mocks.search.searchSmart).toHaveBeenCalledWith(
+ { page: 1, size: 100 },
+ { query: 'test', embedding: '[1, 2, 3]', userIds: [authStub.user1.user.id] },
+ );
+ });
+
+ it('should consider page and size parameters', async () => {
+ await sut.searchSmart(authStub.user1, { query: 'test', page: 2, size: 50 });
+
+ expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
+ [expect.any(String)],
+ 'test',
+ expect.objectContaining({ modelName: expect.any(String) }),
+ );
+ expect(mocks.search.searchSmart).toHaveBeenCalledWith(
+ { page: 2, size: 50 },
+ expect.objectContaining({ query: 'test', embedding: '[1, 2, 3]', userIds: [authStub.user1.user.id] }),
+ );
+ });
+
+ it('should use clip model specified in config', async () => {
+ mocks.systemMetadata.get.mockResolvedValue({
+ machineLearning: { clip: { modelName: 'ViT-B-16-SigLIP__webli' } },
+ });
+
+ await sut.searchSmart(authStub.user1, { query: 'test' });
+
+ expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
+ [expect.any(String)],
+ 'test',
+ expect.objectContaining({ modelName: 'ViT-B-16-SigLIP__webli' }),
+ );
+ });
+
+ it('should use language specified in request', async () => {
+ await sut.searchSmart(authStub.user1, { query: 'test', language: 'de' });
+
+ expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
+ [expect.any(String)],
+ 'test',
+ expect.objectContaining({ language: 'de' }),
+ );
+ });
+ });
});
diff --git a/server/src/services/search.service.ts b/server/src/services/search.service.ts
index e2ad9e7f99f..1c0c0ad4906 100644
--- a/server/src/services/search.service.ts
+++ b/server/src/services/search.service.ts
@@ -78,12 +78,10 @@ export class SearchService extends BaseService {
}
const userIds = await this.getUserIdsToSearch(auth);
-
- const embedding = await this.machineLearningRepository.encodeText(
- machineLearning.urls,
- dto.query,
- machineLearning.clip,
- );
+ const embedding = await this.machineLearningRepository.encodeText(machineLearning.urls, dto.query, {
+ modelName: machineLearning.clip.modelName,
+ language: dto.language,
+ });
const page = dto.page ?? 1;
const size = dto.size || 100;
const { hasNextPage, items } = await this.searchRepository.searchSmart(
diff --git a/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte b/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte
index e5e336521c1..c750f02aedd 100644
--- a/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte
+++ b/web/src/routes/(user)/search/[[photos=photos]]/[[assetId=id]]/+page.svelte
@@ -33,7 +33,7 @@
} from '@immich/sdk';
import { mdiArrowLeft, mdiDotsVertical, mdiImageOffOutline, mdiPlus, mdiSelectAll } from '@mdi/js';
import type { Viewport } from '$lib/stores/assets-store.svelte';
- import { locale } from '$lib/stores/preferences.store';
+ import { lang, locale } from '$lib/stores/preferences.store';
import LoadingSpinner from '$lib/components/shared-components/loading-spinner.svelte';
import { handlePromiseError } from '$lib/utils';
import { parseUtcDate } from '$lib/utils/date-time';
@@ -153,6 +153,7 @@
page: nextPage,
withExif: true,
isVisible: true,
+ language: $lang,
...terms,
};