Fix/shared lock (#210)

This commit is contained in:
John Wang 2023-05-25 21:31:11 +08:00 committed by GitHub
parent 4ef6392de5
commit 1a5acf43aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 3 deletions

View File

@ -34,5 +34,9 @@ class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset_id, DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.index_node_id == index_node_id DocumentSegment.index_node_id == index_node_id
).update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) ).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
)
db.session.commit()

View File

@ -1,14 +1,17 @@
import logging
from typing import Optional, List, Union, Tuple from typing import Optional, List, Union, Tuple
from langchain.callbacks import CallbackManager from langchain.callbacks import CallbackManager
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
from requests.exceptions import ChunkedEncodingError
from core.constant import llm_constant from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler
from core.llm.error import LLMBadRequestError from core.llm.error import LLMBadRequestError
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder from core.chain.main_chain_builder import MainChainBuilder
@ -84,6 +87,11 @@ class Completion:
) )
except ConversationTaskStoppedException: except ConversationTaskStoppedException:
return return
except ChunkedEncodingError as e:
# Interrupt by LLM (like OpenAI), handle it.
logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return
@classmethod @classmethod
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,

View File

@ -171,7 +171,7 @@ class ConversationMessageTask:
) )
if not by_stopped: if not by_stopped:
self._pub_handler.pub_end() self.end()
def update_provider_quota(self): def update_provider_quota(self):
llm_provider_service = LLMProviderService( llm_provider_service = LLMProviderService(
@ -268,6 +268,9 @@ class ConversationMessageTask:
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def end(self):
self._pub_handler.pub_end()
class PubHandler: class PubHandler:
def __init__(self, user: Union[Account | EndUser], task_id: str, def __init__(self, user: Union[Account | EndUser], task_id: str,