ragflow/agent/component/retrieval.py
zhangcdian bf0d516e49
Agent Update: Fix Role Issue and Enhance KB Search (#5842)
### What problem does this PR solve?

**generate.py 更新:**
问题:部分模型提供商对输入对话内容的格式有严格校验,要求第一条内容的 role 不能为 assistant,否则会报错。
解决:删除了系统设置的 agent 开场白,确保传递给模型的对话内容中,第一条内容的 role 不为 assistant。

**retrieval.py 更新:**
问题:当前知识库检索使用全部对话内容作为输入,可能导致检索结果不准确。
解决:改为仅使用用户最后提出的一个问题进行知识库检索,提高检索的准确性。

**Update generate.py:**
Issue: Some model providers have strict validation rules for the format
of input conversation content, requiring that the role of the first
content must not be assistant. Otherwise, an error will occur.
Solution: Removed the system-set agent opening statement to ensure that
the role of the first content in the conversation passed to the model is
not assistant.

**Update retrieval.py:**
Issue: The current knowledge base retrieval uses the entire conversation
content as input, which may lead to inaccurate retrieval results.
Solution: Changed the retrieval logic to use only the last question
asked by the user for knowledge base retrieval, improving retrieval
accuracy.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Performance Improvement
2025-03-10 18:29:58 +08:00

92 lines
3.4 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
import pandas as pd
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
from agent.component.base import ComponentBase, ComponentParamBase
from rag.app.tag import label_question
class RetrievalParam(ComponentParamBase):
"""
Define the Retrieval component parameters.
"""
def __init__(self):
super().__init__()
self.similarity_threshold = 0.2
self.keywords_similarity_weight = 0.5
self.top_n = 8
self.top_k = 1024
self.kb_ids = []
self.rerank_id = ""
self.empty_response = ""
def check(self):
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
self.check_decimal_float(self.keywords_similarity_weight, "[Retrieval] Keyword similarity weight")
self.check_positive_number(self.top_n, "[Retrieval] Top N")
class Retrieval(ComponentBase, ABC):
component_name = "Retrieval"
def _run(self, history, **kwargs):
query = self.get_input()
query = str(query["content"][0]) if "content" in query else ""
lines = query.split('\n')
user_queries = [line.split("USER:", 1)[1] for line in lines if line.startswith("USER:")]
query = user_queries[-1] if user_queries else ""
kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids)
if not kbs:
return Retrieval.be_output("")
embd_nms = list(set([kb.embd_id for kb in kbs]))
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
embd_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, embd_nms[0])
self._canvas.set_embedding_model(embd_nms[0])
rerank_mdl = None
if self._param.rerank_id:
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
1, self._param.top_n,
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(query, kbs))
if not kbinfos["chunks"]:
df = Retrieval.be_output("")
if self._param.empty_response and self._param.empty_response.strip():
df["empty_response"] = self._param.empty_response
return df
df = pd.DataFrame(kbinfos["chunks"])
df["content"] = df["content_with_weight"]
del df["content_with_weight"]
logging.debug("{} {}".format(query, df))
return df