mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 17:19:08 +08:00
chore: apply ruff's pyupgrade linter rules to modernize Python code with targeted version (#2419)
This commit is contained in:
parent
589099a005
commit
063191889d
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
@ -169,8 +169,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
@ -246,8 +247,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import flask_login
|
||||
from flask import current_app, request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import flask_restful
|
||||
from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
|
@ -1,6 +1,4 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
@ -71,7 +69,7 @@ class DocumentResource(Resource):
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]:
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
@ -1,8 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Generator, Union
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
@ -164,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import current_user
|
||||
|
@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
@ -123,8 +123,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import and_
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from functools import wraps
|
||||
|
||||
from flask import current_app, request
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, fields, marshal_with, reqparse
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
@ -182,8 +183,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
@ -154,8 +154,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
@ -160,8 +160,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
@ -17,7 +17,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
@ -38,7 +38,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
"""
|
||||
On new chunk callback
|
||||
@ -58,7 +58,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
@ -80,7 +80,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, cast
|
||||
from typing import cast
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
@ -8,7 +8,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
|
||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
@ -42,7 +43,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -85,7 +86,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
||||
def real_plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -146,7 +147,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -158,7 +159,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
@ -51,7 +52,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
@ -125,7 +126,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -207,7 +208,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
try:
|
||||
@ -215,7 +216,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
|
||||
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(
|
||||
self.model_config,
|
||||
@ -264,7 +265,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
return new_messages
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
self, messages: list[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
@ -275,7 +276,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
|
||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
|
@ -1,5 +1,6 @@
|
||||
import re
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||
@ -68,7 +69,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -125,8 +126,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
@ -153,7 +154,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
@ -180,7 +181,7 @@ Thought: {agent_scratchpad}
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
@ -213,8 +214,8 @@ Thought: {agent_scratchpad}
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
|
@ -1,5 +1,6 @@
|
||||
import re
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||
@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -127,7 +128,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
@ -154,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
self, messages: list[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
@ -173,8 +174,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
@ -200,7 +201,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
@ -227,7 +228,7 @@ Thought: {agent_scratchpad}
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
@ -260,8 +261,8 @@ Thought: {agent_scratchpad}
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import time
|
||||
from typing import Generator, List, Optional, Tuple, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import (
|
||||
@ -84,7 +85,7 @@ class AppRunner:
|
||||
return rest_tokens
|
||||
|
||||
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||
prompt_messages: List[PromptMessage]):
|
||||
prompt_messages: list[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
@ -126,7 +127,7 @@ class AppRunner:
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
-> Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
:param context:
|
||||
@ -295,7 +296,7 @@ class AppRunner:
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
query: str) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
|
@ -1,7 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Generator, Optional, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -118,7 +119,7 @@ class GenerateTaskPipeline:
|
||||
}
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
@ -201,7 +202,7 @@ class GenerateTaskPipeline:
|
||||
data = self._error_to_stream_response_data(self._handle_error(event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
@ -353,7 +354,7 @@ class GenerateTaskPipeline:
|
||||
|
||||
yield self._yield_response(response)
|
||||
|
||||
elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
|
||||
elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel
|
||||
@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
|
@ -2,7 +2,8 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Generator, Optional, Tuple, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -585,7 +586,7 @@ class ApplicationManager:
|
||||
return AppOrchestrationConfigEntity(**properties)
|
||||
|
||||
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
|
||||
-> Tuple[Conversation, Message]:
|
||||
-> tuple[Conversation, Message]:
|
||||
"""
|
||||
Initialize generate records
|
||||
:param application_generate_entity: application generate entity
|
||||
|
@ -1,7 +1,8 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any, Generator
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
@ -37,7 +37,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._message_agent_thought = None
|
||||
|
||||
@property
|
||||
def agent_loops(self) -> List[AgentLoop]:
|
||||
def agent_loops(self) -> list[AgentLoop]:
|
||||
return self._agent_loops
|
||||
|
||||
def clear_agent_loops(self) -> None:
|
||||
@ -95,14 +95,14 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@ -120,7 +120,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
@ -21,7 +21,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
||||
def on_tool_start(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
tool_inputs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
||||
@ -29,7 +29,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
||||
def on_tool_end(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
tool_inputs: dict[str, Any],
|
||||
tool_outputs: str,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
|
@ -1,4 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
@ -40,7 +39,7 @@ class DatasetIndexToolCallbackHandler:
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def on_tool_end(self, documents: List[Document]) -> None:
|
||||
def on_tool_end(self, documents: list[Document]) -> None:
|
||||
"""Handle tool end."""
|
||||
for document in documents:
|
||||
doc_id = document.metadata['doc_id']
|
||||
@ -55,7 +54,7 @@ class DatasetIndexToolCallbackHandler:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def return_retriever_resource_info(self, resource: List):
|
||||
def return_retriever_resource_info(self, resource: list):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
if resource and len(resource) > 0:
|
||||
for item in resource:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
@ -16,8 +16,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
print_text("\n[on_chat_model_start]\n", color='blue')
|
||||
@ -26,7 +26,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
print_text(str(sub_message) + "\n", color='blue')
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
print_text("\n[on_llm_start]\n", color='blue')
|
||||
@ -48,13 +48,13 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
chain_type = serialized['id'][-1]
|
||||
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
|
||||
|
||||
@ -66,7 +66,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain import LLMChain as LCLLMChain
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
@ -16,12 +16,12 @@ class LLMChain(LCLLMChain):
|
||||
model_config: ModelConfigEntity
|
||||
"""The language model instance to use."""
|
||||
llm: BaseLanguageModel = FakeLLM(response="")
|
||||
parameters: Dict[str, Any] = {}
|
||||
parameters: dict[str, Any] = {}
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
input_list: list[dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
|
@ -1,6 +1,6 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
@ -28,7 +28,7 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]:
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
@ -37,7 +37,7 @@ class FileExtractor:
|
||||
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]:
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
@ -53,7 +53,7 @@ class FileExtractor:
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||
upload_file: Optional[UploadFile] = None,
|
||||
is_automatic: bool = False) -> Union[List[Document], str]:
|
||||
is_automatic: bool = False) -> Union[list[Document], str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
file_extension = input_file.suffix.lower()
|
||||
|
@ -1,6 +1,6 @@
|
||||
import csv
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from langchain.document_loaders import CSVLoader as LCCSVLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
@ -14,7 +14,7 @@ class CSVLoader(LCCSVLoader):
|
||||
self,
|
||||
file_path: str,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[Dict] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
@ -24,7 +24,7 @@ class CSVLoader(LCCSVLoader):
|
||||
self.csv_args = csv_args or {}
|
||||
self.autodetect_encoding = autodetect_encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
"""Load data into document objects."""
|
||||
try:
|
||||
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@ -23,7 +22,7 @@ class ExcelLoader(BaseLoader):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
data = []
|
||||
keys = []
|
||||
wb = load_workbook(filename=self._file_path, read_only=True)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
@ -23,7 +22,7 @@ class HTMLLoader(BaseLoader):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
return [Document(page_content=self._load_as_text())]
|
||||
|
||||
def _load_as_text(self) -> str:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
@ -42,7 +42,7 @@ class MarkdownLoader(BaseLoader):
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
tups = self.parse_tups(self._file_path)
|
||||
documents = []
|
||||
for header, value in tups:
|
||||
@ -54,13 +54,13 @@ class MarkdownLoader(BaseLoader):
|
||||
|
||||
return documents
|
||||
|
||||
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
|
||||
def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]:
|
||||
"""Convert a markdown file to a dictionary.
|
||||
|
||||
The keys are the headers and the values are the text under each header.
|
||||
|
||||
"""
|
||||
markdown_tups: List[Tuple[Optional[str], str]] = []
|
||||
markdown_tups: list[tuple[Optional[str], str]] = []
|
||||
lines = markdown_text.split("\n")
|
||||
|
||||
current_header = None
|
||||
@ -103,11 +103,11 @@ class MarkdownLoader(BaseLoader):
|
||||
content = re.sub(pattern, r"\1", content)
|
||||
return content
|
||||
|
||||
def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
|
||||
def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]:
|
||||
"""Parse file into tuples."""
|
||||
content = ""
|
||||
try:
|
||||
with open(filepath, "r", encoding=self._encoding) as f:
|
||||
with open(filepath, encoding=self._encoding) as f:
|
||||
content = f.read()
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
@ -67,7 +67,7 @@ class NotionLoader(BaseLoader):
|
||||
document_model=document_model
|
||||
)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
self.update_last_edited_time(
|
||||
self._document_model
|
||||
)
|
||||
@ -78,7 +78,7 @@ class NotionLoader(BaseLoader):
|
||||
|
||||
def _load_data_as_documents(
|
||||
self, notion_obj_id: str, notion_page_type: str
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
docs = []
|
||||
if notion_page_type == 'database':
|
||||
# get all the pages in the database
|
||||
@ -94,8 +94,8 @@ class NotionLoader(BaseLoader):
|
||||
return docs
|
||||
|
||||
def _get_notion_database_data(
|
||||
self, database_id: str, query_dict: Dict[str, Any] = {}
|
||||
) -> List[Document]:
|
||||
self, database_id: str, query_dict: dict[str, Any] = {}
|
||||
) -> list[Document]:
|
||||
"""Get all the pages from a Notion database."""
|
||||
res = requests.post(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
@ -149,12 +149,12 @@ class NotionLoader(BaseLoader):
|
||||
|
||||
return database_content_list
|
||||
|
||||
def _get_notion_block_data(self, page_id: str) -> List[str]:
|
||||
def _get_notion_block_data(self, page_id: str) -> list[str]:
|
||||
result_lines_arr = []
|
||||
cur_block_id = page_id
|
||||
while True:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: Dict[str, Any] = {}
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
@ -216,7 +216,7 @@ class NotionLoader(BaseLoader):
|
||||
cur_block_id = block_id
|
||||
while True:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: Dict[str, Any] = {}
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
@ -280,7 +280,7 @@ class NotionLoader(BaseLoader):
|
||||
cur_block_id = block_id
|
||||
while not done:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: Dict[str, Any] = {}
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
@ -346,7 +346,7 @@ class NotionLoader(BaseLoader):
|
||||
else:
|
||||
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
|
||||
|
||||
query_dict: Dict[str, Any] = {}
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from langchain.document_loaders import PyPDFium2Loader
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
@ -28,7 +28,7 @@ class PdfLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._upload_file = upload_file
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_exists = False
|
||||
if self._upload_file:
|
||||
|
@ -1,6 +1,5 @@
|
||||
import base64
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
@ -24,7 +23,7 @@ class UnstructuredEmailLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.email import partition_email
|
||||
elements = partition_email(filename=self._file_path, api_url=self._api_url)
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@ -34,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.md import partition_md
|
||||
|
||||
elements = partition_md(filename=self._file_path, api_url=self._api_url)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@ -24,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@ -23,7 +22,7 @@ class UnstructuredPPTLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@ -22,7 +21,7 @@ class UnstructuredPPTXLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@ -24,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.text import partition_text
|
||||
|
||||
elements = partition_text(filename=self._file_path, api_url=self._api_url)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@ -24,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.xml import partition_xml
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, Optional, Sequence, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from langchain.schema import Document
|
||||
from sqlalchemy import func
|
||||
@ -22,10 +23,10 @@ class DatasetDocumentStore:
|
||||
self._document_id = document_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore":
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {
|
||||
"dataset_id": self._dataset.id,
|
||||
@ -40,7 +41,7 @@ class DatasetDocumentStore:
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def docs(self) -> Dict[str, Document]:
|
||||
def docs(self) -> dict[str, Document]:
|
||||
document_segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self._dataset.id
|
||||
).all()
|
||||
|
@ -1,6 +1,6 @@
|
||||
import base64
|
||||
import logging
|
||||
from typing import List, Optional, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -21,7 +21,7 @@ class CacheEmbedding(Embeddings):
|
||||
self._model_instance = model_instance
|
||||
self._user = user
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
text_embeddings = []
|
||||
try:
|
||||
@ -52,7 +52,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
|
@ -1,8 +1,9 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -135,7 +136,7 @@ class ProviderConfiguration(BaseModel):
|
||||
if self.provider.provider_credential_schema else []
|
||||
)
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@ -282,7 +283,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return None
|
||||
|
||||
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
|
||||
-> Tuple[ProviderModel, dict]:
|
||||
-> tuple[ProviderModel, dict]:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
@ -711,7 +712,7 @@ class ProviderConfigurations(BaseModel):
|
||||
Model class for provider configuration dict.
|
||||
"""
|
||||
tenant_id: str
|
||||
configurations: Dict[str, ProviderConfiguration] = {}
|
||||
configurations: dict[str, ProviderConfiguration] = {}
|
||||
|
||||
def __init__(self, tenant_id: str):
|
||||
super().__init__(tenant_id=tenant_id)
|
||||
@ -759,7 +760,7 @@ class ProviderConfigurations(BaseModel):
|
||||
|
||||
return all_models
|
||||
|
||||
def to_list(self) -> List[ProviderConfiguration]:
|
||||
def to_list(self) -> list[ProviderConfiguration]:
|
||||
"""
|
||||
Convert to list.
|
||||
|
||||
|
@ -61,7 +61,7 @@ class Extensible:
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, '__builtin__')
|
||||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, 'r', encoding='utf-8') as f:
|
||||
with open(builtin_file_path, encoding='utf-8') as f:
|
||||
position = int(f.read().strip())
|
||||
|
||||
if (extension_name + '.py') not in file_names:
|
||||
@ -93,7 +93,7 @@ class Extensible:
|
||||
json_path = os.path.join(subdir_path, 'schema.json')
|
||||
json_data = {}
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
with open(json_path, encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
extensions[extension_name] = ModuleExtension(
|
||||
|
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
@ -50,7 +50,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[List[PromptMessage]] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None
|
||||
@ -122,7 +122,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return app_orchestration_config
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
|
||||
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
@ -140,7 +140,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]:
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
@ -325,7 +325,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]:
|
||||
def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
@ -356,7 +356,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return result
|
||||
|
||||
def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]:
|
||||
def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
@ -404,7 +404,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
return result
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: List[str]
|
||||
tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Create agent thought
|
||||
@ -449,7 +449,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
thought: str,
|
||||
observation: str,
|
||||
answer: str,
|
||||
messages_ids: List[str],
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
"""
|
||||
Save agent thought
|
||||
@ -505,7 +505,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def get_history_prompt_messages(self) -> List[PromptMessage]:
|
||||
def get_history_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages
|
||||
"""
|
||||
@ -516,7 +516,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return self.history_prompt_messages
|
||||
|
||||
def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]:
|
||||
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Transform tool message into agent thought
|
||||
"""
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Generator, List, Literal, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Literal, Union
|
||||
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
|
||||
@ -29,7 +30,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
@ -37,7 +38,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
app_orchestration_config = self.app_orchestration_config
|
||||
self._repack_app_orchestration_config(app_orchestration_config)
|
||||
|
||||
agent_scratchpad: List[AgentScratchpadUnit] = []
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
|
||||
# check model mode
|
||||
if self.app_orchestration_config.model_config.mode == "completion":
|
||||
@ -56,7 +57,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
prompt_messages = self.history_prompt_messages
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
@ -83,7 +84,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
@ -493,7 +494,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
if not next_iteration.find("{{observation}}") >= 0:
|
||||
raise ValueError("{{observation}} is required in next_iteration")
|
||||
|
||||
def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
|
||||
def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
convert agent scratchpad list to str
|
||||
"""
|
||||
@ -506,13 +507,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
return result
|
||||
|
||||
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: List[PromptMessageTool],
|
||||
agent_scratchpad: List[AgentScratchpadUnit],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
instruction: str,
|
||||
input: str,
|
||||
) -> List[PromptMessage]:
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Tuple, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
||||
@ -44,7 +45,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
)
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
@ -70,13 +71,13 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
agent_thoughts: List[MessageAgentThought] = []
|
||||
agent_thoughts: list[MessageAgentThought] = []
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
@ -117,7 +118,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
@ -364,7 +365,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@ -381,7 +382,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
@ -96,7 +96,7 @@ class DatasetRetrievalFeature:
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler) \
|
||||
-> Optional[List[BaseTool]]:
|
||||
-> Optional[list[BaseTool]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param tenant_id: tenant id
|
||||
|
@ -2,7 +2,7 @@ import concurrent
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
@ -62,7 +62,7 @@ class ExternalDataFetchFeature:
|
||||
app_id: str,
|
||||
external_data_tool: ExternalDataVariableEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
query: str) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Query external data tool.
|
||||
:param flask_app: flask app
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from core.entities.application_entities import AppOrchestrationConfigEntity
|
||||
from core.moderation.base import ModerationAction, ModerationException
|
||||
@ -13,7 +12,7 @@ class ModerationFeature:
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
query: str) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
@ -15,8 +15,8 @@ class MessageFileParser:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
|
||||
user: Union[Account, EndUser]) -> List[FileObj]:
|
||||
def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig,
|
||||
user: Union[Account, EndUser]) -> list[FileObj]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
@ -96,7 +96,7 @@ class MessageFileParser:
|
||||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
|
||||
def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]:
|
||||
"""
|
||||
transform message files
|
||||
|
||||
@ -110,8 +110,8 @@ class MessageFileParser:
|
||||
# return all file objs
|
||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||
|
||||
def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
|
||||
file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
|
||||
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
|
||||
file_upload_config: dict) -> dict[FileType, list[FileObj]]:
|
||||
"""
|
||||
transform files to file objs
|
||||
|
||||
@ -119,7 +119,7 @@ class MessageFileParser:
|
||||
:param file_upload_config:
|
||||
:return:
|
||||
"""
|
||||
type_file_objs: Dict[FileType, List[FileObj]] = {
|
||||
type_file_objs: dict[FileType, list[FileObj]] = {
|
||||
# Currently only support image
|
||||
FileType.IMAGE: []
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
@ -53,7 +53,7 @@ class BaseIndex(ABC):
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self) -> None:
|
||||
|
@ -1,5 +1,4 @@
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
import jieba
|
||||
from jieba.analyse import default_tfidf
|
||||
@ -12,7 +11,7 @@ class JiebaKeywordTableHandler:
|
||||
def __init__(self):
|
||||
default_tfidf.stop_words = STOPWORDS
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text,
|
||||
@ -21,7 +20,7 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
return set(self._expand_tokens_with_subtokens(keywords))
|
||||
|
||||
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
|
||||
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
|
||||
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
||||
results = set()
|
||||
for token in tokens:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
@ -116,7 +116,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
||||
@ -221,7 +221,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
keywords = keyword_table_handler.extract_keywords(query)
|
||||
|
||||
# go through text chunks in order of most matching keywords
|
||||
chunk_indices_count: Dict[str, int] = defaultdict(int)
|
||||
chunk_indices_count: dict[str, int] = defaultdict(int)
|
||||
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
|
||||
for keyword in keywords:
|
||||
for node_id in keyword_table[keyword]:
|
||||
@ -235,7 +235,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
|
||||
return sorted_chunk_indices[: k]
|
||||
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]):
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
|
||||
document_segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id == node_id
|
||||
@ -244,7 +244,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
document_segment.keywords = keywords
|
||||
db.session.commit()
|
||||
|
||||
def create_segment_keywords(self, node_id: str, keywords: List[str]):
|
||||
def create_segment_keywords(self, node_id: str, keywords: list[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
self._update_segment_keywords(self.dataset.id, node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
@ -266,7 +266,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
|
||||
def update_segment_keywords_index(self, node_id: str, keywords: list[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
@ -282,7 +282,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(self, query: str) -> list[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
@ -293,7 +293,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel):
|
||||
"""
|
||||
return self.index.search(query, **self.search_kwargs)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(self, query: str) -> list[Document]:
|
||||
raise NotImplementedError("KeywordTableRetriever does not support async")
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
@ -43,13 +43,13 @@ class BaseVectorIndex(BaseIndex):
|
||||
def search_by_full_text_index(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
@ -160,6 +160,6 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
],
|
||||
))
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# milvus/zilliz doesn't support bm25 search
|
||||
return []
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, List, Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import qdrant_client
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -210,7 +210,7 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
|
||||
return False
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import requests
|
||||
import weaviate
|
||||
@ -172,7 +172,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
|
||||
|
||||
return False
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
|
||||
|
@ -5,7 +5,7 @@ import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Optional, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask_login import current_user
|
||||
@ -40,7 +40,7 @@ class IndexingRunner:
|
||||
self.storage = storage
|
||||
self.model_manager = ModelManager()
|
||||
|
||||
def run(self, dataset_documents: List[DatasetDocument]):
|
||||
def run(self, dataset_documents: list[DatasetDocument]):
|
||||
"""Run the indexing process."""
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
@ -238,7 +238,7 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
|
||||
def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
@ -494,7 +494,7 @@ class IndexingRunner:
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]:
|
||||
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]:
|
||||
# load file
|
||||
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
|
||||
return []
|
||||
@ -526,7 +526,7 @@ class IndexingRunner:
|
||||
)
|
||||
|
||||
# replace doc id to document model id
|
||||
text_docs = cast(List[Document], text_docs)
|
||||
text_docs = cast(list[Document], text_docs)
|
||||
for text_doc in text_docs:
|
||||
# remove invalid symbol
|
||||
text_doc.page_content = self.filter_string(text_doc.page_content)
|
||||
@ -540,7 +540,7 @@ class IndexingRunner:
|
||||
text = re.sub(r'\|>', '>', text)
|
||||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
|
||||
# Unicode U+FFFE
|
||||
text = re.sub(u'\uFFFE', '', text)
|
||||
text = re.sub('\uFFFE', '', text)
|
||||
return text
|
||||
|
||||
def _get_splitter(self, processing_rule: DatasetProcessRule,
|
||||
@ -577,9 +577,9 @@ class IndexingRunner:
|
||||
|
||||
return character_splitter
|
||||
|
||||
def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
def _step_split(self, text_docs: list[Document], splitter: TextSplitter,
|
||||
dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
|
||||
-> List[Document]:
|
||||
-> list[Document]:
|
||||
"""
|
||||
Split the text documents into documents and save them to the document segment.
|
||||
"""
|
||||
@ -624,9 +624,9 @@ class IndexingRunner:
|
||||
|
||||
return documents
|
||||
|
||||
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule, tenant_id: str,
|
||||
document_form: str, document_language: str) -> List[Document]:
|
||||
document_form: str, document_language: str) -> list[Document]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
@ -699,8 +699,8 @@ class IndexingRunner:
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
||||
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||
def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule) -> list[Document]:
|
||||
"""
|
||||
Split the text documents into nodes.
|
||||
"""
|
||||
@ -770,7 +770,7 @@ class IndexingRunner:
|
||||
for q, a in matches if q and a
|
||||
]
|
||||
|
||||
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
|
||||
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None:
|
||||
"""
|
||||
Build the index for the document.
|
||||
"""
|
||||
@ -877,7 +877,7 @@ class IndexingRunner:
|
||||
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):
|
||||
def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset):
|
||||
"""
|
||||
Batch add segments index processing
|
||||
"""
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import IO, Generator, List, Optional, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import IO, Optional, Union, cast
|
||||
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
@ -47,7 +48,7 @@ class ModelInstance:
|
||||
return credentials
|
||||
|
||||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
@ -23,7 +23,7 @@ class Callback(ABC):
|
||||
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
@ -42,7 +42,7 @@ class Callback(ABC):
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
"""
|
||||
On new chunk callback
|
||||
@ -62,7 +62,7 @@ class Callback(ABC):
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
@ -82,7 +82,7 @@ class Callback(ABC):
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
||||
class LoggingCallback(Callback):
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
@ -60,7 +60,7 @@ class LoggingCallback(Callback):
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
"""
|
||||
On new chunk callback
|
||||
@ -81,7 +81,7 @@ class LoggingCallback(Callback):
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
@ -113,7 +113,7 @@ class LoggingCallback(Callback):
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
|
@ -1,8 +1,7 @@
|
||||
from typing import Dict
|
||||
|
||||
from core.model_runtime.entities.model_entities import DefaultParameterName
|
||||
|
||||
PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
|
||||
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
DefaultParameterName.TEMPERATURE: {
|
||||
'label': {
|
||||
'en_US': 'Temperature',
|
||||
|
@ -153,7 +153,7 @@ class AIModel(ABC):
|
||||
# read _position.yaml file
|
||||
position_map = {}
|
||||
if os.path.exists(position_file_path):
|
||||
with open(position_file_path, 'r', encoding='utf-8') as f:
|
||||
with open(position_file_path, encoding='utf-8') as f:
|
||||
positions = yaml.safe_load(f)
|
||||
# convert list to dict with key as model provider name, value as index
|
||||
position_map = {position: index for index, position in enumerate(positions)}
|
||||
@ -161,7 +161,7 @@ class AIModel(ABC):
|
||||
# traverse all model_schema_yaml_paths
|
||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||
# read yaml data from yaml file
|
||||
with open(model_schema_yaml_path, 'r', encoding='utf-8') as f:
|
||||
with open(model_schema_yaml_path, encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
|
||||
new_parameter_rules = []
|
||||
|
@ -3,7 +3,8 @@ import os
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Generator, List, Optional, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||
@ -29,7 +30,7 @@ class LargeLanguageModel(AIModel):
|
||||
|
||||
def invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@ -122,7 +123,7 @@ class LargeLanguageModel(AIModel):
|
||||
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
|
||||
"""
|
||||
Invoke result generator
|
||||
@ -186,7 +187,7 @@ class LargeLanguageModel(AIModel):
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@ -218,7 +219,7 @@ class LargeLanguageModel(AIModel):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def enforce_stop_tokens(self, text: str, stop: List[str]) -> str:
|
||||
def enforce_stop_tokens(self, text: str, stop: list[str]) -> str:
|
||||
"""Cut off the text as soon as any stop words occur."""
|
||||
return re.split("|".join(stop), text, maxsplit=1)[0]
|
||||
|
||||
@ -329,7 +330,7 @@ class LargeLanguageModel(AIModel):
|
||||
def _trigger_before_invoke_callbacks(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||
"""
|
||||
Trigger before invoke callbacks
|
||||
@ -367,7 +368,7 @@ class LargeLanguageModel(AIModel):
|
||||
def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||
"""
|
||||
Trigger new chunk callbacks
|
||||
@ -406,7 +407,7 @@ class LargeLanguageModel(AIModel):
|
||||
def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||
"""
|
||||
Trigger after invoke callbacks
|
||||
@ -446,7 +447,7 @@ class LargeLanguageModel(AIModel):
|
||||
def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
||||
"""
|
||||
Trigger invoke error callbacks
|
||||
@ -527,7 +528,7 @@ class LargeLanguageModel(AIModel):
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
|
||||
elif parameter_rule.type == ParameterType.FLOAT:
|
||||
if not isinstance(parameter_value, (float, int)):
|
||||
if not isinstance(parameter_value, float | int):
|
||||
raise ValueError(f"Model Parameter {parameter_name} should be float.")
|
||||
|
||||
# validate parameter value precision
|
||||
|
@ -1,7 +1,6 @@
|
||||
import importlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict
|
||||
|
||||
import yaml
|
||||
|
||||
@ -12,7 +11,7 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
class ModelProvider(ABC):
|
||||
provider_schema: ProviderEntity = None
|
||||
model_instance_map: Dict[str, AIModel] = {}
|
||||
model_instance_map: dict[str, AIModel] = {}
|
||||
|
||||
@abstractmethod
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
@ -47,7 +46,7 @@ class ModelProvider(ABC):
|
||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
||||
yaml_data = {}
|
||||
if os.path.exists(yaml_path):
|
||||
with open(yaml_path, 'r', encoding='utf-8') as f:
|
||||
with open(yaml_path, encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
|
||||
try:
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Generator, List, Optional, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import anthropic
|
||||
from anthropic import Anthropic, Stream
|
||||
@ -29,7 +30,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@ -90,7 +91,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -255,7 +256,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt_anthropic(self, messages: List[PromptMessage]) -> str:
|
||||
def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic model
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import logging
|
||||
from typing import Generator, List, Optional, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import tiktoken
|
||||
from openai import AzureOpenAI, Stream
|
||||
@ -34,7 +35,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
|
||||
@ -121,7 +122,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
return ai_model_entity.entity if ai_model_entity else None
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
@ -239,7 +240,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
def _chat_generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
@ -537,7 +538,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_from_messages(self, credentials: dict, messages: List[PromptMessage],
|
||||
def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user