fix: detached model in completion thread (#1269)

This commit is contained in:
takatost 2023-10-02 22:27:25 +08:00 committed by GitHub
parent 41d4c5b424
commit 373e90ee6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 21 deletions

View File

@ -132,8 +132,6 @@ class BaseLLM(BaseProviderModel):
if self.deduct_quota: if self.deduct_quota:
self.model_provider.check_quota_over_limit() self.model_provider.check_quota_over_limit()
db.session.commit()
if not callbacks: if not callbacks:
callbacks = self.callbacks callbacks = self.callbacks
else: else:

View File

@ -3,7 +3,7 @@ import logging
import threading import threading
import time import time
import uuid import uuid
from typing import Generator, Union, Any from typing import Generator, Union, Any, Optional
from flask import current_app, Flask from flask import current_app, Flask
from redis.client import PubSub from redis.client import PubSub
@ -141,12 +141,12 @@ class CompletionService:
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={ generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id, 'generate_task_id': generate_task_id,
'app_model': app_model, 'detached_app_model': app_model,
'app_model_config': app_model_config, 'app_model_config': app_model_config,
'query': query, 'query': query,
'inputs': inputs, 'inputs': inputs,
'user': user, 'detached_user': user,
'conversation': conversation, 'detached_conversation': conversation,
'streaming': streaming, 'streaming': streaming,
'is_model_config_override': is_model_config_override, 'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev' 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
@ -171,18 +171,22 @@ class CompletionService:
return user return user
@classmethod @classmethod
def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig, def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
query: str, inputs: dict, user: Union[Account, EndUser], query: str, inputs: dict, detached_user: Union[Account, EndUser],
conversation: Conversation, streaming: bool, is_model_config_override: bool, detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev'): retriever_from: str = 'dev'):
with flask_app.app_context(): with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
app_model = db.session.merge(detached_app_model)
if detached_conversation:
conversation = db.session.merge(detached_conversation)
else:
conversation = None
try: try:
if conversation:
# fixed the state of the conversation object when it detached from the original session
conversation = db.session.query(Conversation).filter_by(id=conversation.id).first()
# run # run
Completion.generate( Completion.generate(
task_id=generate_task_id, task_id=generate_task_id,
app=app_model, app=app_model,
@ -210,12 +214,14 @@ class CompletionService:
db.session.commit() db.session.commit()
@classmethod @classmethod
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, user, generate_task_id) -> threading.Thread: def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
# wait for 10 minutes to close the thread # wait for 10 minutes to close the thread
timeout = 600 timeout = 600
def close_pubsub(): def close_pubsub():
with flask_app.app_context(): with flask_app.app_context():
user = db.session.merge(detached_user)
sleep_iterations = 0 sleep_iterations = 0
while sleep_iterations < timeout and worker_thread.is_alive(): while sleep_iterations < timeout and worker_thread.is_alive():
if sleep_iterations > 0 and sleep_iterations % 10 == 0: if sleep_iterations > 0 and sleep_iterations % 10 == 0:
@ -279,11 +285,11 @@ class CompletionService:
generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={ generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id, 'generate_task_id': generate_task_id,
'app_model': app_model, 'detached_app_model': app_model,
'app_model_config': app_model_config, 'app_model_config': app_model_config,
'message': message, 'detached_message': message,
'pre_prompt': pre_prompt, 'pre_prompt': pre_prompt,
'user': user, 'detached_user': user,
'streaming': streaming 'streaming': streaming
}) })
@ -294,10 +300,15 @@ class CompletionService:
return cls.compact_response(pubsub, streaming) return cls.compact_response(pubsub, streaming)
@classmethod @classmethod
def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
app_model_config: AppModelConfig, message: Message, pre_prompt: str, app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str,
user: Union[Account, EndUser], streaming: bool): detached_user: Union[Account, EndUser], streaming: bool):
with flask_app.app_context(): with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
app_model = db.session.merge(detached_app_model)
message = db.session.merge(detached_message)
try: try:
# run # run
Completion.generate_more_like_this( Completion.generate_more_like_this(