solve knowledgegraph issue when calling gemini model (#2738)

### What problem does this PR solve?
#2720

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
JobSmithManipulation 2024-10-08 18:27:04 +08:00 committed by GitHub
parent d92acdcf1d
commit 16472eb3ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -62,7 +62,7 @@ class Base(ABC):
stream=True, stream=True,
**gen_conf) **gen_conf)
for resp in response: for resp in response:
if not resp.choices:continue if not resp.choices: continue
if not resp.choices[0].delta.content: if not resp.choices[0].delta.content:
resp.choices[0].delta.content = "" resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content ans += resp.choices[0].delta.content
@ -72,7 +72,7 @@ class Base(ABC):
+ num_tokens_from_string(resp.choices[0].delta.content) + num_tokens_from_string(resp.choices[0].delta.content)
) )
if not hasattr(resp, "usage") or not resp.usage if not hasattr(resp, "usage") or not resp.usage
else resp.usage.get("total_tokens",total_tokens) else resp.usage.get("total_tokens", total_tokens)
) )
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english( ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
@ -87,13 +87,13 @@ class Base(ABC):
class GptTurbo(Base): class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1" if not base_url: base_url = "https://api.openai.com/v1"
super().__init__(key, model_name, base_url) super().__init__(key, model_name, base_url)
class MoonshotChat(Base): class MoonshotChat(Base):
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"): def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
if not base_url: base_url="https://api.moonshot.cn/v1" if not base_url: base_url = "https://api.moonshot.cn/v1"
super().__init__(key, model_name, base_url) super().__init__(key, model_name, base_url)
@ -108,7 +108,7 @@ class XinferenceChat(Base):
class DeepSeekChat(Base): class DeepSeekChat(Base):
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"): def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
if not base_url: base_url="https://api.deepseek.com/v1" if not base_url: base_url = "https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url) super().__init__(key, model_name, base_url)
@ -178,7 +178,7 @@ class BaiChuanChat(Base):
stream=True, stream=True,
**self._format_params(gen_conf)) **self._format_params(gen_conf))
for resp in response: for resp in response:
if not resp.choices:continue if not resp.choices: continue
if not resp.choices[0].delta.content: if not resp.choices[0].delta.content:
resp.choices[0].delta.content = "" resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content ans += resp.choices[0].delta.content
@ -252,7 +252,8 @@ class QWenChat(Base):
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans yield ans
else: else:
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**" yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
"Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
except Exception as e: except Exception as e:
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
@ -298,7 +299,7 @@ class ZhipuChat(Base):
**gen_conf **gen_conf
) )
for resp in response: for resp in response:
if not resp.choices[0].delta.content:continue if not resp.choices[0].delta.content: continue
delta = resp.choices[0].delta.content delta = resp.choices[0].delta.content
ans += delta ans += delta
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
@ -411,7 +412,7 @@ class LocalLLM(Base):
self.client = Client(port=12345, protocol="grpc", asyncio=True) self.client = Client(port=12345, protocol="grpc", asyncio=True)
def _prepare_prompt(self, system, history, gen_conf): def _prepare_prompt(self, system, history, gen_conf):
from rag.svr.jina_server import Prompt,Generation from rag.svr.jina_server import Prompt, Generation
if system: if system:
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
@ -419,7 +420,7 @@ class LocalLLM(Base):
return Prompt(message=history, gen_conf=gen_conf) return Prompt(message=history, gen_conf=gen_conf)
def _stream_response(self, endpoint, prompt): def _stream_response(self, endpoint, prompt):
from rag.svr.jina_server import Prompt,Generation from rag.svr.jina_server import Prompt, Generation
answer = "" answer = ""
try: try:
res = self.client.stream_doc( res = self.client.stream_doc(
@ -583,7 +584,7 @@ class MistralChat(Base):
messages=history, messages=history,
**gen_conf) **gen_conf)
for resp in response: for resp in response:
if not resp.choices or not resp.choices[0].delta.content:continue if not resp.choices or not resp.choices[0].delta.content: continue
ans += resp.choices[0].delta.content ans += resp.choices[0].delta.content
total_tokens += 1 total_tokens += 1
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
@ -620,9 +621,8 @@ class BedrockChat(Base):
gen_conf["topP"] = gen_conf["top_p"] gen_conf["topP"] = gen_conf["top_p"]
_ = gen_conf.pop("top_p") _ = gen_conf.pop("top_p")
for item in history: for item in history:
if not isinstance(item["content"],list) and not isinstance(item["content"],tuple): if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
item["content"] = [{"text":item["content"]}] item["content"] = [{"text": item["content"]}]
try: try:
# Send the message to the model, using a basic inference configuration. # Send the message to the model, using a basic inference configuration.
@ -630,7 +630,7 @@ class BedrockChat(Base):
modelId=self.model_name, modelId=self.model_name,
messages=history, messages=history,
inferenceConfig=gen_conf, inferenceConfig=gen_conf,
system=[{"text": (system if system else "Answer the user's message.")}] , system=[{"text": (system if system else "Answer the user's message.")}],
) )
# Extract and print the response text. # Extract and print the response text.
@ -652,8 +652,8 @@ class BedrockChat(Base):
gen_conf["topP"] = gen_conf["top_p"] gen_conf["topP"] = gen_conf["top_p"]
_ = gen_conf.pop("top_p") _ = gen_conf.pop("top_p")
for item in history: for item in history:
if not isinstance(item["content"],list) and not isinstance(item["content"],tuple): if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
item["content"] = [{"text":item["content"]}] item["content"] = [{"text": item["content"]}]
if self.model_name.split('.')[0] == 'ai21': if self.model_name.split('.')[0] == 'ai21':
try: try:
@ -693,8 +693,8 @@ class BedrockChat(Base):
class GeminiChat(Base): class GeminiChat(Base):
def __init__(self, key, model_name,base_url=None): def __init__(self, key, model_name, base_url=None):
from google.generativeai import client,GenerativeModel from google.generativeai import client, GenerativeModel
client.configure(api_key=key) client.configure(api_key=key)
_client = client.get_default_generative_client() _client = client.get_default_generative_client()
@ -702,8 +702,7 @@ class GeminiChat(Base):
self.model = GenerativeModel(model_name=self.model_name) self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client self.model._client = _client
def chat(self, system, history, gen_conf):
def chat(self,system,history,gen_conf):
from google.generativeai.types import content_types from google.generativeai.types import content_types
if system: if system:
@ -717,7 +716,9 @@ class GeminiChat(Base):
for item in history: for item in history:
if 'role' in item and item['role'] == 'assistant': if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model' item['role'] = 'model'
if 'content' in item : if 'role' in item and item['role'] == 'system':
item['role'] = 'user'
if 'content' in item:
item['parts'] = item.pop('content') item['parts'] = item.pop('content')
try: try:
@ -742,13 +743,13 @@ class GeminiChat(Base):
for item in history: for item in history:
if 'role' in item and item['role'] == 'assistant': if 'role' in item and item['role'] == 'assistant':
item['role'] = 'model' item['role'] = 'model'
if 'content' in item : if 'content' in item:
item['parts'] = item.pop('content') item['parts'] = item.pop('content')
ans = "" ans = ""
try: try:
response = self.model.generate_content( response = self.model.generate_content(
history, history,
generation_config=gen_conf,stream=True) generation_config=gen_conf, stream=True)
for resp in response: for resp in response:
ans += resp.text ans += resp.text
yield ans yield ans
@ -760,7 +761,7 @@ class GeminiChat(Base):
class GroqChat: class GroqChat:
def __init__(self, key, model_name,base_url=''): def __init__(self, key, model_name, base_url=''):
self.client = Groq(api_key=key) self.client = Groq(api_key=key)
self.model_name = model_name self.model_name = model_name
@ -942,7 +943,7 @@ class CoHereChat(Base):
class LeptonAIChat(Base): class LeptonAIChat(Base):
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
if not base_url: if not base_url:
base_url = os.path.join("https://"+model_name+".lepton.run","api","v1") base_url = os.path.join("https://" + model_name + ".lepton.run", "api", "v1")
super().__init__(key, model_name, base_url) super().__init__(key, model_name, base_url)
@ -1058,7 +1059,7 @@ class HunyuanChat(Base):
) )
_gen_conf = {} _gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items() } for item in history] _history = [{k.capitalize(): v for k, v in item.items()} for item in history]
if system: if system:
_history.insert(0, {"Role": "system", "Content": system}) _history.insert(0, {"Role": "system", "Content": system})
if "temperature" in gen_conf: if "temperature" in gen_conf:
@ -1084,7 +1085,7 @@ class HunyuanChat(Base):
) )
_gen_conf = {} _gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items() } for item in history] _history = [{k.capitalize(): v for k, v in item.items()} for item in history]
if system: if system:
_history.insert(0, {"Role": "system", "Content": system}) _history.insert(0, {"Role": "system", "Content": system})
@ -1141,9 +1142,9 @@ class BaiduYiyanChat(Base):
import qianfan import qianfan
key = json.loads(key) key = json.loads(key)
ak = key.get("yiyan_ak","") ak = key.get("yiyan_ak", "")
sk = key.get("yiyan_sk","") sk = key.get("yiyan_sk", "")
self.client = qianfan.ChatCompletion(ak=ak,sk=sk) self.client = qianfan.ChatCompletion(ak=ak, sk=sk)
self.model_name = model_name.lower() self.model_name = model_name.lower()
self.system = "" self.system = ""
@ -1151,7 +1152,8 @@ class BaiduYiyanChat(Base):
if system: if system:
self.system = system self.system = system
gen_conf["penalty_score"] = ( gen_conf["penalty_score"] = (
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2 (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
0)) / 2
) + 1 ) + 1
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
gen_conf["max_output_tokens"] = gen_conf["max_tokens"] gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
@ -1174,7 +1176,8 @@ class BaiduYiyanChat(Base):
if system: if system:
self.system = system self.system = system
gen_conf["penalty_score"] = ( gen_conf["penalty_score"] = (
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2 (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
0)) / 2
) + 1 ) + 1
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
gen_conf["max_output_tokens"] = gen_conf["max_tokens"] gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
@ -1415,4 +1418,3 @@ class GoogleChat(Base):
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
yield response._chunks[-1].usage_metadata.total_token_count yield response._chunks[-1].usage_metadata.total_token_count