mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-13 05:29:05 +08:00
solve knowledgegraph issue when calling gemini model (#2738)
### What problem does this PR solve? #2720 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
d92acdcf1d
commit
16472eb3ea
@ -23,7 +23,7 @@ from ollama import Client
|
|||||||
from rag.nlp import is_english
|
from rag.nlp import is_english
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
from groq import Groq
|
from groq import Groq
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -62,17 +62,17 @@ class Base(ABC):
|
|||||||
stream=True,
|
stream=True,
|
||||||
**gen_conf)
|
**gen_conf)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices:continue
|
if not resp.choices: continue
|
||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
resp.choices[0].delta.content = ""
|
resp.choices[0].delta.content = ""
|
||||||
ans += resp.choices[0].delta.content
|
ans += resp.choices[0].delta.content
|
||||||
total_tokens = (
|
total_tokens = (
|
||||||
(
|
(
|
||||||
total_tokens
|
total_tokens
|
||||||
+ num_tokens_from_string(resp.choices[0].delta.content)
|
+ num_tokens_from_string(resp.choices[0].delta.content)
|
||||||
)
|
)
|
||||||
if not hasattr(resp, "usage") or not resp.usage
|
if not hasattr(resp, "usage") or not resp.usage
|
||||||
else resp.usage.get("total_tokens",total_tokens)
|
else resp.usage.get("total_tokens", total_tokens)
|
||||||
)
|
)
|
||||||
if resp.choices[0].finish_reason == "length":
|
if resp.choices[0].finish_reason == "length":
|
||||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||||
@ -87,13 +87,13 @@ class Base(ABC):
|
|||||||
|
|
||||||
class GptTurbo(Base):
|
class GptTurbo(Base):
|
||||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
||||||
if not base_url: base_url="https://api.openai.com/v1"
|
if not base_url: base_url = "https://api.openai.com/v1"
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
class MoonshotChat(Base):
|
class MoonshotChat(Base):
|
||||||
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
|
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
|
||||||
if not base_url: base_url="https://api.moonshot.cn/v1"
|
if not base_url: base_url = "https://api.moonshot.cn/v1"
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ class XinferenceChat(Base):
|
|||||||
|
|
||||||
class DeepSeekChat(Base):
|
class DeepSeekChat(Base):
|
||||||
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
|
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
|
||||||
if not base_url: base_url="https://api.deepseek.com/v1"
|
if not base_url: base_url = "https://api.deepseek.com/v1"
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
@ -178,14 +178,14 @@ class BaiChuanChat(Base):
|
|||||||
stream=True,
|
stream=True,
|
||||||
**self._format_params(gen_conf))
|
**self._format_params(gen_conf))
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices:continue
|
if not resp.choices: continue
|
||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
resp.choices[0].delta.content = ""
|
resp.choices[0].delta.content = ""
|
||||||
ans += resp.choices[0].delta.content
|
ans += resp.choices[0].delta.content
|
||||||
total_tokens = (
|
total_tokens = (
|
||||||
(
|
(
|
||||||
total_tokens
|
total_tokens
|
||||||
+ num_tokens_from_string(resp.choices[0].delta.content)
|
+ num_tokens_from_string(resp.choices[0].delta.content)
|
||||||
)
|
)
|
||||||
if not hasattr(resp, "usage")
|
if not hasattr(resp, "usage")
|
||||||
else resp.usage["total_tokens"]
|
else resp.usage["total_tokens"]
|
||||||
@ -252,7 +252,8 @@ class QWenChat(Base):
|
|||||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
yield ans
|
yield ans
|
||||||
else:
|
else:
|
||||||
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
|
||||||
|
"Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
@ -298,7 +299,7 @@ class ZhipuChat(Base):
|
|||||||
**gen_conf
|
**gen_conf
|
||||||
)
|
)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices[0].delta.content:continue
|
if not resp.choices[0].delta.content: continue
|
||||||
delta = resp.choices[0].delta.content
|
delta = resp.choices[0].delta.content
|
||||||
ans += delta
|
ans += delta
|
||||||
if resp.choices[0].finish_reason == "length":
|
if resp.choices[0].finish_reason == "length":
|
||||||
@ -411,7 +412,7 @@ class LocalLLM(Base):
|
|||||||
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
||||||
|
|
||||||
def _prepare_prompt(self, system, history, gen_conf):
|
def _prepare_prompt(self, system, history, gen_conf):
|
||||||
from rag.svr.jina_server import Prompt,Generation
|
from rag.svr.jina_server import Prompt, Generation
|
||||||
if system:
|
if system:
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
@ -419,7 +420,7 @@ class LocalLLM(Base):
|
|||||||
return Prompt(message=history, gen_conf=gen_conf)
|
return Prompt(message=history, gen_conf=gen_conf)
|
||||||
|
|
||||||
def _stream_response(self, endpoint, prompt):
|
def _stream_response(self, endpoint, prompt):
|
||||||
from rag.svr.jina_server import Prompt,Generation
|
from rag.svr.jina_server import Prompt, Generation
|
||||||
answer = ""
|
answer = ""
|
||||||
try:
|
try:
|
||||||
res = self.client.stream_doc(
|
res = self.client.stream_doc(
|
||||||
@ -463,10 +464,10 @@ class VolcEngineChat(Base):
|
|||||||
|
|
||||||
class MiniMaxChat(Base):
|
class MiniMaxChat(Base):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
key,
|
key,
|
||||||
model_name,
|
model_name,
|
||||||
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
|
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
|
||||||
):
|
):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||||||
@ -583,7 +584,7 @@ class MistralChat(Base):
|
|||||||
messages=history,
|
messages=history,
|
||||||
**gen_conf)
|
**gen_conf)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices or not resp.choices[0].delta.content:continue
|
if not resp.choices or not resp.choices[0].delta.content: continue
|
||||||
ans += resp.choices[0].delta.content
|
ans += resp.choices[0].delta.content
|
||||||
total_tokens += 1
|
total_tokens += 1
|
||||||
if resp.choices[0].finish_reason == "length":
|
if resp.choices[0].finish_reason == "length":
|
||||||
@ -620,9 +621,8 @@ class BedrockChat(Base):
|
|||||||
gen_conf["topP"] = gen_conf["top_p"]
|
gen_conf["topP"] = gen_conf["top_p"]
|
||||||
_ = gen_conf.pop("top_p")
|
_ = gen_conf.pop("top_p")
|
||||||
for item in history:
|
for item in history:
|
||||||
if not isinstance(item["content"],list) and not isinstance(item["content"],tuple):
|
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
|
||||||
item["content"] = [{"text":item["content"]}]
|
item["content"] = [{"text": item["content"]}]
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Send the message to the model, using a basic inference configuration.
|
# Send the message to the model, using a basic inference configuration.
|
||||||
@ -630,9 +630,9 @@ class BedrockChat(Base):
|
|||||||
modelId=self.model_name,
|
modelId=self.model_name,
|
||||||
messages=history,
|
messages=history,
|
||||||
inferenceConfig=gen_conf,
|
inferenceConfig=gen_conf,
|
||||||
system=[{"text": (system if system else "Answer the user's message.")}] ,
|
system=[{"text": (system if system else "Answer the user's message.")}],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract and print the response text.
|
# Extract and print the response text.
|
||||||
ans = response["output"]["message"]["content"][0]["text"]
|
ans = response["output"]["message"]["content"][0]["text"]
|
||||||
return ans, num_tokens_from_string(ans)
|
return ans, num_tokens_from_string(ans)
|
||||||
@ -652,9 +652,9 @@ class BedrockChat(Base):
|
|||||||
gen_conf["topP"] = gen_conf["top_p"]
|
gen_conf["topP"] = gen_conf["top_p"]
|
||||||
_ = gen_conf.pop("top_p")
|
_ = gen_conf.pop("top_p")
|
||||||
for item in history:
|
for item in history:
|
||||||
if not isinstance(item["content"],list) and not isinstance(item["content"],tuple):
|
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
|
||||||
item["content"] = [{"text":item["content"]}]
|
item["content"] = [{"text": item["content"]}]
|
||||||
|
|
||||||
if self.model_name.split('.')[0] == 'ai21':
|
if self.model_name.split('.')[0] == 'ai21':
|
||||||
try:
|
try:
|
||||||
response = self.client.converse(
|
response = self.client.converse(
|
||||||
@ -684,7 +684,7 @@ class BedrockChat(Base):
|
|||||||
if "contentBlockDelta" in resp:
|
if "contentBlockDelta" in resp:
|
||||||
ans += resp["contentBlockDelta"]["delta"]["text"]
|
ans += resp["contentBlockDelta"]["delta"]["text"]
|
||||||
yield ans
|
yield ans
|
||||||
|
|
||||||
except (ClientError, Exception) as e:
|
except (ClientError, Exception) as e:
|
||||||
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
|
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
|
||||||
|
|
||||||
@ -693,22 +693,21 @@ class BedrockChat(Base):
|
|||||||
|
|
||||||
class GeminiChat(Base):
|
class GeminiChat(Base):
|
||||||
|
|
||||||
def __init__(self, key, model_name,base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
from google.generativeai import client,GenerativeModel
|
from google.generativeai import client, GenerativeModel
|
||||||
|
|
||||||
client.configure(api_key=key)
|
client.configure(api_key=key)
|
||||||
_client = client.get_default_generative_client()
|
_client = client.get_default_generative_client()
|
||||||
self.model_name = 'models/' + model_name
|
self.model_name = 'models/' + model_name
|
||||||
self.model = GenerativeModel(model_name=self.model_name)
|
self.model = GenerativeModel(model_name=self.model_name)
|
||||||
self.model._client = _client
|
self.model._client = _client
|
||||||
|
|
||||||
|
def chat(self, system, history, gen_conf):
|
||||||
def chat(self,system,history,gen_conf):
|
|
||||||
from google.generativeai.types import content_types
|
from google.generativeai.types import content_types
|
||||||
|
|
||||||
if system:
|
if system:
|
||||||
self.model._system_instruction = content_types.to_content(system)
|
self.model._system_instruction = content_types.to_content(system)
|
||||||
|
|
||||||
if 'max_tokens' in gen_conf:
|
if 'max_tokens' in gen_conf:
|
||||||
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
|
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
|
||||||
for k in list(gen_conf.keys()):
|
for k in list(gen_conf.keys()):
|
||||||
@ -717,9 +716,11 @@ class GeminiChat(Base):
|
|||||||
for item in history:
|
for item in history:
|
||||||
if 'role' in item and item['role'] == 'assistant':
|
if 'role' in item and item['role'] == 'assistant':
|
||||||
item['role'] = 'model'
|
item['role'] = 'model'
|
||||||
if 'content' in item :
|
if 'role' in item and item['role'] == 'system':
|
||||||
|
item['role'] = 'user'
|
||||||
|
if 'content' in item:
|
||||||
item['parts'] = item.pop('content')
|
item['parts'] = item.pop('content')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.model.generate_content(
|
response = self.model.generate_content(
|
||||||
history,
|
history,
|
||||||
@ -731,7 +732,7 @@ class GeminiChat(Base):
|
|||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf):
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
from google.generativeai.types import content_types
|
from google.generativeai.types import content_types
|
||||||
|
|
||||||
if system:
|
if system:
|
||||||
self.model._system_instruction = content_types.to_content(system)
|
self.model._system_instruction = content_types.to_content(system)
|
||||||
if 'max_tokens' in gen_conf:
|
if 'max_tokens' in gen_conf:
|
||||||
@ -742,13 +743,13 @@ class GeminiChat(Base):
|
|||||||
for item in history:
|
for item in history:
|
||||||
if 'role' in item and item['role'] == 'assistant':
|
if 'role' in item and item['role'] == 'assistant':
|
||||||
item['role'] = 'model'
|
item['role'] = 'model'
|
||||||
if 'content' in item :
|
if 'content' in item:
|
||||||
item['parts'] = item.pop('content')
|
item['parts'] = item.pop('content')
|
||||||
ans = ""
|
ans = ""
|
||||||
try:
|
try:
|
||||||
response = self.model.generate_content(
|
response = self.model.generate_content(
|
||||||
history,
|
history,
|
||||||
generation_config=gen_conf,stream=True)
|
generation_config=gen_conf, stream=True)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
ans += resp.text
|
ans += resp.text
|
||||||
yield ans
|
yield ans
|
||||||
@ -756,11 +757,11 @@ class GeminiChat(Base):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
yield response._chunks[-1].usage_metadata.total_token_count
|
yield response._chunks[-1].usage_metadata.total_token_count
|
||||||
|
|
||||||
|
|
||||||
class GroqChat:
|
class GroqChat:
|
||||||
def __init__(self, key, model_name,base_url=''):
|
def __init__(self, key, model_name, base_url=''):
|
||||||
self.client = Groq(api_key=key)
|
self.client = Groq(api_key=key)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
@ -942,7 +943,7 @@ class CoHereChat(Base):
|
|||||||
class LeptonAIChat(Base):
|
class LeptonAIChat(Base):
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = os.path.join("https://"+model_name+".lepton.run","api","v1")
|
base_url = os.path.join("https://" + model_name + ".lepton.run", "api", "v1")
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
@ -1058,7 +1059,7 @@ class HunyuanChat(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
_gen_conf = {}
|
_gen_conf = {}
|
||||||
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
|
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
|
||||||
if system:
|
if system:
|
||||||
_history.insert(0, {"Role": "system", "Content": system})
|
_history.insert(0, {"Role": "system", "Content": system})
|
||||||
if "temperature" in gen_conf:
|
if "temperature" in gen_conf:
|
||||||
@ -1084,7 +1085,7 @@ class HunyuanChat(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
_gen_conf = {}
|
_gen_conf = {}
|
||||||
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
|
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
|
||||||
if system:
|
if system:
|
||||||
_history.insert(0, {"Role": "system", "Content": system})
|
_history.insert(0, {"Role": "system", "Content": system})
|
||||||
|
|
||||||
@ -1121,7 +1122,7 @@ class HunyuanChat(Base):
|
|||||||
|
|
||||||
class SparkChat(Base):
|
class SparkChat(Base):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
|
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
|
||||||
):
|
):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://spark-api-open.xf-yun.com/v1"
|
base_url = "https://spark-api-open.xf-yun.com/v1"
|
||||||
@ -1141,9 +1142,9 @@ class BaiduYiyanChat(Base):
|
|||||||
import qianfan
|
import qianfan
|
||||||
|
|
||||||
key = json.loads(key)
|
key = json.loads(key)
|
||||||
ak = key.get("yiyan_ak","")
|
ak = key.get("yiyan_ak", "")
|
||||||
sk = key.get("yiyan_sk","")
|
sk = key.get("yiyan_sk", "")
|
||||||
self.client = qianfan.ChatCompletion(ak=ak,sk=sk)
|
self.client = qianfan.ChatCompletion(ak=ak, sk=sk)
|
||||||
self.model_name = model_name.lower()
|
self.model_name = model_name.lower()
|
||||||
self.system = ""
|
self.system = ""
|
||||||
|
|
||||||
@ -1151,16 +1152,17 @@ class BaiduYiyanChat(Base):
|
|||||||
if system:
|
if system:
|
||||||
self.system = system
|
self.system = system
|
||||||
gen_conf["penalty_score"] = (
|
gen_conf["penalty_score"] = (
|
||||||
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
|
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
|
||||||
) + 1
|
0)) / 2
|
||||||
|
) + 1
|
||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
||||||
ans = ""
|
ans = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.do(
|
response = self.client.do(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=history,
|
messages=history,
|
||||||
system=self.system,
|
system=self.system,
|
||||||
**gen_conf
|
**gen_conf
|
||||||
).body
|
).body
|
||||||
@ -1174,8 +1176,9 @@ class BaiduYiyanChat(Base):
|
|||||||
if system:
|
if system:
|
||||||
self.system = system
|
self.system = system
|
||||||
gen_conf["penalty_score"] = (
|
gen_conf["penalty_score"] = (
|
||||||
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
|
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
|
||||||
) + 1
|
0)) / 2
|
||||||
|
) + 1
|
||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
||||||
ans = ""
|
ans = ""
|
||||||
@ -1183,8 +1186,8 @@ class BaiduYiyanChat(Base):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.do(
|
response = self.client.do(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=history,
|
messages=history,
|
||||||
system=self.system,
|
system=self.system,
|
||||||
stream=True,
|
stream=True,
|
||||||
**gen_conf
|
**gen_conf
|
||||||
@ -1415,4 +1418,3 @@ class GoogleChat(Base):
|
|||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
yield response._chunks[-1].usage_metadata.total_token_count
|
yield response._chunks[-1].usage_metadata.total_token_count
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user