diff --git a/api/apps/api_app.py b/api/apps/api_app.py index cc417c9a2..f66eb8067 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -18,7 +18,7 @@ import os import re from datetime import datetime, timedelta from flask import request, Response -from api.db.services.llm_service import TenantLLMService +from api.db.services.llm_service import LLMBundle from flask_login import login_required, current_user from api.db import VALID_FILE_TYPES, VALID_TASK_STATUS, FileType, LLMType, ParserType, FileSource @@ -875,14 +875,12 @@ def retrieval(): data=False, message='Knowledge bases use different embedding models or does not exist."', code=settings.RetCode.AUTHENTICATION_ERROR) - embd_mdl = TenantLLMService.model_instance( - kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id) + embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id) rerank_mdl = None if req.get("rerank_id"): - rerank_mdl = TenantLLMService.model_instance( - kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) + rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"]) if req.get("keyword", False): - chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,