mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 20:39:03 +08:00
fix english query bug (#840)
### What problem does this PR solve? #834 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
6683179d6a
commit
2b36283712
@ -118,7 +118,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
||||||
dialog.similarity_threshold,
|
dialog.similarity_threshold,
|
||||||
dialog.vector_similarity_weight,
|
dialog.vector_similarity_weight,
|
||||||
doc_ids=kwargs.get("doc_ids", "").split(","),
|
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
|
||||||
top=1024, aggs=False)
|
top=1024, aggs=False)
|
||||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||||
chat_logger.info(
|
chat_logger.info(
|
||||||
|
@ -20,6 +20,7 @@ from openai import OpenAI
|
|||||||
import openai
|
import openai
|
||||||
from ollama import Client
|
from ollama import Client
|
||||||
from rag.nlp import is_english
|
from rag.nlp import is_english
|
||||||
|
from rag.utils import num_tokens_from_string
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
@ -255,3 +256,46 @@ class OllamaChat(Base):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
yield 0
|
yield 0
|
||||||
|
|
||||||
|
|
||||||
|
class LocalLLM(Base):
|
||||||
|
class RPCProxy:
|
||||||
|
def __init__(self, host, port):
|
||||||
|
self.host = host
|
||||||
|
self.port = int(port)
|
||||||
|
self.__conn()
|
||||||
|
|
||||||
|
def __conn(self):
|
||||||
|
from multiprocessing.connection import Client
|
||||||
|
self._connection = Client(
|
||||||
|
(self.host, self.port), authkey=b'infiniflow-token4kevinhu')
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
def do_rpc(*args, **kwargs):
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
self._connection.send(
|
||||||
|
pickle.dumps((name, args, kwargs)))
|
||||||
|
return pickle.loads(self._connection.recv())
|
||||||
|
except Exception as e:
|
||||||
|
self.__conn()
|
||||||
|
raise Exception("RPC connection lost!")
|
||||||
|
|
||||||
|
return do_rpc
|
||||||
|
|
||||||
|
def __init__(self, key, model_name="glm-3-turbo"):
|
||||||
|
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
|
||||||
|
|
||||||
|
def chat(self, system, history, gen_conf):
|
||||||
|
if system:
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
try:
|
||||||
|
ans = self.client.chat(
|
||||||
|
history,
|
||||||
|
gen_conf
|
||||||
|
)
|
||||||
|
return ans, num_tokens_from_string(ans)
|
||||||
|
except Exception as e:
|
||||||
|
return "**ERROR**: " + str(e), 0
|
@ -2,9 +2,10 @@ import argparse
|
|||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
from copy import deepcopy
|
||||||
from multiprocessing.connection import Listener
|
from multiprocessing.connection import Listener
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
||||||
|
|
||||||
|
|
||||||
def torch_gc():
|
def torch_gc():
|
||||||
@ -95,6 +96,32 @@ def chat(messages, gen_conf):
|
|||||||
return str(e)
|
return str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def chat_streamly(messages, gen_conf):
|
||||||
|
global tokenizer
|
||||||
|
model = Model()
|
||||||
|
try:
|
||||||
|
torch_gc()
|
||||||
|
conf = deepcopy(gen_conf)
|
||||||
|
print(messages, conf)
|
||||||
|
text = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||||
|
streamer = TextStreamer(tokenizer)
|
||||||
|
conf["inputs"] = model_inputs.input_ids
|
||||||
|
conf["streamer"] = streamer
|
||||||
|
conf["max_new_tokens"] = conf["max_tokens"]
|
||||||
|
del conf["max_tokens"]
|
||||||
|
thread = Thread(target=model.generate, kwargs=conf)
|
||||||
|
thread.start()
|
||||||
|
for _, new_text in enumerate(streamer):
|
||||||
|
yield new_text
|
||||||
|
except Exception as e:
|
||||||
|
yield "**ERROR**: " + str(e)
|
||||||
|
|
||||||
|
|
||||||
def Model():
|
def Model():
|
||||||
global models
|
global models
|
||||||
random.seed(time.time())
|
random.seed(time.time())
|
||||||
@ -113,6 +140,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
handler = RPCHandler()
|
handler = RPCHandler()
|
||||||
handler.register_function(chat)
|
handler.register_function(chat)
|
||||||
|
handler.register_function(chat_streamly)
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
for _ in range(1):
|
for _ in range(1):
|
||||||
|
@ -36,7 +36,7 @@ class EsQueryer:
|
|||||||
patts = [
|
patts = [
|
||||||
(r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
|
(r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
|
||||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||||
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down)", " ")
|
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
|
||||||
]
|
]
|
||||||
for r, p in patts:
|
for r, p in patts:
|
||||||
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
||||||
@ -44,7 +44,7 @@ class EsQueryer:
|
|||||||
|
|
||||||
def question(self, txt, tbl="qa", min_match="60%"):
|
def question(self, txt, tbl="qa", min_match="60%"):
|
||||||
txt = re.sub(
|
txt = re.sub(
|
||||||
r"[ \r\n\t,,。??/`!!&]+",
|
r"[ \r\n\t,,。??/`!!&\^%%]+",
|
||||||
" ",
|
" ",
|
||||||
rag_tokenizer.tradi2simp(
|
rag_tokenizer.tradi2simp(
|
||||||
rag_tokenizer.strQ2B(
|
rag_tokenizer.strQ2B(
|
||||||
@ -53,9 +53,10 @@ class EsQueryer:
|
|||||||
|
|
||||||
if not self.isChinese(txt):
|
if not self.isChinese(txt):
|
||||||
tks = rag_tokenizer.tokenize(txt).split(" ")
|
tks = rag_tokenizer.tokenize(txt).split(" ")
|
||||||
q = copy.deepcopy(tks)
|
tks_w = self.tw.weights(tks)
|
||||||
for i in range(1, len(tks)):
|
q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w]
|
||||||
q.append("\"%s %s\"^2" % (tks[i - 1], tks[i]))
|
for i in range(1, len(tks_w)):
|
||||||
|
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
|
||||||
if not q:
|
if not q:
|
||||||
q.append(txt)
|
q.append(txt)
|
||||||
return Q("bool",
|
return Q("bool",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user