fix english query bug (#840)

### What problem does this PR solve?

#834 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
KevinHuSh 2024-05-20 12:23:51 +08:00 committed by GitHub
parent 6683179d6a
commit 2b36283712
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 80 additions and 7 deletions

View File

@ -118,7 +118,7 @@ def chat(dialog, messages, stream=True, **kwargs):
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=kwargs.get("doc_ids", "").split(","),
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info(

View File

@ -20,6 +20,7 @@ from openai import OpenAI
import openai
from ollama import Client
from rag.nlp import is_english
from rag.utils import num_tokens_from_string
class Base(ABC):
@ -255,3 +256,46 @@ class OllamaChat(Base):
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield 0
class LocalLLM(Base):
class RPCProxy:
def __init__(self, host, port):
self.host = host
self.port = int(port)
self.__conn()
def __conn(self):
from multiprocessing.connection import Client
self._connection = Client(
(self.host, self.port), authkey=b'infiniflow-token4kevinhu')
def __getattr__(self, name):
import pickle
def do_rpc(*args, **kwargs):
for _ in range(3):
try:
self._connection.send(
pickle.dumps((name, args, kwargs)))
return pickle.loads(self._connection.recv())
except Exception as e:
self.__conn()
raise Exception("RPC connection lost!")
return do_rpc
def __init__(self, key, model_name="glm-3-turbo"):
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
try:
ans = self.client.chat(
history,
gen_conf
)
return ans, num_tokens_from_string(ans)
except Exception as e:
return "**ERROR**: " + str(e), 0

View File

@ -2,9 +2,10 @@ import argparse
import pickle
import random
import time
from copy import deepcopy
from multiprocessing.connection import Listener
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
def torch_gc():
@ -95,6 +96,32 @@ def chat(messages, gen_conf):
return str(e)
def chat_streamly(messages, gen_conf):
global tokenizer
model = Model()
try:
torch_gc()
conf = deepcopy(gen_conf)
print(messages, conf)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = TextStreamer(tokenizer)
conf["inputs"] = model_inputs.input_ids
conf["streamer"] = streamer
conf["max_new_tokens"] = conf["max_tokens"]
del conf["max_tokens"]
thread = Thread(target=model.generate, kwargs=conf)
thread.start()
for _, new_text in enumerate(streamer):
yield new_text
except Exception as e:
yield "**ERROR**: " + str(e)
def Model():
global models
random.seed(time.time())
@ -113,6 +140,7 @@ if __name__ == "__main__":
handler = RPCHandler()
handler.register_function(chat)
handler.register_function(chat_streamly)
models = []
for _ in range(1):

View File

@ -36,7 +36,7 @@ class EsQueryer:
patts = [
(r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down)", " ")
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
]
for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
@ -44,7 +44,7 @@ class EsQueryer:
def question(self, txt, tbl="qa", min_match="60%"):
txt = re.sub(
r"[ \r\n\t,,。??/`!&]+",
r"[ \r\n\t,,。??/`!&\^%%]+",
" ",
rag_tokenizer.tradi2simp(
rag_tokenizer.strQ2B(
@ -53,9 +53,10 @@ class EsQueryer:
if not self.isChinese(txt):
tks = rag_tokenizer.tokenize(txt).split(" ")
q = copy.deepcopy(tks)
for i in range(1, len(tks)):
q.append("\"%s %s\"^2" % (tks[i - 1], tks[i]))
tks_w = self.tw.weights(tks)
q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w]
for i in range(1, len(tks_w)):
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
if not q:
q.append(txt)
return Q("bool",