From de584807e1f7b9faea04cfa0e5c7f54c8cd0e1a1 Mon Sep 17 00:00:00 2001 From: Chenhe Gu Date: Fri, 5 Jan 2024 15:03:54 +0800 Subject: [PATCH] fix streaming (#1944) --- .../openai_api_compatible/llm/llm.py | 12 +++--------- .../openai_api_compatible/test_llm.py | 12 ++++++------ .../test_text_embedding.py | 18 +++++++----------- 3 files changed, 16 insertions(+), 26 deletions(-) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 71c15c7f88..cf694b940b 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -337,9 +337,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ) ) - for chunk in response.iter_content(chunk_size=2048): + for chunk in response.iter_lines(decode_unicode=True, delimiter='\n\n'): if chunk: - decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip() + decoded_chunk = chunk.strip().lstrip('data: ').lstrip() chunk_json = None try: @@ -356,7 +356,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): continue choice = chunk_json['choices'][0] - chunk_index = choice['index'] if 'index' in choice else chunk_index + chunk_index += 1 if 'delta' in choice: delta = choice['delta'] @@ -408,12 +408,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message=assistant_prompt_message, ) ) - else: - yield create_final_llm_result_chunk( - index=chunk_index + 1, - message=AssistantPromptMessage(content=""), - finish_reason="End of stream." - ) chunk_index += 1 diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py index 1d53b7a3f0..8be19b7c6c 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py @@ -22,7 +22,7 @@ def test_validate_credentials(): model='mistralai/Mixtral-8x7B-Instruct-v0.1', credentials={ 'api_key': 'invalid_key', - 'endpoint_url': 'https://api.together.xyz/v1/chat/completions', + 'endpoint_url': 'https://api.together.xyz/v1/', 'mode': 'chat' } ) @@ -31,7 +31,7 @@ def test_validate_credentials(): model='mistralai/Mixtral-8x7B-Instruct-v0.1', credentials={ 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/chat/completions', + 'endpoint_url': 'https://api.together.xyz/v1/', 'mode': 'chat' } ) @@ -43,7 +43,7 @@ def test_invoke_model(): model='mistralai/Mixtral-8x7B-Instruct-v0.1', credentials={ 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/completions', + 'endpoint_url': 'https://api.together.xyz/v1/', 'mode': 'completion' }, prompt_messages=[ @@ -74,7 +74,7 @@ def test_invoke_stream_model(): model='mistralai/Mixtral-8x7B-Instruct-v0.1', credentials={ 'api_key': os.environ.get('TOGETHER_API_KEY'), - 'endpoint_url': 'https://api.together.xyz/v1/chat/completions', + 'endpoint_url': 'https://api.together.xyz/v1/', 'mode': 'chat' }, prompt_messages=[ @@ -110,7 +110,7 @@ def test_invoke_chat_model_with_tools(): model='gpt-3.5-turbo', credentials={ 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/chat/completions', + 'endpoint_url': 'https://api.openai.com/v1/', 'mode': 'chat' }, prompt_messages=[ @@ -165,7 +165,7 @@ def test_get_num_tokens(): model='mistralai/Mixtral-8x7B-Instruct-v0.1', credentials={ 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/chat/completions' + 'endpoint_url': 'https://api.openai.com/v1/' }, prompt_messages=[ SystemPromptMessage( diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py index d8c8a26a10..fbaa322881 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py @@ -18,9 +18,8 @@ def test_validate_credentials(): model='text-embedding-ada-002', credentials={ 'api_key': 'invalid_key', - 'endpoint_url': 'https://api.openai.com/v1/embeddings', - 'context_size': 8184, - 'max_chunks': 32 + 'endpoint_url': 'https://api.openai.com/v1/', + 'context_size': 8184 } ) @@ -29,9 +28,8 @@ def test_validate_credentials(): model='text-embedding-ada-002', credentials={ 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/embeddings', - 'context_size': 8184, - 'max_chunks': 32 + 'endpoint_url': 'https://api.openai.com/v1/', + 'context_size': 8184 } ) @@ -43,9 +41,8 @@ def test_invoke_model(): model='text-embedding-ada-002', credentials={ 'api_key': os.environ.get('OPENAI_API_KEY'), - 'endpoint_url': 'https://api.openai.com/v1/embeddings', - 'context_size': 8184, - 'max_chunks': 32 + 'endpoint_url': 'https://api.openai.com/v1/', + 'context_size': 8184 }, texts=[ "hello", @@ -67,8 +64,7 @@ def test_get_num_tokens(): credentials={ 'api_key': os.environ.get('OPENAI_API_KEY'), 'endpoint_url': 'https://api.openai.com/v1/embeddings', - 'context_size': 8184, - 'max_chunks': 32 + 'context_size': 8184 }, texts=[ "hello",