diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 998e4abaa..74b2e6c95 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -20,7 +20,7 @@ from api.db.services.dialog_service import keyword_extraction from rag.app.qa import rmPrefix, beAdoc from rag.nlp import rag_tokenizer from api.db import LLMType, ParserType -from api.db.services.llm_service import TenantLLMService +from api.db.services.llm_service import TenantLLMService, LLMBundle from api import settings import xxhash import re @@ -1331,18 +1331,14 @@ def retrieval_test(tenant_id): e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) if not e: return get_error_data_result(message="Dataset not found!") - embd_mdl = TenantLLMService.model_instance( - kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id - ) + embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id) rerank_mdl = None if req.get("rerank_id"): - rerank_mdl = TenantLLMService.model_instance( - kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"] - ) + rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"]) if req.get("keyword", False): - chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler