From 24bdedf80260eba2aff853f46406c856c3a37a20 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Wed, 10 Jan 2024 20:48:16 +0800 Subject: [PATCH] fix get embedding model provider in empty dataset (#1986) Co-authored-by: jyong --- api/core/indexing_runner.py | 36 ++++++++++++------- api/core/model_manager.py | 2 ++ .../baichuan/text_embedding/text_embedding.py | 4 +-- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 7f1b0a5147..a55729b9ff 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -221,12 +221,18 @@ class IndexingRunner: if not dataset: raise ValueError('Dataset not found.') if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': - embedding_model_instance = self.model_manager.get_model_instance( - tenant_id=tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) + if dataset.embedding_model_provider: + embedding_model_instance = self.model_manager.get_model_instance( + tenant_id=tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model + ) + else: + embedding_model_instance = self.model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) else: if indexing_technique == 'high_quality': embedding_model_instance = self.model_manager.get_default_model_instance( @@ -328,12 +334,18 @@ class IndexingRunner: if not dataset: raise ValueError('Dataset not found.') if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': - embedding_model_instance = self.model_manager.get_model_instance( - tenant_id=tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) + if dataset.embedding_model_provider: + embedding_model_instance = self.model_manager.get_model_instance( + tenant_id=tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model + ) + else: + embedding_model_instance = self.model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) else: if indexing_technique == 'high_quality': embedding_model_instance = self.model_manager.get_default_model_instance( diff --git a/api/core/model_manager.py b/api/core/model_manager.py index dab30b2c3a..c732e40995 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -179,6 +179,8 @@ class ModelManager: :param model: model name :return: """ + if not provider: + return self.get_default_model_instance(tenant_id, model_type) provider_model_bundle = self._provider_manager.get_provider_model_bundle( tenant_id=tenant_id, provider=provider, diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 8847e020b6..d0487c62b0 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -69,9 +69,9 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): raise InsufficientAccountBalance(msg) elif err == 'invalid_authentication': raise InvalidAuthenticationError(msg) - elif 'rate' in err: + elif err and 'rate' in err: raise RateLimitReachedError(msg) - elif 'internal' in err: + elif err and 'internal' in err: raise InternalServerError(msg) elif err == 'api_key_empty': raise InvalidAPIKeyError(msg)