search between multiple indiices for team function (#3079)

### What problem does this PR solve?

#2834 
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2024-10-29 13:19:01 +08:00 committed by GitHub
parent c5a3146a8c
commit 2d1fbefdb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 54 additions and 18 deletions

View File

@ -29,6 +29,7 @@ from .jin10 import Jin10, Jin10Param
from .tushare import TuShare, TuShareParam from .tushare import TuShare, TuShareParam
from .akshare import AkShare, AkShareParam from .akshare import AkShare, AkShareParam
from .crawler import Crawler, CrawlerParam from .crawler import Crawler, CrawlerParam
from .invoke import Invoke, InvokeParam
def component_class(class_name): def component_class(class_name):

View File

@ -17,6 +17,7 @@ import re
from functools import partial from functools import partial
import pandas as pd import pandas as pd
from api.db import LLMType from api.db import LLMType
from api.db.services.dialog_service import message_fit_in
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler from api.settings import retrievaler
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
@ -112,7 +113,7 @@ class Generate(ComponentBase):
kwargs["input"] = input kwargs["input"] = input
for n, v in kwargs.items(): for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v), prompt) prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
downstreams = self._canvas.get_component(self._id)["downstream"] downstreams = self._canvas.get_component(self._id)["downstream"]
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[ if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
@ -124,8 +125,10 @@ class Generate(ComponentBase):
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []} retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
return pd.DataFrame([res]) return pd.DataFrame([res])
ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size), msg = self._canvas.get_history(self._param.message_history_window_size)
self._param.gen_conf()) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
res = self.set_cite(retrieval_res, ans) res = self.set_cite(retrieval_res, ans)
return pd.DataFrame([res]) return pd.DataFrame([res])
@ -141,9 +144,10 @@ class Generate(ComponentBase):
self.set_output(res) self.set_output(res)
return return
msg = self._canvas.get_history(self._param.message_history_window_size)
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
answer = "" answer = ""
for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size), for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
self._param.gen_conf()):
res = {"content": ans, "reference": []} res = {"content": ans, "reference": []}
answer = ans answer = ans
yield res yield res

View File

@ -14,10 +14,10 @@
# limitations under the License. # limitations under the License.
# #
import json import json
import re
from abc import ABC from abc import ABC
import requests import requests
from deepdoc.parser import HtmlParser
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
@ -34,11 +34,13 @@ class InvokeParam(ComponentParamBase):
self.variables = [] self.variables = []
self.url = "" self.url = ""
self.timeout = 60 self.timeout = 60
self.clean_html = False
def check(self): def check(self):
self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put']) self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
self.check_empty(self.url, "End point URL") self.check_empty(self.url, "End point URL")
self.check_positive_integer(self.timeout, "Timeout time in second") self.check_positive_integer(self.timeout, "Timeout time in second")
self.check_boolean(self.clean_html, "Clean HTML")
class Invoke(ComponentBase, ABC): class Invoke(ComponentBase, ABC):
@ -63,7 +65,7 @@ class Invoke(ComponentBase, ABC):
if self._param.headers: if self._param.headers:
headers = json.loads(self._param.headers) headers = json.loads(self._param.headers)
proxies = None proxies = None
if self._param.proxy: if re.sub(r"https?:?/?/?", "", self._param.proxy):
proxies = {"http": self._param.proxy, "https": self._param.proxy} proxies = {"http": self._param.proxy, "https": self._param.proxy}
if method == 'get': if method == 'get':
@ -72,6 +74,10 @@ class Invoke(ComponentBase, ABC):
headers=headers, headers=headers,
proxies=proxies, proxies=proxies,
timeout=self._param.timeout) timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))
return Invoke.be_output(response.text) return Invoke.be_output(response.text)
if method == 'put': if method == 'put':
@ -80,5 +86,18 @@ class Invoke(ComponentBase, ABC):
headers=headers, headers=headers,
proxies=proxies, proxies=proxies,
timeout=self._param.timeout) timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))
return Invoke.be_output(response.text)
if method == 'post':
response = requests.post(url=url,
json=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))
return Invoke.be_output(response.text) return Invoke.be_output(response.text)

View File

@ -205,7 +205,9 @@ def chat(dialog, messages, stream=True, **kwargs):
else: else:
if prompt_config.get("keyword", False): if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1]) questions[-1] += keyword_extraction(chat_mdl, questions[-1])
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold, dialog.similarity_threshold,
dialog.vector_similarity_weight, dialog.vector_similarity_weight,
doc_ids=attachments, doc_ids=attachments,

View File

