From 41d4c5b424aa17b854104bc6a0664ec49a9a6a6a Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 2 Oct 2023 10:19:26 +0800 Subject: [PATCH] fix: count down thread in completion db not commit (#1267) --- api/core/conversation_message_task.py | 4 +-- .../models/llm/openai_model.py | 8 +---- api/services/completion_service.py | 36 +++++++++---------- 3 files changed, 21 insertions(+), 27 deletions(-) diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 49c9cba082..ae98f91a88 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -94,7 +94,7 @@ class ConversationMessageTask: if not self.conversation: self.is_new_conversation = True self.conversation = Conversation( - app_id=self.app_model_config.app_id, + app_id=self.app.id, app_model_config_id=self.app_model_config.id, model_provider=self.provider_name, model_id=self.model_name, @@ -115,7 +115,7 @@ class ConversationMessageTask: db.session.commit() self.message = Message( - app_id=self.app_model_config.app_id, + app_id=self.app.id, model_provider=self.provider_name, model_id=self.model_name, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, diff --git a/api/core/model_providers/models/llm/openai_model.py b/api/core/model_providers/models/llm/openai_model.py index d88d3af555..d65efc9641 100644 --- a/api/core/model_providers/models/llm/openai_model.py +++ b/api/core/model_providers/models/llm/openai_model.py @@ -106,13 +106,7 @@ class OpenAIModel(BaseLLM): raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.") prompts = self._get_prompt_from_messages(messages) - - try: - return self._client.generate([prompts], stop, callbacks) - finally: - thread_context = api_requestor._thread_context - if hasattr(thread_context, "session") and thread_context.session: - thread_context.session.close() + return self._client.generate([prompts], stop, callbacks) def get_num_tokens(self, messages: List[PromptMessage]) -> int: """ diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 252c85e93f..d8ffd02ed4 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -155,7 +155,7 @@ class CompletionService: generate_worker_thread.start() # wait for 10 minutes to close the thread - cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id) + cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id) return cls.compact_response(pubsub, streaming) @@ -210,25 +210,26 @@ class CompletionService: db.session.commit() @classmethod - def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread: + def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, user, generate_task_id) -> threading.Thread: # wait for 10 minutes to close the thread timeout = 600 def close_pubsub(): - sleep_iterations = 0 - while sleep_iterations < timeout and worker_thread.is_alive(): - if sleep_iterations > 0 and sleep_iterations % 10 == 0: - PubHandler.ping(user, generate_task_id) + with flask_app.app_context(): + sleep_iterations = 0 + while sleep_iterations < timeout and worker_thread.is_alive(): + if sleep_iterations > 0 and sleep_iterations % 10 == 0: + PubHandler.ping(user, generate_task_id) - time.sleep(1) - sleep_iterations += 1 + time.sleep(1) + sleep_iterations += 1 - if worker_thread.is_alive(): - PubHandler.stop(user, generate_task_id) - try: - pubsub.close() - except: - pass + if worker_thread.is_alive(): + PubHandler.stop(user, generate_task_id) + try: + pubsub.close() + except: + pass countdown_thread = threading.Thread(target=close_pubsub) countdown_thread.start() @@ -288,7 +289,7 @@ class CompletionService: generate_worker_thread.start() - cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id) + cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id) return cls.compact_response(pubsub, streaming) @@ -313,15 +314,14 @@ class CompletionService: except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError) as e: - db.session.rollback() PubHandler.pub_error(user, generate_task_id, e) except LLMAuthorizationError: - db.session.rollback() PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided')) except Exception as e: - db.session.rollback() logging.exception("Unknown Error in completion") PubHandler.pub_error(user, generate_task_id, e) + finally: + db.session.commit() @classmethod def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):