mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 06:29:03 +08:00
fix: count down thread in completion db not commit (#1267)
This commit is contained in:
parent
86a9dea428
commit
41d4c5b424
@ -94,7 +94,7 @@ class ConversationMessageTask:
|
|||||||
if not self.conversation:
|
if not self.conversation:
|
||||||
self.is_new_conversation = True
|
self.is_new_conversation = True
|
||||||
self.conversation = Conversation(
|
self.conversation = Conversation(
|
||||||
app_id=self.app_model_config.app_id,
|
app_id=self.app.id,
|
||||||
app_model_config_id=self.app_model_config.id,
|
app_model_config_id=self.app_model_config.id,
|
||||||
model_provider=self.provider_name,
|
model_provider=self.provider_name,
|
||||||
model_id=self.model_name,
|
model_id=self.model_name,
|
||||||
@ -115,7 +115,7 @@ class ConversationMessageTask:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
self.message = Message(
|
self.message = Message(
|
||||||
app_id=self.app_model_config.app_id,
|
app_id=self.app.id,
|
||||||
model_provider=self.provider_name,
|
model_provider=self.provider_name,
|
||||||
model_id=self.model_name,
|
model_id=self.model_name,
|
||||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||||
|
@ -106,13 +106,7 @@ class OpenAIModel(BaseLLM):
|
|||||||
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
|
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
|
||||||
|
|
||||||
prompts = self._get_prompt_from_messages(messages)
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
|
return self._client.generate([prompts], stop, callbacks)
|
||||||
try:
|
|
||||||
return self._client.generate([prompts], stop, callbacks)
|
|
||||||
finally:
|
|
||||||
thread_context = api_requestor._thread_context
|
|
||||||
if hasattr(thread_context, "session") and thread_context.session:
|
|
||||||
thread_context.session.close()
|
|
||||||
|
|
||||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -155,7 +155,7 @@ class CompletionService:
|
|||||||
generate_worker_thread.start()
|
generate_worker_thread.start()
|
||||||
|
|
||||||
# wait for 10 minutes to close the thread
|
# wait for 10 minutes to close the thread
|
||||||
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
|
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
|
||||||
|
|
||||||
return cls.compact_response(pubsub, streaming)
|
return cls.compact_response(pubsub, streaming)
|
||||||
|
|
||||||
@ -210,25 +210,26 @@ class CompletionService:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
|
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
|
||||||
# wait for 10 minutes to close the thread
|
# wait for 10 minutes to close the thread
|
||||||
timeout = 600
|
timeout = 600
|
||||||
|
|
||||||
def close_pubsub():
|
def close_pubsub():
|
||||||
sleep_iterations = 0
|
with flask_app.app_context():
|
||||||
while sleep_iterations < timeout and worker_thread.is_alive():
|
sleep_iterations = 0
|
||||||
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
|
while sleep_iterations < timeout and worker_thread.is_alive():
|
||||||
PubHandler.ping(user, generate_task_id)
|
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
|
||||||
|
PubHandler.ping(user, generate_task_id)
|
||||||
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
sleep_iterations += 1
|
sleep_iterations += 1
|
||||||
|
|
||||||
if worker_thread.is_alive():
|
if worker_thread.is_alive():
|
||||||
PubHandler.stop(user, generate_task_id)
|
PubHandler.stop(user, generate_task_id)
|
||||||
try:
|
try:
|
||||||
pubsub.close()
|
pubsub.close()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
countdown_thread = threading.Thread(target=close_pubsub)
|
countdown_thread = threading.Thread(target=close_pubsub)
|
||||||
countdown_thread.start()
|
countdown_thread.start()
|
||||||
@ -288,7 +289,7 @@ class CompletionService:
|
|||||||
|
|
||||||
generate_worker_thread.start()
|
generate_worker_thread.start()
|
||||||
|
|
||||||
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
|
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
|
||||||
|
|
||||||
return cls.compact_response(pubsub, streaming)
|
return cls.compact_response(pubsub, streaming)
|
||||||
|
|
||||||
@ -313,15 +314,14 @@ class CompletionService:
|
|||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
|
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
|
||||||
ModelCurrentlyNotSupportError) as e:
|
ModelCurrentlyNotSupportError) as e:
|
||||||
db.session.rollback()
|
|
||||||
PubHandler.pub_error(user, generate_task_id, e)
|
PubHandler.pub_error(user, generate_task_id, e)
|
||||||
except LLMAuthorizationError:
|
except LLMAuthorizationError:
|
||||||
db.session.rollback()
|
|
||||||
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
|
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.session.rollback()
|
|
||||||
logging.exception("Unknown Error in completion")
|
logging.exception("Unknown Error in completion")
|
||||||
PubHandler.pub_error(user, generate_task_id, e)
|
PubHandler.pub_error(user, generate_task_id, e)
|
||||||
|
finally:
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user