@ -16,11 +16,13 @@ import readability
import html_text import html_text
import chardet import chardet
def get_encoding(file): def get_encoding(file):
with open(file,'rb') as f: with open(file,'rb') as f:
tmp = chardet.detect(f.read()) tmp = chardet.detect(f.read())
return tmp['encoding'] return tmp['encoding']
class RAGFlowHtmlParser: class RAGFlowHtmlParser:
def __call__(self, fnm, binary=None): def __call__(self, fnm, binary=None):
txt = "" txt = ""

View File

@ -79,7 +79,7 @@ class Dealer:
Q("bool", must_not=Q("range", available_int={"lt": 1}))) Q("bool", must_not=Q("range", available_int={"lt": 1})))
return bqry return bqry
def search(self, req, idxnm, emb_mdl=None, highlight=False): def search(self, req, idxnms, emb_mdl=None, highlight=False):
qst = req.get("question", "") qst = req.get("question", "")
bqry, keywords = self.qryr.question(qst, min_match="30%") bqry, keywords = self.qryr.question(qst, min_match="30%")
bqry = self._add_filters(bqry, req) bqry = self._add_filters(bqry, req)
@ -134,7 +134,7 @@ class Dealer:
del s["highlight"] del s["highlight"]
q_vec = s["knn"]["query_vector"] q_vec = s["knn"]["query_vector"]
es_logger.info("【Q】: {}".format(json.dumps(s))) es_logger.info("【Q】: {}".format(json.dumps(s)))
res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src) res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
es_logger.info("TOTAL: {}".format(self.es.getTotal(res))) es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
if self.es.getTotal(res) == 0 and "knn" in s: if self.es.getTotal(res) == 0 and "knn" in s:
bqry, _ = self.qryr.question(qst, min_match="10%") bqry, _ = self.qryr.question(qst, min_match="10%")
@ -144,7 +144,7 @@ class Dealer:
s["query"] = bqry.to_dict() s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict() s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.17 s["knn"]["similarity"] = 0.17
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src) res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
es_logger.info("【Q】: {}".format(json.dumps(s))) es_logger.info("【Q】: {}".format(json.dumps(s)))
kwds = set([]) kwds = set([])
@ -358,20 +358,26 @@ class Dealer:
rag_tokenizer.tokenize(ans).split(" "), rag_tokenizer.tokenize(ans).split(" "),
rag_tokenizer.tokenize(inst).split(" ")) rag_tokenizer.tokenize(inst).split(" "))
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False): vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}} ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question: if not question:
return ranks return ranks
RERANK_PAGE_LIMIT = 3 RERANK_PAGE_LIMIT = 3
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128), req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
"question": question, "vector": True, "topk": top, "question": question, "vector": True, "topk": top,
"similarity": similarity_threshold, "similarity": similarity_threshold,
"available_int": 1} "available_int": 1}
if page > RERANK_PAGE_LIMIT: if page > RERANK_PAGE_LIMIT:
req["page"] = page req["page"] = page
req["size"] = page_size req["size"] = page_size
sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)
if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",")
sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
ranks["total"] = sres.total ranks["total"] = sres.total
if page <= RERANK_PAGE_LIMIT: if page <= RERANK_PAGE_LIMIT:
@ -467,7 +473,7 @@ class Dealer:
s = Search() s = Search()
s = s.query(Q("match", doc_id=doc_id))[0:max_count] s = s.query(Q("match", doc_id=doc_id))[0:max_count]
s = s.to_dict() s = s.to_dict()
es_res = self.es.search(s, idxnm=index_name(tenant_id), timeout="600s", src=fields) es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
res = [] res = []
for index, chunk in enumerate(es_res['hits']['hits']): for index, chunk in enumerate(es_res['hits']['hits']):
res.append({fld: chunk['_source'].get(fld) for fld in fields}) res.append({fld: chunk['_source'].get(fld) for fld in fields})

View File

@ -221,12 +221,14 @@ class ESConnection:
return False return False
def search(self, q, idxnm=None, src=False, timeout="2s"): def search(self, q, idxnms=None, src=False, timeout="2s"):
if not isinstance(q, dict): if not isinstance(q, dict):
q = Search().query(q).to_dict() q = Search().query(q).to_dict()
if isinstance(idxnms, str):
idxnms = idxnms.split(",")
for i in range(3): for i in range(3):
try: try:
res = self.es.search(index=(self.idxnm if not idxnm else idxnm), res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
body=q, body=q,
timeout=timeout, timeout=timeout,
# search_type="dfs_query_then_fetch", # search_type="dfs_query_then_fetch",