diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 6cf0ed5a7..0349a3bcd 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -159,21 +159,15 @@ def create(tenant_id): embd_model = LLMService.query( llm_name=req["embedding_model"], model_type="embedding" ) + if embd_model: + if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding",llm_name=req.get("embedding_model"),): + return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") + if not embd_model: + embd_model=TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")) if not embd_model: return get_error_data_result( f"`embedding_model` {req.get('embedding_model')} doesn't exist" ) - if embd_model: - if req[ - "embedding_model" - ] not in valid_embedding_models and not TenantLLMService.query( - tenant_id=tenant_id, - model_type="embedding", - llm_name=req.get("embedding_model"), - ): - return get_error_data_result( - f"`embedding_model` {req.get('embedding_model')} doesn't exist" - ) key_mapping = { "chunk_num": "chunk_count", "doc_num": "document_count", @@ -403,21 +397,16 @@ def update(tenant_id, dataset_id): embd_model = LLMService.query( llm_name=req["embedding_model"], model_type="embedding" ) + if embd_model: + if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding",llm_name=req.get("embedding_model"),): + return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") + if not embd_model: + embd_model=TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")) + if not embd_model: return get_error_data_result( f"`embedding_model` {req.get('embedding_model')} doesn't exist" ) - if embd_model: - if req[ - "embedding_model" - ] not in valid_embedding_models and not TenantLLMService.query( - tenant_id=tenant_id, - model_type="embedding", - llm_name=req.get("embedding_model"), - ): - return get_error_data_result( - f"`embedding_model` {req.get('embedding_model')} doesn't exist" - ) req["embd_id"] = req.pop("embedding_model") if "name" in req: req["name"] = req["name"].strip()