From 5038552ed9966e7f5823577a83509f4abd7e6d5e Mon Sep 17 00:00:00 2001 From: Mohammed Tawileh <38083272+mktawileh@users.noreply.github.com> Date: Thu, 7 Nov 2024 05:36:28 +0300 Subject: [PATCH] fix: improve embedding model validation logic for dataset operations (#3235) What problem does this PR solve? When creating or updating datasets with custom embedding models (e.g., Ollama), the validation logic was too restrictive and prevented valid models from being used. The previous implementation would reject valid custom models if they weren't in the predefined list, even when they existed in TenantLLMService. Changes: - Simplify and improve the embedding model validation flow in create/update endpoints - Check TenantLLMService for custom models before rejecting - Make validation logic more consistent between create and update operations ### What problem does this PR solve? This fix allows users to successfully create and update datasets with custom embedding models while maintaining proper validation checks. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Jin Hai Co-authored-by: Kevin Hu Co-authored-by: liuhua <10215101452@stu.ecnu.edu.cn> --- api/apps/sdk/dataset.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 6cf0ed5a7..0349a3bcd 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -159,21 +159,15 @@ def create(tenant_id): embd_model = LLMService.query( llm_name=req["embedding_model"], model_type="embedding" ) + if embd_model: + if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding",llm_name=req.get("embedding_model"),): + return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") + if not embd_model: + embd_model=TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")) if not embd_model: return get_error_data_result( f"`embedding_model` {req.get('embedding_model')} doesn't exist" ) - if embd_model: - if req[ - "embedding_model" - ] not in valid_embedding_models and not TenantLLMService.query( - tenant_id=tenant_id, - model_type="embedding", - llm_name=req.get("embedding_model"), - ): - return get_error_data_result( - f"`embedding_model` {req.get('embedding_model')} doesn't exist" - ) key_mapping = { "chunk_num": "chunk_count", "doc_num": "document_count", @@ -403,21 +397,16 @@ def update(tenant_id, dataset_id): embd_model = LLMService.query( llm_name=req["embedding_model"], model_type="embedding" ) + if embd_model: + if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding",llm_name=req.get("embedding_model"),): + return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") + if not embd_model: + embd_model=TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")) + if not embd_model: return get_error_data_result( f"`embedding_model` {req.get('embedding_model')} doesn't exist" ) - if embd_model: - if req[ - "embedding_model" - ] not in valid_embedding_models and not TenantLLMService.query( - tenant_id=tenant_id, - model_type="embedding", - llm_name=req.get("embedding_model"), - ): - return get_error_data_result( - f"`embedding_model` {req.get('embedding_model')} doesn't exist" - ) req["embd_id"] = req.pop("embedding_model") if "name" in req: req["name"] = req["name"].strip()