mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-20 19:59:11 +08:00
Support streaming output for OpenAI o1-preview and o1-mini (#10890)
This commit is contained in:
parent
1be8365684
commit
4d6b45427c
@ -615,19 +615,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||||
|
|
||||||
# o1 compatibility
|
# o1 compatibility
|
||||||
block_as_stream = False
|
|
||||||
if model.startswith("o1"):
|
if model.startswith("o1"):
|
||||||
if "max_tokens" in model_parameters:
|
if "max_tokens" in model_parameters:
|
||||||
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
||||||
del model_parameters["max_tokens"]
|
del model_parameters["max_tokens"]
|
||||||
|
|
||||||
if stream:
|
|
||||||
block_as_stream = True
|
|
||||||
stream = False
|
|
||||||
|
|
||||||
if "stream_options" in extra_model_kwargs:
|
|
||||||
del extra_model_kwargs["stream_options"]
|
|
||||||
|
|
||||||
if "stop" in extra_model_kwargs:
|
if "stop" in extra_model_kwargs:
|
||||||
del extra_model_kwargs["stop"]
|
del extra_model_kwargs["stop"]
|
||||||
|
|
||||||
@ -644,47 +636,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
|
||||||
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||||
|
|
||||||
if block_as_stream:
|
|
||||||
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
|
||||||
|
|
||||||
return block_result
|
|
||||||
|
|
||||||
def _handle_chat_block_as_stream_response(
|
|
||||||
self,
|
|
||||||
block_result: LLMResult,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
) -> Generator[LLMResultChunk, None, None]:
|
|
||||||
"""
|
|
||||||
Handle llm chat response
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: credentials
|
|
||||||
:param response: response
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:param tools: tools for tool calling
|
|
||||||
:param stop: stop words
|
|
||||||
:return: llm response chunk generator
|
|
||||||
"""
|
|
||||||
text = block_result.message.content
|
|
||||||
text = cast(str, text)
|
|
||||||
|
|
||||||
if stop:
|
|
||||||
text = self.enforce_stop_tokens(text, stop)
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=block_result.model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
system_fingerprint=block_result.system_fingerprint,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=0,
|
|
||||||
message=AssistantPromptMessage(content=text),
|
|
||||||
finish_reason="stop",
|
|
||||||
usage=block_result.usage,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_chat_generate_response(
|
def _handle_chat_generate_response(
|
||||||
self,
|
self,
|
||||||
|
@ -45,18 +45,6 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
self._update_credential(model, credentials)
|
self._update_credential(model, credentials)
|
||||||
|
|
||||||
block_as_stream = False
|
|
||||||
if model.startswith("openai/o1"):
|
|
||||||
block_as_stream = True
|
|
||||||
stop = None
|
|
||||||
|
|
||||||
# invoke block as stream
|
|
||||||
if stream and block_as_stream:
|
|
||||||
return self._generate_block_as_stream(
|
|
||||||
model, credentials, prompt_messages, model_parameters, tools, stop, user
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
def _generate_block_as_stream(
|
def _generate_block_as_stream(
|
||||||
@ -69,9 +57,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
resp: LLMResult = super()._generate(
|
resp = super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, False, user)
|
||||||
model, credentials, prompt_messages, model_parameters, tools, stop, False, user
|
|
||||||
)
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user