mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-15 21:55:55 +08:00
Chat Use CVmodel (#1607)
### What problem does this PR solve? #1230 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
347cb61f26
commit
58df013722
@ -13,6 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import os
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
@ -26,6 +28,7 @@ from rag.app.resume import forbidden_select_fields4resume
|
|||||||
from rag.nlp import keyword_extraction
|
from rag.nlp import keyword_extraction
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||||
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
||||||
class DialogService(CommonService):
|
class DialogService(CommonService):
|
||||||
@ -73,6 +76,15 @@ def message_fit_in(msg, max_length=4000):
|
|||||||
return max_length, msg
|
return max_length, msg
|
||||||
|
|
||||||
|
|
||||||
|
def llm_id2llm_type(llm_id):
|
||||||
|
fnm = os.path.join(get_project_base_directory(), "conf")
|
||||||
|
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
||||||
|
for llm_factory in llm_factories["factory_llm_infos"]:
|
||||||
|
for llm in llm_factory["llm"]:
|
||||||
|
if llm_id == llm["llm_name"]:
|
||||||
|
return llm["model_type"].strip(",")[-1]
|
||||||
|
|
||||||
|
|
||||||
def chat(dialog, messages, stream=True, **kwargs):
|
def chat(dialog, messages, stream=True, **kwargs):
|
||||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||||
llm = LLMService.query(llm_name=dialog.llm_id)
|
llm = LLMService.query(llm_name=dialog.llm_id)
|
||||||
@ -91,7 +103,10 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
questions = [m["content"] for m in messages if m["role"] == "user"]
|
questions = [m["content"] for m in messages if m["role"] == "user"]
|
||||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||||
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||||
|
else:
|
||||||
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||||
|
|
||||||
prompt_config = dialog.prompt_config
|
prompt_config = dialog.prompt_config
|
||||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||||
@ -328,7 +343,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|||||||
|
|
||||||
|
|
||||||
def relevant(tenant_id, llm_id, question, contents: list):
|
def relevant(tenant_id, llm_id, question, contents: list):
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
if llm_id2llm_type(llm_id) == "image2text":
|
||||||
|
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||||
|
else:
|
||||||
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||||
prompt = """
|
prompt = """
|
||||||
You are a grader assessing relevance of a retrieved document to a user question.
|
You are a grader assessing relevance of a retrieved document to a user question.
|
||||||
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
|
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
|
||||||
@ -347,7 +365,10 @@ def relevant(tenant_id, llm_id, question, contents: list):
|
|||||||
|
|
||||||
|
|
||||||
def rewrite(tenant_id, llm_id, question):
|
def rewrite(tenant_id, llm_id, question):
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
if llm_id2llm_type(llm_id) == "image2text":
|
||||||
|
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||||
|
else:
|
||||||
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||||
prompt = """
|
prompt = """
|
||||||
You are an expert at query expansion to generate a paraphrasing of a question.
|
You are an expert at query expansion to generate a paraphrasing of a question.
|
||||||
I can't retrieval relevant information from the knowledge base by using user's question directly.
|
I can't retrieval relevant information from the knowledge base by using user's question directly.
|
||||||
|
@ -70,7 +70,7 @@ class TenantLLMService(CommonService):
|
|||||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||||
mdlnm = tenant.asr_id
|
mdlnm = tenant.asr_id
|
||||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||||
mdlnm = tenant.img2txt_id
|
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
||||||
elif llm_type == LLMType.CHAT.value:
|
elif llm_type == LLMType.CHAT.value:
|
||||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||||
elif llm_type == LLMType.RERANK:
|
elif llm_type == LLMType.RERANK:
|
||||||
|
@ -26,6 +26,7 @@ from io import BytesIO
|
|||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from rag.nlp import is_english
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
@ -36,7 +37,60 @@ class Base(ABC):
|
|||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image, max_tokens=300):
|
||||||
raise NotImplementedError("Please implement encode method!")
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
|
def chat(self, system, history, gen_conf, image=""):
|
||||||
|
if system:
|
||||||
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
try:
|
||||||
|
for his in history:
|
||||||
|
if his["role"] == "user":
|
||||||
|
his["content"] = self.chat_prompt(his["content"], image)
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=history,
|
||||||
|
max_tokens=gen_conf.get("max_tokens", 1000),
|
||||||
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
|
top_p=gen_conf.get("top_p", 0.7)
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||||
|
except Exception as e:
|
||||||
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
def chat_streamly(self, system, history, gen_conf, image=""):
|
||||||
|
if system:
|
||||||
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
|
||||||
|
ans = ""
|
||||||
|
tk_count = 0
|
||||||
|
try:
|
||||||
|
for his in history:
|
||||||
|
if his["role"] == "user":
|
||||||
|
his["content"] = self.chat_prompt(his["content"], image)
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=history,
|
||||||
|
max_tokens=gen_conf.get("max_tokens", 1000),
|
||||||
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
|
top_p=gen_conf.get("top_p", 0.7),
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
for resp in response:
|
||||||
|
if not resp.choices[0].delta.content: continue
|
||||||
|
delta = resp.choices[0].delta.content
|
||||||
|
ans += delta
|
||||||
|
if resp.choices[0].finish_reason == "length":
|
||||||
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
tk_count = resp.usage.total_tokens
|
||||||
|
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
|
||||||
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
|
yield tk_count
|
||||||
|
|
||||||
def image2base64(self, image):
|
def image2base64(self, image):
|
||||||
if isinstance(image, bytes):
|
if isinstance(image, bytes):
|
||||||
return base64.b64encode(image).decode("utf-8")
|
return base64.b64encode(image).decode("utf-8")
|
||||||
@ -68,6 +122,21 @@ class Base(ABC):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def chat_prompt(self, text, b64):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{b64}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": text
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GptV4(Base):
|
class GptV4(Base):
|
||||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
|
||||||
@ -140,6 +209,12 @@ class QWenCV(Base):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def chat_prompt(self, text, b64):
|
||||||
|
return [
|
||||||
|
{"image": f"{b64}"},
|
||||||
|
{"text": text},
|
||||||
|
]
|
||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image, max_tokens=300):
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from dashscope import MultiModalConversation
|
from dashscope import MultiModalConversation
|
||||||
@ -149,6 +224,66 @@ class QWenCV(Base):
|
|||||||
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
||||||
return response.message, 0
|
return response.message, 0
|
||||||
|
|
||||||
|
def chat(self, system, history, gen_conf, image=""):
|
||||||
|
from http import HTTPStatus
|
||||||
|
from dashscope import MultiModalConversation
|
||||||
|
if system:
|
||||||
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
|
||||||
|
for his in history:
|
||||||
|
if his["role"] == "user":
|
||||||
|
his["content"] = self.chat_prompt(his["content"], image)
|
||||||
|
response = MultiModalConversation.call(model=self.model_name, messages=history,
|
||||||
|
max_tokens=gen_conf.get("max_tokens", 1000),
|
||||||
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
|
top_p=gen_conf.get("top_p", 0.7))
|
||||||
|
|
||||||
|
ans = ""
|
||||||
|
tk_count = 0
|
||||||
|
if response.status_code == HTTPStatus.OK:
|
||||||
|
ans += response.output.choices[0]['message']['content']
|
||||||
|
tk_count += response.usage.total_tokens
|
||||||
|
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||||
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
return ans, tk_count
|
||||||
|
|
||||||
|
return "**ERROR**: " + response.message, tk_count
|
||||||
|
|
||||||
|
def chat_streamly(self, system, history, gen_conf, image=""):
|
||||||
|
from http import HTTPStatus
|
||||||
|
from dashscope import MultiModalConversation
|
||||||
|
if system:
|
||||||
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
|
||||||
|
for his in history:
|
||||||
|
if his["role"] == "user":
|
||||||
|
his["content"] = self.chat_prompt(his["content"], image)
|
||||||
|
|
||||||
|
ans = ""
|
||||||
|
tk_count = 0
|
||||||
|
try:
|
||||||
|
response = MultiModalConversation.call(model=self.model_name, messages=history,
|
||||||
|
max_tokens=gen_conf.get("max_tokens", 1000),
|
||||||
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
|
top_p=gen_conf.get("top_p", 0.7),
|
||||||
|
stream=True)
|
||||||
|
for resp in response:
|
||||||
|
if resp.status_code == HTTPStatus.OK:
|
||||||
|
ans = resp.output.choices[0]['message']['content']
|
||||||
|
tk_count = resp.usage.total_tokens
|
||||||
|
if resp.output.choices[0].get("finish_reason", "") == "length":
|
||||||
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
yield ans
|
||||||
|
else:
|
||||||
|
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
|
||||||
|
"Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||||
|
except Exception as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
|
yield tk_count
|
||||||
|
|
||||||
|
|
||||||
class Zhipu4V(Base):
|
class Zhipu4V(Base):
|
||||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||||
@ -166,6 +301,59 @@ class Zhipu4V(Base):
|
|||||||
)
|
)
|
||||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
|
def chat(self, system, history, gen_conf, image=""):
|
||||||
|
if system:
|
||||||
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
try:
|
||||||
|
for his in history:
|
||||||
|
if his["role"] == "user":
|
||||||
|
his["content"] = self.chat_prompt(his["content"], image)
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=history,
|
||||||
|
max_tokens=gen_conf.get("max_tokens", 1000),
|
||||||
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
|
top_p=gen_conf.get("top_p", 0.7)
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||||
|
except Exception as e:
|
||||||
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
def chat_streamly(self, system, history, gen_conf, image=""):
|
||||||
|
if system:
|
||||||
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
|
||||||
|
ans = ""
|
||||||
|
tk_count = 0
|
||||||
|
try:
|
||||||
|
for his in history:
|
||||||
|
if his["role"] == "user":
|
||||||
|
his["content"] = self.chat_prompt(his["content"], image)
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=history,
|
||||||
|
max_tokens=gen_conf.get("max_tokens", 1000),
|
||||||
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
|
top_p=gen_conf.get("top_p", 0.7),
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
for resp in response:
|
||||||
|
if not resp.choices[0].delta.content: continue
|
||||||
|
delta = resp.choices[0].delta.content
|
||||||
|
ans += delta
|
||||||
|
if resp.choices[0].finish_reason == "length":
|
||||||
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
tk_count = resp.usage.total_tokens
|
||||||
|
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
|
||||||
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
|
yield tk_count
|
||||||
|
|
||||||
|
|
||||||
class OllamaCV(Base):
|
class OllamaCV(Base):
|
||||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||||
@ -188,6 +376,63 @@ class OllamaCV(Base):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
def chat(self, system, history, gen_conf, image=""):
|
||||||
|
if system:
|
||||||
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
for his in history:
|
||||||
|
if his["role"] == "user":
|
||||||
|
his["images"] = [image]
|
||||||
|
options = {}
|
||||||
|
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
||||||
|
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
||||||
|
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
|
||||||
|
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||||
|
response = self.client.chat(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=history,
|
||||||
|
options=options,
|
||||||
|
keep_alive=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
ans = response["message"]["content"].strip()
|
||||||
|
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
|
||||||
|
except Exception as e:
|
||||||
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
def chat_streamly(self, system, history, gen_conf, image=""):
|
||||||
|
if system:
|
||||||
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
|
||||||
|
for his in history:
|
||||||
|
if his["role"] == "user":
|
||||||
|
his["images"] = [image]
|
||||||
|
options = {}
|
||||||
|
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
||||||
|
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
||||||
|
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
|
||||||
|
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
||||||
|
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
||||||
|
ans = ""
|
||||||
|
try:
|
||||||
|
response = self.client.chat(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=history,
|
||||||
|
stream=True,
|
||||||
|
options=options,
|
||||||
|
keep_alive=-1
|
||||||
|
)
|
||||||
|
for resp in response:
|
||||||
|
if resp["done"]:
|
||||||
|
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
|
||||||
|
ans += resp["message"]["content"]
|
||||||
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
yield 0
|
||||||
|
|
||||||
|
|
||||||
class LocalAICV(Base):
|
class LocalAICV(Base):
|
||||||
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
||||||
@ -236,7 +481,7 @@ class XinferenceCV(Base):
|
|||||||
|
|
||||||
class GeminiCV(Base):
|
class GeminiCV(Base):
|
||||||
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
||||||
from google.generativeai import client,GenerativeModel
|
from google.generativeai import client, GenerativeModel, GenerationConfig
|
||||||
client.configure(api_key=key)
|
client.configure(api_key=key)
|
||||||
_client = client.get_default_generative_client()
|
_client = client.get_default_generative_client()
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -258,6 +503,59 @@ class GeminiCV(Base):
|
|||||||
)
|
)
|
||||||
return res.text,res.usage_metadata.total_token_count
|
return res.text,res.usage_metadata.total_token_count
|
||||||
|
|
||||||
|
def chat(self, system, history, gen_conf, image=""):
|
||||||
|
if system:
|
||||||
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
try:
|
||||||
|
for his in history:
|
||||||
|
if his["role"] == "assistant":
|
||||||
|
his["role"] = "model"
|
||||||
|
his["parts"] = [his["content"]]
|
||||||
|
his.pop("content")
|
||||||
|
if his["role"] == "user":
|
||||||
|
his["parts"] = [his["content"]]
|
||||||
|
his.pop("content")
|
||||||
|
history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
|
||||||
|
|
||||||
|
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
||||||
|
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
|
||||||
|
top_p=gen_conf.get("top_p", 0.7)))
|
||||||
|
|
||||||
|
ans = response.text
|
||||||
|
return ans, response.usage_metadata.total_token_count
|
||||||
|
except Exception as e:
|
||||||
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
def chat_streamly(self, system, history, gen_conf, image=""):
|
||||||
|
if system:
|
||||||
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
|
||||||
|
ans = ""
|
||||||
|
tk_count = 0
|
||||||
|
try:
|
||||||
|
for his in history:
|
||||||
|
if his["role"] == "assistant":
|
||||||
|
his["role"] = "model"
|
||||||
|
his["parts"] = [his["content"]]
|
||||||
|
his.pop("content")
|
||||||
|
if his["role"] == "user":
|
||||||
|
his["parts"] = [his["content"]]
|
||||||
|
his.pop("content")
|
||||||
|
history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
|
||||||
|
|
||||||
|
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
||||||
|
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
|
||||||
|
top_p=gen_conf.get("top_p", 0.7)), stream=True)
|
||||||
|
|
||||||
|
for resp in response:
|
||||||
|
if not resp.text: continue
|
||||||
|
ans += resp.text
|
||||||
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
|
yield response._chunks[-1].usage_metadata.total_token_count
|
||||||
|
|
||||||
|
|
||||||
class OpenRouterCV(Base):
|
class OpenRouterCV(Base):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -46,7 +46,7 @@ const LlmSettingItems = ({ prefix, formItemLayout = {} }: IProps) => {
|
|||||||
{...formItemLayout}
|
{...formItemLayout}
|
||||||
rules={[{ required: true, message: t('modelMessage') }]}
|
rules={[{ required: true, message: t('modelMessage') }]}
|
||||||
>
|
>
|
||||||
<Select options={modelOptions[LlmModelType.Chat]} showSearch />
|
<Select options={[...modelOptions[LlmModelType.Chat], ...modelOptions[LlmModelType.Image2text],]} showSearch/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Divider></Divider>
|
<Divider></Divider>
|
||||||
<Form.Item
|
<Form.Item
|
||||||
|
Loading…
x
Reference in New Issue
Block a user