mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 21:46:04 +08:00
feat: support openai stream usage (#4140)
This commit is contained in:
parent
e7fe7ec0f6
commit
d5d8b98d82
@ -378,6 +378,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
if user:
|
if user:
|
||||||
extra_model_kwargs['user'] = user
|
extra_model_kwargs['user'] = user
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
extra_model_kwargs['stream_options'] = {
|
||||||
|
"include_usage": True
|
||||||
|
}
|
||||||
|
|
||||||
# text completion model
|
# text completion model
|
||||||
response = client.completions.create(
|
response = client.completions.create(
|
||||||
prompt=prompt_messages[0].content,
|
prompt=prompt_messages[0].content,
|
||||||
@ -446,8 +451,24 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
:return: llm response chunk generator result
|
:return: llm response chunk generator result
|
||||||
"""
|
"""
|
||||||
full_text = ''
|
full_text = ''
|
||||||
|
prompt_tokens = 0
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
|
final_chunk = LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=''),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
|
if chunk.usage:
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = chunk.usage.prompt_tokens
|
||||||
|
completion_tokens = chunk.usage.completion_tokens
|
||||||
continue
|
continue
|
||||||
|
|
||||||
delta = chunk.choices[0]
|
delta = chunk.choices[0]
|
||||||
@ -464,20 +485,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
full_text += text
|
full_text += text
|
||||||
|
|
||||||
if delta.finish_reason is not None:
|
if delta.finish_reason is not None:
|
||||||
# calculate num tokens
|
final_chunk = LLMResultChunk(
|
||||||
if chunk.usage:
|
|
||||||
# transform usage
|
|
||||||
prompt_tokens = chunk.usage.prompt_tokens
|
|
||||||
completion_tokens = chunk.usage.completion_tokens
|
|
||||||
else:
|
|
||||||
# calculate num tokens
|
|
||||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
|
||||||
completion_tokens = self._num_tokens_from_string(model, full_text)
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=chunk.model,
|
model=chunk.model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
system_fingerprint=chunk.system_fingerprint,
|
system_fingerprint=chunk.system_fingerprint,
|
||||||
@ -485,7 +493,6 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
index=delta.index,
|
index=delta.index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
finish_reason=delta.finish_reason,
|
finish_reason=delta.finish_reason,
|
||||||
usage=usage
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -499,6 +506,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not prompt_tokens:
|
||||||
|
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||||
|
|
||||||
|
if not completion_tokens:
|
||||||
|
completion_tokens = self._num_tokens_from_string(model, full_text)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
final_chunk.delta.usage = usage
|
||||||
|
|
||||||
|
yield final_chunk
|
||||||
|
|
||||||
def _chat_generate(self, model: str, credentials: dict,
|
def _chat_generate(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
@ -531,6 +551,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
model_parameters["response_format"] = response_format
|
model_parameters["response_format"] = response_format
|
||||||
|
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
extra_model_kwargs = {}
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
@ -547,6 +568,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
if user:
|
if user:
|
||||||
extra_model_kwargs['user'] = user
|
extra_model_kwargs['user'] = user
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
extra_model_kwargs['stream_options'] = {
|
||||||
|
'include_usage': True
|
||||||
|
}
|
||||||
|
|
||||||
# clear illegal prompt messages
|
# clear illegal prompt messages
|
||||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||||
|
|
||||||
@ -630,8 +656,24 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
full_assistant_content = ''
|
full_assistant_content = ''
|
||||||
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
|
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
|
||||||
|
prompt_tokens = 0
|
||||||
|
completion_tokens = 0
|
||||||
|
final_tool_calls = []
|
||||||
|
final_chunk = LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=''),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
|
if chunk.usage:
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = chunk.usage.prompt_tokens
|
||||||
|
completion_tokens = chunk.usage.completion_tokens
|
||||||
continue
|
continue
|
||||||
|
|
||||||
delta = chunk.choices[0]
|
delta = chunk.choices[0]
|
||||||
@ -667,6 +709,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||||
tool_calls = [function_call] if function_call else []
|
tool_calls = [function_call] if function_call else []
|
||||||
|
if tool_calls:
|
||||||
|
final_tool_calls.extend(tool_calls)
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
@ -677,19 +721,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
full_assistant_content += delta.delta.content if delta.delta.content else ''
|
full_assistant_content += delta.delta.content if delta.delta.content else ''
|
||||||
|
|
||||||
if has_finish_reason:
|
if has_finish_reason:
|
||||||
# calculate num tokens
|
final_chunk = LLMResultChunk(
|
||||||
prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
|
|
||||||
|
|
||||||
full_assistant_prompt_message = AssistantPromptMessage(
|
|
||||||
content=full_assistant_content,
|
|
||||||
tool_calls=tool_calls
|
|
||||||
)
|
|
||||||
completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message])
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=chunk.model,
|
model=chunk.model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
system_fingerprint=chunk.system_fingerprint,
|
system_fingerprint=chunk.system_fingerprint,
|
||||||
@ -697,7 +729,6 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
index=delta.index,
|
index=delta.index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
finish_reason=delta.finish_reason,
|
finish_reason=delta.finish_reason,
|
||||||
usage=usage
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -711,6 +742,22 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not prompt_tokens:
|
||||||
|
prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
|
||||||
|
|
||||||
|
if not completion_tokens:
|
||||||
|
full_assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=full_assistant_content,
|
||||||
|
tool_calls=final_tool_calls
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message])
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
final_chunk.delta.usage = usage
|
||||||
|
|
||||||
|
yield final_chunk
|
||||||
|
|
||||||
def _extract_response_tool_calls(self,
|
def _extract_response_tool_calls(self,
|
||||||
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
|
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
|
||||||
-> list[AssistantPromptMessage.ToolCall]:
|
-> list[AssistantPromptMessage.ToolCall]:
|
||||||
|
@ -9,7 +9,7 @@ flask-restful~=0.3.10
|
|||||||
flask-cors~=4.0.0
|
flask-cors~=4.0.0
|
||||||
gunicorn~=22.0.0
|
gunicorn~=22.0.0
|
||||||
gevent~=23.9.1
|
gevent~=23.9.1
|
||||||
openai~=1.13.3
|
openai~=1.26.0
|
||||||
tiktoken~=0.6.0
|
tiktoken~=0.6.0
|
||||||
psycopg2-binary~=2.9.6
|
psycopg2-binary~=2.9.6
|
||||||
pycryptodome==3.19.1
|
pycryptodome==3.19.1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user