mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-22 14:10:01 +08:00
search between multiple indiices for team function (#3079)
### What problem does this PR solve? #2834 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
c5a3146a8c
commit
2d1fbefdb5
@ -29,6 +29,7 @@ from .jin10 import Jin10, Jin10Param
|
|||||||
from .tushare import TuShare, TuShareParam
|
from .tushare import TuShare, TuShareParam
|
||||||
from .akshare import AkShare, AkShareParam
|
from .akshare import AkShare, AkShareParam
|
||||||
from .crawler import Crawler, CrawlerParam
|
from .crawler import Crawler, CrawlerParam
|
||||||
|
from .invoke import Invoke, InvokeParam
|
||||||
|
|
||||||
|
|
||||||
def component_class(class_name):
|
def component_class(class_name):
|
||||||
|
@ -17,6 +17,7 @@ import re
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
|
from api.db.services.dialog_service import message_fit_in
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.settings import retrievaler
|
from api.settings import retrievaler
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
@ -112,7 +113,7 @@ class Generate(ComponentBase):
|
|||||||
|
|
||||||
kwargs["input"] = input
|
kwargs["input"] = input
|
||||||
for n, v in kwargs.items():
|
for n, v in kwargs.items():
|
||||||
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v), prompt)
|
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
|
||||||
|
|
||||||
downstreams = self._canvas.get_component(self._id)["downstream"]
|
downstreams = self._canvas.get_component(self._id)["downstream"]
|
||||||
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
|
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
|
||||||
@ -124,8 +125,10 @@ class Generate(ComponentBase):
|
|||||||
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
|
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
|
||||||
return pd.DataFrame([res])
|
return pd.DataFrame([res])
|
||||||
|
|
||||||
ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
|
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||||
self._param.gen_conf())
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
||||||
|
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
|
||||||
|
|
||||||
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
||||||
res = self.set_cite(retrieval_res, ans)
|
res = self.set_cite(retrieval_res, ans)
|
||||||
return pd.DataFrame([res])
|
return pd.DataFrame([res])
|
||||||
@ -141,9 +144,10 @@ class Generate(ComponentBase):
|
|||||||
self.set_output(res)
|
self.set_output(res)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||||
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size),
|
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
|
||||||
self._param.gen_conf()):
|
|
||||||
res = {"content": ans, "reference": []}
|
res = {"content": ans, "reference": []}
|
||||||
answer = ans
|
answer = ans
|
||||||
yield res
|
yield res
|
||||||
|
@ -14,10 +14,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from deepdoc.parser import HtmlParser
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
|
|
||||||
|
|
||||||
@ -34,11 +34,13 @@ class InvokeParam(ComponentParamBase):
|
|||||||
self.variables = []
|
self.variables = []
|
||||||
self.url = ""
|
self.url = ""
|
||||||
self.timeout = 60
|
self.timeout = 60
|
||||||
|
self.clean_html = False
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
|
self.check_valid_value(self.method.lower(), "Type of content from the crawler", ['get', 'post', 'put'])
|
||||||
self.check_empty(self.url, "End point URL")
|
self.check_empty(self.url, "End point URL")
|
||||||
self.check_positive_integer(self.timeout, "Timeout time in second")
|
self.check_positive_integer(self.timeout, "Timeout time in second")
|
||||||
|
self.check_boolean(self.clean_html, "Clean HTML")
|
||||||
|
|
||||||
|
|
||||||
class Invoke(ComponentBase, ABC):
|
class Invoke(ComponentBase, ABC):
|
||||||
@ -63,7 +65,7 @@ class Invoke(ComponentBase, ABC):
|
|||||||
if self._param.headers:
|
if self._param.headers:
|
||||||
headers = json.loads(self._param.headers)
|
headers = json.loads(self._param.headers)
|
||||||
proxies = None
|
proxies = None
|
||||||
if self._param.proxy:
|
if re.sub(r"https?:?/?/?", "", self._param.proxy):
|
||||||
proxies = {"http": self._param.proxy, "https": self._param.proxy}
|
proxies = {"http": self._param.proxy, "https": self._param.proxy}
|
||||||
|
|
||||||
if method == 'get':
|
if method == 'get':
|
||||||
@ -72,6 +74,10 @@ class Invoke(ComponentBase, ABC):
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
timeout=self._param.timeout)
|
timeout=self._param.timeout)
|
||||||
|
if self._param.clean_html:
|
||||||
|
sections = HtmlParser()(None, response.content)
|
||||||
|
return Invoke.be_output("\n".join(sections))
|
||||||
|
|
||||||
return Invoke.be_output(response.text)
|
return Invoke.be_output(response.text)
|
||||||
|
|
||||||
if method == 'put':
|
if method == 'put':
|
||||||
@ -80,5 +86,18 @@ class Invoke(ComponentBase, ABC):
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
timeout=self._param.timeout)
|
timeout=self._param.timeout)
|
||||||
|
if self._param.clean_html:
|
||||||
|
sections = HtmlParser()(None, response.content)
|
||||||
|
return Invoke.be_output("\n".join(sections))
|
||||||
|
return Invoke.be_output(response.text)
|
||||||
|
|
||||||
|
if method == 'post':
|
||||||
|
response = requests.post(url=url,
|
||||||
|
json=args,
|
||||||
|
headers=headers,
|
||||||
|
proxies=proxies,
|
||||||
|
timeout=self._param.timeout)
|
||||||
|
if self._param.clean_html:
|
||||||
|
sections = HtmlParser()(None, response.content)
|
||||||
|
return Invoke.be_output("\n".join(sections))
|
||||||
return Invoke.be_output(response.text)
|
return Invoke.be_output(response.text)
|
||||||
|
@ -205,7 +205,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
else:
|
else:
|
||||||
if prompt_config.get("keyword", False):
|
if prompt_config.get("keyword", False):
|
||||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||||
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
|
||||||
|
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||||
|
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
||||||
dialog.similarity_threshold,
|
dialog.similarity_threshold,
|
||||||
dialog.vector_similarity_weight,
|
dialog.vector_similarity_weight,
|
||||||
doc_ids=attachments,
|
doc_ids=attachments,
|
||||||
|
@ -16,11 +16,13 @@ import readability
|
|||||||
import html_text
|
import html_text
|
||||||
import chardet
|
import chardet
|
||||||
|
|
||||||
|
|
||||||
def get_encoding(file):
|
def get_encoding(file):
|
||||||
with open(file,'rb') as f:
|
with open(file,'rb') as f:
|
||||||
tmp = chardet.detect(f.read())
|
tmp = chardet.detect(f.read())
|
||||||
return tmp['encoding']
|
return tmp['encoding']
|
||||||
|
|
||||||
|
|
||||||
class RAGFlowHtmlParser:
|
class RAGFlowHtmlParser:
|
||||||
def __call__(self, fnm, binary=None):
|
def __call__(self, fnm, binary=None):
|
||||||
txt = ""
|
txt = ""
|
||||||
|
@ -79,7 +79,7 @@ class Dealer:
|
|||||||
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
||||||
return bqry
|
return bqry
|
||||||
|
|
||||||
def search(self, req, idxnm, emb_mdl=None, highlight=False):
|
def search(self, req, idxnms, emb_mdl=None, highlight=False):
|
||||||
qst = req.get("question", "")
|
qst = req.get("question", "")
|
||||||
bqry, keywords = self.qryr.question(qst, min_match="30%")
|
bqry, keywords = self.qryr.question(qst, min_match="30%")
|
||||||
bqry = self._add_filters(bqry, req)
|
bqry = self._add_filters(bqry, req)
|
||||||
@ -134,7 +134,7 @@ class Dealer:
|
|||||||
del s["highlight"]
|
del s["highlight"]
|
||||||
q_vec = s["knn"]["query_vector"]
|
q_vec = s["knn"]["query_vector"]
|
||||||
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
||||||
res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
|
res = self.es.search(deepcopy(s), idxnms=idxnms, timeout="600s", src=src)
|
||||||
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
|
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
|
||||||
if self.es.getTotal(res) == 0 and "knn" in s:
|
if self.es.getTotal(res) == 0 and "knn" in s:
|
||||||
bqry, _ = self.qryr.question(qst, min_match="10%")
|
bqry, _ = self.qryr.question(qst, min_match="10%")
|
||||||
@ -144,7 +144,7 @@ class Dealer:
|
|||||||
s["query"] = bqry.to_dict()
|
s["query"] = bqry.to_dict()
|
||||||
s["knn"]["filter"] = bqry.to_dict()
|
s["knn"]["filter"] = bqry.to_dict()
|
||||||
s["knn"]["similarity"] = 0.17
|
s["knn"]["similarity"] = 0.17
|
||||||
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
res = self.es.search(s, idxnms=idxnms, timeout="600s", src=src)
|
||||||
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
||||||
|
|
||||||
kwds = set([])
|
kwds = set([])
|
||||||
@ -358,20 +358,26 @@ class Dealer:
|
|||||||
rag_tokenizer.tokenize(ans).split(" "),
|
rag_tokenizer.tokenize(ans).split(" "),
|
||||||
rag_tokenizer.tokenize(inst).split(" "))
|
rag_tokenizer.tokenize(inst).split(" "))
|
||||||
|
|
||||||
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
|
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
|
||||||
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
|
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
|
||||||
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
||||||
if not question:
|
if not question:
|
||||||
return ranks
|
return ranks
|
||||||
|
|
||||||
RERANK_PAGE_LIMIT = 3
|
RERANK_PAGE_LIMIT = 3
|
||||||
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
|
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
|
||||||
"question": question, "vector": True, "topk": top,
|
"question": question, "vector": True, "topk": top,
|
||||||
"similarity": similarity_threshold,
|
"similarity": similarity_threshold,
|
||||||
"available_int": 1}
|
"available_int": 1}
|
||||||
|
|
||||||
if page > RERANK_PAGE_LIMIT:
|
if page > RERANK_PAGE_LIMIT:
|
||||||
req["page"] = page
|
req["page"] = page
|
||||||
req["size"] = page_size
|
req["size"] = page_size
|
||||||
sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)
|
|
||||||
|
if isinstance(tenant_ids, str):
|
||||||
|
tenant_ids = tenant_ids.split(",")
|
||||||
|
|
||||||
|
sres = self.search(req, [index_name(tid) for tid in tenant_ids], embd_mdl, highlight)
|
||||||
ranks["total"] = sres.total
|
ranks["total"] = sres.total
|
||||||
|
|
||||||
if page <= RERANK_PAGE_LIMIT:
|
if page <= RERANK_PAGE_LIMIT:
|
||||||
@ -467,7 +473,7 @@ class Dealer:
|
|||||||
s = Search()
|
s = Search()
|
||||||
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
|
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
|
||||||
s = s.to_dict()
|
s = s.to_dict()
|
||||||
es_res = self.es.search(s, idxnm=index_name(tenant_id), timeout="600s", src=fields)
|
es_res = self.es.search(s, idxnms=index_name(tenant_id), timeout="600s", src=fields)
|
||||||
res = []
|
res = []
|
||||||
for index, chunk in enumerate(es_res['hits']['hits']):
|
for index, chunk in enumerate(es_res['hits']['hits']):
|
||||||
res.append({fld: chunk['_source'].get(fld) for fld in fields})
|
res.append({fld: chunk['_source'].get(fld) for fld in fields})
|
||||||
|
@ -221,12 +221,14 @@ class ESConnection:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def search(self, q, idxnm=None, src=False, timeout="2s"):
|
def search(self, q, idxnms=None, src=False, timeout="2s"):
|
||||||
if not isinstance(q, dict):
|
if not isinstance(q, dict):
|
||||||
q = Search().query(q).to_dict()
|
q = Search().query(q).to_dict()
|
||||||
|
if isinstance(idxnms, str):
|
||||||
|
idxnms = idxnms.split(",")
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
try:
|
try:
|
||||||
res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
|
res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
|
||||||
body=q,
|
body=q,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
# search_type="dfs_query_then_fetch",
|
# search_type="dfs_query_then_fetch",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user