mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 05:19:00 +08:00
refactor: remove unused codes, move core/agent module into dataset retrieval feature (#2614)
This commit is contained in:
parent
d44b05a9e5
commit
dd961985f0
@ -1,49 +0,0 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage
|
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
||||||
|
|
||||||
|
|
||||||
class CalcTokenMixin:
|
|
||||||
|
|
||||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
|
|
||||||
"""
|
|
||||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
|
||||||
|
|
||||||
:param model_config:
|
|
||||||
:param messages:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
|
||||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
||||||
|
|
||||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
|
||||||
|
|
||||||
max_tokens = 0
|
|
||||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
|
||||||
if (parameter_rule.name == 'max_tokens'
|
|
||||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
|
||||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
|
||||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
|
||||||
|
|
||||||
if model_context_tokens is None:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
if max_tokens is None:
|
|
||||||
max_tokens = 0
|
|
||||||
|
|
||||||
prompt_tokens = model_type_instance.get_num_tokens(
|
|
||||||
model_config.model,
|
|
||||||
model_config.credentials,
|
|
||||||
messages
|
|
||||||
)
|
|
||||||
|
|
||||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
|
||||||
|
|
||||||
return rest_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class ExceededLLMTokensLimitError(Exception):
|
|
||||||
pass
|
|
@ -1,361 +0,0 @@
|
|||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any, Optional, Union
|
|
||||||
|
|
||||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
|
||||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
|
||||||
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
|
|
||||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
|
||||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
|
||||||
from langchain.schema import (
|
|
||||||
AgentAction,
|
|
||||||
AgentFinish,
|
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
get_buffer_string,
|
|
||||||
)
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
from pydantic import root_validator
|
|
||||||
|
|
||||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
|
||||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
|
||||||
from core.chain.llm_chain import LLMChain
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
|
||||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
|
||||||
from core.model_manager import ModelInstance
|
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
|
||||||
from core.third_party.langchain.llms.fake import FakeLLM
|
|
||||||
|
|
||||||
|
|
||||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
|
||||||
moving_summary_buffer: str = ""
|
|
||||||
moving_summary_index: int = 0
|
|
||||||
summary_model_config: ModelConfigEntity = None
|
|
||||||
model_config: ModelConfigEntity
|
|
||||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_llm(cls, values: dict) -> dict:
|
|
||||||
return values
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm_and_tools(
|
|
||||||
cls,
|
|
||||||
model_config: ModelConfigEntity,
|
|
||||||
tools: Sequence[BaseTool],
|
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
|
||||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
|
||||||
system_message: Optional[SystemMessage] = SystemMessage(
|
|
||||||
content="You are a helpful AI assistant."
|
|
||||||
),
|
|
||||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> BaseSingleActionAgent:
|
|
||||||
prompt = cls.create_prompt(
|
|
||||||
extra_prompt_messages=extra_prompt_messages,
|
|
||||||
system_message=system_message,
|
|
||||||
)
|
|
||||||
return cls(
|
|
||||||
model_config=model_config,
|
|
||||||
llm=FakeLLM(response=''),
|
|
||||||
prompt=prompt,
|
|
||||||
tools=tools,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
agent_llm_callback=agent_llm_callback,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def should_use_agent(self, query: str):
|
|
||||||
"""
|
|
||||||
return should use agent
|
|
||||||
|
|
||||||
:param query:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
original_max_tokens = 0
|
|
||||||
for parameter_rule in self.model_config.model_schema.parameter_rules:
|
|
||||||
if (parameter_rule.name == 'max_tokens'
|
|
||||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
|
||||||
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
|
|
||||||
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
|
|
||||||
|
|
||||||
self.model_config.parameters['max_tokens'] = 40
|
|
||||||
|
|
||||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
|
||||||
messages = prompt.to_messages()
|
|
||||||
|
|
||||||
try:
|
|
||||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
|
||||||
model_instance = ModelInstance(
|
|
||||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
|
||||||
model=self.model_config.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
tools = []
|
|
||||||
for function in self.functions:
|
|
||||||
tool = PromptMessageTool(
|
|
||||||
**function
|
|
||||||
)
|
|
||||||
|
|
||||||
tools.append(tool)
|
|
||||||
|
|
||||||
result = model_instance.invoke_llm(
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
tools=tools,
|
|
||||||
stream=False,
|
|
||||||
model_parameters={
|
|
||||||
'temperature': 0.2,
|
|
||||||
'top_p': 0.3,
|
|
||||||
'max_tokens': 1500
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
self.model_config.parameters['max_tokens'] = original_max_tokens
|
|
||||||
|
|
||||||
return True if result.message.tool_calls else False
|
|
||||||
|
|
||||||
def plan(
|
|
||||||
self,
|
|
||||||
intermediate_steps: list[tuple[AgentAction, str]],
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
|
||||||
"""Given input, decided what to do.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
|
||||||
**kwargs: User inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Action specifying what tool to use.
|
|
||||||
"""
|
|
||||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
|
||||||
selected_inputs = {
|
|
||||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
|
||||||
}
|
|
||||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
|
||||||
prompt = self.prompt.format_prompt(**full_inputs)
|
|
||||||
messages = prompt.to_messages()
|
|
||||||
|
|
||||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
|
||||||
|
|
||||||
# summarize messages if rest_tokens < 0
|
|
||||||
try:
|
|
||||||
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
|
|
||||||
except ExceededLLMTokensLimitError as e:
|
|
||||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
|
||||||
|
|
||||||
model_instance = ModelInstance(
|
|
||||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
|
||||||
model=self.model_config.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
tools = []
|
|
||||||
for function in self.functions:
|
|
||||||
tool = PromptMessageTool(
|
|
||||||
**function
|
|
||||||
)
|
|
||||||
|
|
||||||
tools.append(tool)
|
|
||||||
|
|
||||||
result = model_instance.invoke_llm(
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
tools=tools,
|
|
||||||
stream=False,
|
|
||||||
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
|
|
||||||
model_parameters={
|
|
||||||
'temperature': 0.2,
|
|
||||||
'top_p': 0.3,
|
|
||||||
'max_tokens': 1500
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
ai_message = AIMessage(
|
|
||||||
content=result.message.content or "",
|
|
||||||
additional_kwargs={
|
|
||||||
'function_call': {
|
|
||||||
'id': result.message.tool_calls[0].id,
|
|
||||||
**result.message.tool_calls[0].function.dict()
|
|
||||||
} if result.message.tool_calls else None
|
|
||||||
}
|
|
||||||
)
|
|
||||||
agent_decision = _parse_ai_message(ai_message)
|
|
||||||
|
|
||||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
|
||||||
tool_inputs = agent_decision.tool_input
|
|
||||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
|
||||||
tool_inputs['query'] = kwargs['input']
|
|
||||||
agent_decision.tool_input = tool_inputs
|
|
||||||
|
|
||||||
return agent_decision
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_system_message(cls):
|
|
||||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
|
||||||
"The current date or current time you know is wrong.\n"
|
|
||||||
"Respond directly if appropriate.")
|
|
||||||
|
|
||||||
def return_stopped_response(
|
|
||||||
self,
|
|
||||||
early_stopping_method: str,
|
|
||||||
intermediate_steps: list[tuple[AgentAction, str]],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> AgentFinish:
|
|
||||||
try:
|
|
||||||
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
|
||||||
except ValueError:
|
|
||||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
|
||||||
|
|
||||||
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
|
|
||||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
|
||||||
rest_tokens = self.get_message_rest_tokens(
|
|
||||||
self.model_config,
|
|
||||||
messages,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
|
||||||
if rest_tokens >= 0:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
system_message = None
|
|
||||||
human_message = None
|
|
||||||
should_summary_messages = []
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message, SystemMessage):
|
|
||||||
system_message = message
|
|
||||||
elif isinstance(message, HumanMessage):
|
|
||||||
human_message = message
|
|
||||||
else:
|
|
||||||
should_summary_messages.append(message)
|
|
||||||
|
|
||||||
if len(should_summary_messages) > 2:
|
|
||||||
ai_message = should_summary_messages[-2]
|
|
||||||
function_message = should_summary_messages[-1]
|
|
||||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
|
||||||
self.moving_summary_index = len(should_summary_messages)
|
|
||||||
else:
|
|
||||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
|
||||||
raise ExceededLLMTokensLimitError(error_msg)
|
|
||||||
|
|
||||||
new_messages = [system_message, human_message]
|
|
||||||
|
|
||||||
if self.moving_summary_index == 0:
|
|
||||||
should_summary_messages.insert(0, human_message)
|
|
||||||
|
|
||||||
self.moving_summary_buffer = self.predict_new_summary(
|
|
||||||
messages=should_summary_messages,
|
|
||||||
existing_summary=self.moving_summary_buffer
|
|
||||||
)
|
|
||||||
|
|
||||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
|
||||||
new_messages.append(ai_message)
|
|
||||||
new_messages.append(function_message)
|
|
||||||
|
|
||||||
return new_messages
|
|
||||||
|
|
||||||
def predict_new_summary(
|
|
||||||
self, messages: list[BaseMessage], existing_summary: str
|
|
||||||
) -> str:
|
|
||||||
new_lines = get_buffer_string(
|
|
||||||
messages,
|
|
||||||
human_prefix="Human",
|
|
||||||
ai_prefix="AI",
|
|
||||||
)
|
|
||||||
|
|
||||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
|
||||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
|
|
||||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
|
||||||
|
|
||||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
|
||||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
|
||||||
if model_config.provider == 'azure_openai':
|
|
||||||
model = model_config.model
|
|
||||||
model = model.replace("gpt-35", "gpt-3.5")
|
|
||||||
else:
|
|
||||||
model = model_config.credentials.get("base_model_name")
|
|
||||||
|
|
||||||
tiktoken_ = _import_tiktoken()
|
|
||||||
try:
|
|
||||||
encoding = tiktoken_.encoding_for_model(model)
|
|
||||||
except KeyError:
|
|
||||||
model = "cl100k_base"
|
|
||||||
encoding = tiktoken_.get_encoding(model)
|
|
||||||
|
|
||||||
if model.startswith("gpt-3.5-turbo"):
|
|
||||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
|
||||||
tokens_per_message = 4
|
|
||||||
# if there's a name, the role is omitted
|
|
||||||
tokens_per_name = -1
|
|
||||||
elif model.startswith("gpt-4"):
|
|
||||||
tokens_per_message = 3
|
|
||||||
tokens_per_name = 1
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"get_num_tokens_from_messages() is not presently implemented "
|
|
||||||
f"for model {model}."
|
|
||||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
|
||||||
"information on how messages are converted to tokens."
|
|
||||||
)
|
|
||||||
num_tokens = 0
|
|
||||||
for m in messages:
|
|
||||||
message = _convert_message_to_dict(m)
|
|
||||||
num_tokens += tokens_per_message
|
|
||||||
for key, value in message.items():
|
|
||||||
if key == "function_call":
|
|
||||||
for f_key, f_value in value.items():
|
|
||||||
num_tokens += len(encoding.encode(f_key))
|
|
||||||
num_tokens += len(encoding.encode(f_value))
|
|
||||||
else:
|
|
||||||
num_tokens += len(encoding.encode(value))
|
|
||||||
|
|
||||||
if key == "name":
|
|
||||||
num_tokens += tokens_per_name
|
|
||||||
# every reply is primed with <im_start>assistant
|
|
||||||
num_tokens += 3
|
|
||||||
|
|
||||||
if kwargs.get('functions'):
|
|
||||||
for function in kwargs.get('functions'):
|
|
||||||
num_tokens += len(encoding.encode('name'))
|
|
||||||
num_tokens += len(encoding.encode(function.get("name")))
|
|
||||||
num_tokens += len(encoding.encode('description'))
|
|
||||||
num_tokens += len(encoding.encode(function.get("description")))
|
|
||||||
parameters = function.get("parameters")
|
|
||||||
num_tokens += len(encoding.encode('parameters'))
|
|
||||||
if 'title' in parameters:
|
|
||||||
num_tokens += len(encoding.encode('title'))
|
|
||||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
|
||||||
num_tokens += len(encoding.encode('type'))
|
|
||||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
|
||||||
if 'properties' in parameters:
|
|
||||||
num_tokens += len(encoding.encode('properties'))
|
|
||||||
for key, value in parameters.get('properties').items():
|
|
||||||
num_tokens += len(encoding.encode(key))
|
|
||||||
for field_key, field_value in value.items():
|
|
||||||
num_tokens += len(encoding.encode(field_key))
|
|
||||||
if field_key == 'enum':
|
|
||||||
for enum_field in field_value:
|
|
||||||
num_tokens += 3
|
|
||||||
num_tokens += len(encoding.encode(enum_field))
|
|
||||||
else:
|
|
||||||
num_tokens += len(encoding.encode(field_key))
|
|
||||||
num_tokens += len(encoding.encode(str(field_value)))
|
|
||||||
if 'required' in parameters:
|
|
||||||
num_tokens += len(encoding.encode('required'))
|
|
||||||
for required_field in parameters['required']:
|
|
||||||
num_tokens += 3
|
|
||||||
num_tokens += len(encoding.encode(required_field))
|
|
||||||
|
|
||||||
return num_tokens
|
|
@ -1,306 +0,0 @@
|
|||||||
import re
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any, Optional, Union, cast
|
|
||||||
|
|
||||||
from langchain import BasePromptTemplate, PromptTemplate
|
|
||||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
|
||||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
|
||||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
|
||||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
|
||||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
|
|
||||||
from langchain.schema import (
|
|
||||||
AgentAction,
|
|
||||||
AgentFinish,
|
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
HumanMessage,
|
|
||||||
OutputParserException,
|
|
||||||
get_buffer_string,
|
|
||||||
)
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
|
|
||||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
|
||||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
|
||||||
from core.chain.llm_chain import LLMChain
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
|
||||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
|
||||||
|
|
||||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
|
||||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
|
||||||
Valid "action" values: "Final Answer" or {tool_names}
|
|
||||||
|
|
||||||
Provide only ONE action per $JSON_BLOB, as shown:
|
|
||||||
|
|
||||||
```
|
|
||||||
{{{{
|
|
||||||
"action": $TOOL_NAME,
|
|
||||||
"action_input": $INPUT
|
|
||||||
}}}}
|
|
||||||
```
|
|
||||||
|
|
||||||
Follow this format:
|
|
||||||
|
|
||||||
Question: input question to answer
|
|
||||||
Thought: consider previous and subsequent steps
|
|
||||||
Action:
|
|
||||||
```
|
|
||||||
$JSON_BLOB
|
|
||||||
```
|
|
||||||
Observation: action result
|
|
||||||
... (repeat Thought/Action/Observation N times)
|
|
||||||
Thought: I know what to respond
|
|
||||||
Action:
|
|
||||||
```
|
|
||||||
{{{{
|
|
||||||
"action": "Final Answer",
|
|
||||||
"action_input": "Final response to human"
|
|
||||||
}}}}
|
|
||||||
```"""
|
|
||||||
|
|
||||||
|
|
||||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|
||||||
moving_summary_buffer: str = ""
|
|
||||||
moving_summary_index: int = 0
|
|
||||||
summary_model_config: ModelConfigEntity = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
def should_use_agent(self, query: str):
|
|
||||||
"""
|
|
||||||
return should use agent
|
|
||||||
Using the ReACT mode to determine whether an agent is needed is costly,
|
|
||||||
so it's better to just use an Agent for reasoning, which is cheaper.
|
|
||||||
|
|
||||||
:param query:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
def plan(
|
|
||||||
self,
|
|
||||||
intermediate_steps: list[tuple[AgentAction, str]],
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
|
||||||
"""Given input, decided what to do.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
intermediate_steps: Steps the LLM has taken to date,
|
|
||||||
along with observatons
|
|
||||||
callbacks: Callbacks to run.
|
|
||||||
**kwargs: User inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Action specifying what tool to use.
|
|
||||||
"""
|
|
||||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
|
||||||
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
if prompts:
|
|
||||||
messages = prompts[0].to_messages()
|
|
||||||
|
|
||||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
|
||||||
|
|
||||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
|
|
||||||
if rest_tokens < 0:
|
|
||||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent_decision = self.output_parser.parse(full_output)
|
|
||||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
|
||||||
tool_inputs = agent_decision.tool_input
|
|
||||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
|
||||||
tool_inputs['query'] = kwargs['input']
|
|
||||||
agent_decision.tool_input = tool_inputs
|
|
||||||
return agent_decision
|
|
||||||
except OutputParserException:
|
|
||||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
|
||||||
"I don't know how to respond to that."}, "")
|
|
||||||
|
|
||||||
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
|
|
||||||
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
|
||||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
|
||||||
should_summary_messages = [AIMessage(content=observation)
|
|
||||||
for _, observation in should_summary_intermediate_steps]
|
|
||||||
if self.moving_summary_index == 0:
|
|
||||||
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
|
|
||||||
|
|
||||||
self.moving_summary_index = len(intermediate_steps)
|
|
||||||
else:
|
|
||||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
|
||||||
raise ExceededLLMTokensLimitError(error_msg)
|
|
||||||
|
|
||||||
if self.moving_summary_buffer and 'chat_history' in kwargs:
|
|
||||||
kwargs["chat_history"].pop()
|
|
||||||
|
|
||||||
self.moving_summary_buffer = self.predict_new_summary(
|
|
||||||
messages=should_summary_messages,
|
|
||||||
existing_summary=self.moving_summary_buffer
|
|
||||||
)
|
|
||||||
|
|
||||||
if 'chat_history' in kwargs:
|
|
||||||
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
|
|
||||||
|
|
||||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
|
||||||
|
|
||||||
def predict_new_summary(
|
|
||||||
self, messages: list[BaseMessage], existing_summary: str
|
|
||||||
) -> str:
|
|
||||||
new_lines = get_buffer_string(
|
|
||||||
messages,
|
|
||||||
human_prefix="Human",
|
|
||||||
ai_prefix="AI",
|
|
||||||
)
|
|
||||||
|
|
||||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
|
||||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_prompt(
|
|
||||||
cls,
|
|
||||||
tools: Sequence[BaseTool],
|
|
||||||
prefix: str = PREFIX,
|
|
||||||
suffix: str = SUFFIX,
|
|
||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
|
||||||
input_variables: Optional[list[str]] = None,
|
|
||||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
|
||||||
) -> BasePromptTemplate:
|
|
||||||
tool_strings = []
|
|
||||||
for tool in tools:
|
|
||||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
|
||||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
|
||||||
formatted_tools = "\n".join(tool_strings)
|
|
||||||
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
|
|
||||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
|
||||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
|
||||||
if input_variables is None:
|
|
||||||
input_variables = ["input", "agent_scratchpad"]
|
|
||||||
_memory_prompts = memory_prompts or []
|
|
||||||
messages = [
|
|
||||||
SystemMessagePromptTemplate.from_template(template),
|
|
||||||
*_memory_prompts,
|
|
||||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
|
||||||
]
|
|
||||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_completion_prompt(
|
|
||||||
cls,
|
|
||||||
tools: Sequence[BaseTool],
|
|
||||||
prefix: str = PREFIX,
|
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
|
||||||
input_variables: Optional[list[str]] = None,
|
|
||||||
) -> PromptTemplate:
|
|
||||||
"""Create prompt in the style of the zero shot agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tools: List of tools the agent will have access to, used to format the
|
|
||||||
prompt.
|
|
||||||
prefix: String to put before the list of tools.
|
|
||||||
input_variables: List of input variables the final prompt will expect.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A PromptTemplate with the template assembled from the pieces here.
|
|
||||||
"""
|
|
||||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
|
||||||
Question: {input}
|
|
||||||
Thought: {agent_scratchpad}
|
|
||||||
"""
|
|
||||||
|
|
||||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
|
||||||
tool_names = ", ".join([tool.name for tool in tools])
|
|
||||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
|
||||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
|
||||||
if input_variables is None:
|
|
||||||
input_variables = ["input", "agent_scratchpad"]
|
|
||||||
return PromptTemplate(template=template, input_variables=input_variables)
|
|
||||||
|
|
||||||
def _construct_scratchpad(
|
|
||||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
|
||||||
) -> str:
|
|
||||||
agent_scratchpad = ""
|
|
||||||
for action, observation in intermediate_steps:
|
|
||||||
agent_scratchpad += action.log
|
|
||||||
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
|
||||||
|
|
||||||
if not isinstance(agent_scratchpad, str):
|
|
||||||
raise ValueError("agent_scratchpad should be of type string.")
|
|
||||||
if agent_scratchpad:
|
|
||||||
llm_chain = cast(LLMChain, self.llm_chain)
|
|
||||||
if llm_chain.model_config.mode == "chat":
|
|
||||||
return (
|
|
||||||
f"This was your previous work "
|
|
||||||
f"(but I haven't seen any of it! I only see what "
|
|
||||||
f"you return as final answer):\n{agent_scratchpad}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return agent_scratchpad
|
|
||||||
else:
|
|
||||||
return agent_scratchpad
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm_and_tools(
|
|
||||||
cls,
|
|
||||||
model_config: ModelConfigEntity,
|
|
||||||
tools: Sequence[BaseTool],
|
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
|
||||||
output_parser: Optional[AgentOutputParser] = None,
|
|
||||||
prefix: str = PREFIX,
|
|
||||||
suffix: str = SUFFIX,
|
|
||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
|
||||||
input_variables: Optional[list[str]] = None,
|
|
||||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
|
||||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Agent:
|
|
||||||
"""Construct an agent from an LLM and tools."""
|
|
||||||
cls._validate_tools(tools)
|
|
||||||
if model_config.mode == "chat":
|
|
||||||
prompt = cls.create_prompt(
|
|
||||||
tools,
|
|
||||||
prefix=prefix,
|
|
||||||
suffix=suffix,
|
|
||||||
human_message_template=human_message_template,
|
|
||||||
format_instructions=format_instructions,
|
|
||||||
input_variables=input_variables,
|
|
||||||
memory_prompts=memory_prompts,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = cls.create_completion_prompt(
|
|
||||||
tools,
|
|
||||||
prefix=prefix,
|
|
||||||
format_instructions=format_instructions,
|
|
||||||
input_variables=input_variables,
|
|
||||||
)
|
|
||||||
llm_chain = LLMChain(
|
|
||||||
model_config=model_config,
|
|
||||||
prompt=prompt,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
agent_llm_callback=agent_llm_callback,
|
|
||||||
parameters={
|
|
||||||
'temperature': 0.2,
|
|
||||||
'top_p': 0.3,
|
|
||||||
'max_tokens': 1500
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tool_names = [tool.name for tool in tools]
|
|
||||||
_output_parser = output_parser
|
|
||||||
return cls(
|
|
||||||
llm_chain=llm_chain,
|
|
||||||
allowed_tools=tool_names,
|
|
||||||
output_parser=_output_parser,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|||||||
from core.moderation.base import ModerationException
|
from core.moderation.base import ModerationException
|
||||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
|
from models.model import App, Conversation, Message, MessageAgentThought
|
||||||
from models.tools import ToolConversationVariables
|
from models.tools import ToolConversationVariables
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -174,11 +173,6 @@ class AssistantApplicationRunner(AppRunner):
|
|||||||
# convert db variables to tool variables
|
# convert db variables to tool variables
|
||||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||||
|
|
||||||
message_chain = self._init_message_chain(
|
|
||||||
message=message,
|
|
||||||
query=query
|
|
||||||
)
|
|
||||||
|
|
||||||
# init model instance
|
# init model instance
|
||||||
model_instance = ModelInstance(
|
model_instance = ModelInstance(
|
||||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||||
@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner):
|
|||||||
'pool': db_variables.variables
|
'pool': db_variables.variables
|
||||||
})
|
})
|
||||||
|
|
||||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
|
||||||
"""
|
|
||||||
Init MessageChain
|
|
||||||
:param message: message
|
|
||||||
:param query: query
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
message_chain = MessageChain(
|
|
||||||
message_id=message.id,
|
|
||||||
type="AgentExecutor",
|
|
||||||
input=json.dumps({
|
|
||||||
"input": query
|
|
||||||
})
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(message_chain)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return message_chain
|
|
||||||
|
|
||||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
|
||||||
"""
|
|
||||||
Save MessageChain
|
|
||||||
:param message_chain: message chain
|
|
||||||
:param output_text: output text
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
message_chain.output = json.dumps({
|
|
||||||
"output": output_text
|
|
||||||
})
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||||
message: Message) -> LLMUsage:
|
message: Message) -> LLMUsage:
|
||||||
"""
|
"""
|
||||||
|
@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
|
|||||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
|
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
|
||||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.moderation.base import ModerationException
|
from core.moderation.base import ModerationException
|
||||||
|
8
api/core/entities/agent_entities.py
Normal file
8
api/core/entities/agent_entities.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class PlanningStrategy(Enum):
|
||||||
|
ROUTER = 'router'
|
||||||
|
REACT_ROUTER = 'react_router'
|
||||||
|
REACT = 'react'
|
||||||
|
FUNCTION_CALL = 'function_call'
|
@ -1,199 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Optional, cast
|
|
||||||
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
|
|
||||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
|
||||||
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
|
|
||||||
from core.application_queue_manager import ApplicationQueueManager
|
|
||||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
|
||||||
from core.entities.application_entities import (
|
|
||||||
AgentEntity,
|
|
||||||
AppOrchestrationConfigEntity,
|
|
||||||
InvokeFrom,
|
|
||||||
ModelConfigEntity,
|
|
||||||
)
|
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
|
||||||
from core.model_runtime.model_providers import model_provider_factory
|
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
||||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset
|
|
||||||
from models.model import Message
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentRunnerFeature:
|
|
||||||
def __init__(self, tenant_id: str,
|
|
||||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
|
||||||
model_config: ModelConfigEntity,
|
|
||||||
config: AgentEntity,
|
|
||||||
queue_manager: ApplicationQueueManager,
|
|
||||||
message: Message,
|
|
||||||
user_id: str,
|
|
||||||
agent_llm_callback: AgentLLMCallback,
|
|
||||||
callback: AgentLoopGatherCallbackHandler,
|
|
||||||
memory: Optional[TokenBufferMemory] = None,) -> None:
|
|
||||||
"""
|
|
||||||
Agent runner
|
|
||||||
:param tenant_id: tenant id
|
|
||||||
:param app_orchestration_config: app orchestration config
|
|
||||||
:param model_config: model config
|
|
||||||
:param config: dataset config
|
|
||||||
:param queue_manager: queue manager
|
|
||||||
:param message: message
|
|
||||||
:param user_id: user id
|
|
||||||
:param agent_llm_callback: agent llm callback
|
|
||||||
:param callback: callback
|
|
||||||
:param memory: memory
|
|
||||||
"""
|
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.app_orchestration_config = app_orchestration_config
|
|
||||||
self.model_config = model_config
|
|
||||||
self.config = config
|
|
||||||
self.queue_manager = queue_manager
|
|
||||||
self.message = message
|
|
||||||
self.user_id = user_id
|
|
||||||
self.agent_llm_callback = agent_llm_callback
|
|
||||||
self.callback = callback
|
|
||||||
self.memory = memory
|
|
||||||
|
|
||||||
def run(self, query: str,
|
|
||||||
invoke_from: InvokeFrom) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Retrieve agent loop result.
|
|
||||||
:param query: query
|
|
||||||
:param invoke_from: invoke from
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
provider = self.config.provider
|
|
||||||
model = self.config.model
|
|
||||||
tool_configs = self.config.tools
|
|
||||||
|
|
||||||
# check model is support tool calling
|
|
||||||
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
|
|
||||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
|
||||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
||||||
|
|
||||||
# get model schema
|
|
||||||
model_schema = model_type_instance.get_model_schema(
|
|
||||||
model=model,
|
|
||||||
credentials=self.model_config.credentials
|
|
||||||
)
|
|
||||||
|
|
||||||
if not model_schema:
|
|
||||||
return None
|
|
||||||
|
|
||||||
planning_strategy = PlanningStrategy.REACT
|
|
||||||
features = model_schema.features
|
|
||||||
if features:
|
|
||||||
if ModelFeature.TOOL_CALL in features \
|
|
||||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
|
||||||
planning_strategy = PlanningStrategy.FUNCTION_CALL
|
|
||||||
|
|
||||||
tools = self.to_tools(
|
|
||||||
tool_configs=tool_configs,
|
|
||||||
invoke_from=invoke_from,
|
|
||||||
callbacks=[self.callback, DifyStdOutCallbackHandler()],
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(tools) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
agent_configuration = AgentConfiguration(
|
|
||||||
strategy=planning_strategy,
|
|
||||||
model_config=self.model_config,
|
|
||||||
tools=tools,
|
|
||||||
memory=self.memory,
|
|
||||||
max_iterations=10,
|
|
||||||
max_execution_time=400.0,
|
|
||||||
early_stopping_method="generate",
|
|
||||||
agent_llm_callback=self.agent_llm_callback,
|
|
||||||
callbacks=[self.callback, DifyStdOutCallbackHandler()]
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_executor = AgentExecutor(agent_configuration)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# check if should use agent
|
|
||||||
should_use_agent = agent_executor.should_use_agent(query)
|
|
||||||
if not should_use_agent:
|
|
||||||
return None
|
|
||||||
|
|
||||||
result = agent_executor.run(query)
|
|
||||||
return result.output
|
|
||||||
except Exception as ex:
|
|
||||||
logger.exception("agent_executor run failed")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def to_dataset_retriever_tool(self, tool_config: dict,
|
|
||||||
invoke_from: InvokeFrom) \
|
|
||||||
-> Optional[BaseTool]:
|
|
||||||
"""
|
|
||||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
|
||||||
:param tool_config: tool config
|
|
||||||
:param invoke_from: invoke from
|
|
||||||
"""
|
|
||||||
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
|
|
||||||
|
|
||||||
hit_callback = DatasetIndexToolCallbackHandler(
|
|
||||||
queue_manager=self.queue_manager,
|
|
||||||
app_id=self.message.app_id,
|
|
||||||
message_id=self.message.id,
|
|
||||||
user_id=self.user_id,
|
|
||||||
invoke_from=invoke_from
|
|
||||||
)
|
|
||||||
|
|
||||||
# get dataset from dataset id
|
|
||||||
dataset = db.session.query(Dataset).filter(
|
|
||||||
Dataset.tenant_id == self.tenant_id,
|
|
||||||
Dataset.id == tool_config.get("id")
|
|
||||||
).first()
|
|
||||||
|
|
||||||
# pass if dataset is not available
|
|
||||||
if not dataset:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# pass if dataset is not available
|
|
||||||
if (dataset and dataset.available_document_count == 0
|
|
||||||
and dataset.available_document_count == 0):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# get retrieval model config
|
|
||||||
default_retrieval_model = {
|
|
||||||
'search_method': 'semantic_search',
|
|
||||||
'reranking_enable': False,
|
|
||||||
'reranking_model': {
|
|
||||||
'reranking_provider_name': '',
|
|
||||||
'reranking_model_name': ''
|
|
||||||
},
|
|
||||||
'top_k': 2,
|
|
||||||
'score_threshold_enabled': False
|
|
||||||
}
|
|
||||||
|
|
||||||
retrieval_model_config = dataset.retrieval_model \
|
|
||||||
if dataset.retrieval_model else default_retrieval_model
|
|
||||||
|
|
||||||
# get top k
|
|
||||||
top_k = retrieval_model_config['top_k']
|
|
||||||
|
|
||||||
# get score threshold
|
|
||||||
score_threshold = None
|
|
||||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
|
||||||
if score_threshold_enabled:
|
|
||||||
score_threshold = retrieval_model_config.get("score_threshold")
|
|
||||||
|
|
||||||
tool = DatasetRetrieverTool.from_dataset(
|
|
||||||
dataset=dataset,
|
|
||||||
top_k=top_k,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
hit_callbacks=[hit_callback],
|
|
||||||
return_resource=show_retrieve_source,
|
|
||||||
retriever_from=invoke_from.to_source()
|
|
||||||
)
|
|
||||||
|
|
||||||
return tool
|
|
@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
|||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
|
||||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||||
|
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
|
||||||
|
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.third_party.langchain.llms.fake import FakeLLM
|
|
||||||
|
|
||||||
|
|
||||||
class LLMChain(LCLLMChain):
|
class LLMChain(LCLLMChain):
|
@ -12,9 +12,9 @@ from pydantic import root_validator
|
|||||||
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||||
|
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||||
from core.third_party.langchain.llms.fake import FakeLLM
|
|
||||||
|
|
||||||
|
|
||||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
|
|||||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
from core.chain.llm_chain import LLMChain
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
|
from core.features.dataset_retrieval.agent.llm_chain import LLMChain
|
||||||
|
|
||||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
@ -1,4 +1,3 @@
|
|||||||
import enum
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
|
|||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
|
||||||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
|
||||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
|
||||||
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
|
||||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.entities.message_entities import prompt_messages_to_lc_messages
|
from core.entities.message_entities import prompt_messages_to_lc_messages
|
||||||
|
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
|
||||||
|
from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||||
|
from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||||
|
from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||||
from core.helper import moderation
|
from core.helper import moderation
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
|
|||||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
|
|
||||||
|
|
||||||
class PlanningStrategy(str, enum.Enum):
|
|
||||||
ROUTER = 'router'
|
|
||||||
REACT_ROUTER = 'react_router'
|
|
||||||
REACT = 'react'
|
|
||||||
FUNCTION_CALL = 'function_call'
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfiguration(BaseModel):
|
class AgentConfiguration(BaseModel):
|
||||||
strategy: PlanningStrategy
|
strategy: PlanningStrategy
|
||||||
model_config: ModelConfigEntity
|
model_config: ModelConfigEntity
|
||||||
@ -62,28 +53,7 @@ class AgentExecutor:
|
|||||||
self.agent = self._init_agent()
|
self.agent = self._init_agent()
|
||||||
|
|
||||||
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
if self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
|
||||||
model_config=self.configuration.model_config,
|
|
||||||
tools=self.configuration.tools,
|
|
||||||
output_parser=StructuredChatOutputParser(),
|
|
||||||
summary_model_config=self.configuration.summary_model_config
|
|
||||||
if self.configuration.summary_model_config else None,
|
|
||||||
agent_llm_callback=self.configuration.agent_llm_callback,
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
|
||||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
|
||||||
model_config=self.configuration.model_config,
|
|
||||||
tools=self.configuration.tools,
|
|
||||||
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
|
||||||
if self.configuration.memory else None, # used for read chat histories memory
|
|
||||||
summary_model_config=self.configuration.summary_model_config
|
|
||||||
if self.configuration.summary_model_config else None,
|
|
||||||
agent_llm_callback=self.configuration.agent_llm_callback,
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
|
||||||
self.configuration.tools = [t for t in self.configuration.tools
|
self.configuration.tools = [t for t in self.configuration.tools
|
||||||
if isinstance(t, DatasetRetrieverTool)
|
if isinstance(t, DatasetRetrieverTool)
|
||||||
or isinstance(t, DatasetMultiRetrieverTool)]
|
or isinstance(t, DatasetMultiRetrieverTool)]
|
@ -2,9 +2,10 @@ from typing import Optional, cast
|
|||||||
|
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
|
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
|
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
|
||||||
|
from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
189
api/core/third_party/spark/spark_llm.py
vendored
189
api/core/third_party/spark/spark_llm.py
vendored
@ -1,189 +0,0 @@
|
|||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
import json
|
|
||||||
import queue
|
|
||||||
import ssl
|
|
||||||
from datetime import datetime
|
|
||||||
from time import mktime
|
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import urlencode, urlparse
|
|
||||||
from wsgiref.handlers import format_date_time
|
|
||||||
|
|
||||||
import websocket
|
|
||||||
|
|
||||||
|
|
||||||
class SparkLLMClient:
|
|
||||||
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
|
|
||||||
domain = 'spark-api.xf-yun.com'
|
|
||||||
endpoint = 'chat'
|
|
||||||
if api_domain:
|
|
||||||
domain = api_domain
|
|
||||||
if model_name == 'spark-v3':
|
|
||||||
endpoint = 'multimodal'
|
|
||||||
|
|
||||||
model_api_configs = {
|
|
||||||
'spark': {
|
|
||||||
'version': 'v1.1',
|
|
||||||
'chat_domain': 'general'
|
|
||||||
},
|
|
||||||
'spark-v2': {
|
|
||||||
'version': 'v2.1',
|
|
||||||
'chat_domain': 'generalv2'
|
|
||||||
},
|
|
||||||
'spark-v3': {
|
|
||||||
'version': 'v3.1',
|
|
||||||
'chat_domain': 'generalv3'
|
|
||||||
},
|
|
||||||
'spark-v3.5': {
|
|
||||||
'version': 'v3.5',
|
|
||||||
'chat_domain': 'generalv3.5'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
api_version = model_api_configs[model_name]['version']
|
|
||||||
|
|
||||||
self.chat_domain = model_api_configs[model_name]['chat_domain']
|
|
||||||
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
|
|
||||||
self.app_id = app_id
|
|
||||||
self.ws_url = self.create_url(
|
|
||||||
urlparse(self.api_base).netloc,
|
|
||||||
urlparse(self.api_base).path,
|
|
||||||
self.api_base,
|
|
||||||
api_key,
|
|
||||||
api_secret
|
|
||||||
)
|
|
||||||
|
|
||||||
self.queue = queue.Queue()
|
|
||||||
self.blocking_message = ''
|
|
||||||
|
|
||||||
def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
|
|
||||||
# generate timestamp by RFC1123
|
|
||||||
now = datetime.now()
|
|
||||||
date = format_date_time(mktime(now.timetuple()))
|
|
||||||
|
|
||||||
signature_origin = "host: " + host + "\n"
|
|
||||||
signature_origin += "date: " + date + "\n"
|
|
||||||
signature_origin += "GET " + path + " HTTP/1.1"
|
|
||||||
|
|
||||||
# encrypt using hmac-sha256
|
|
||||||
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
|
||||||
digestmod=hashlib.sha256).digest()
|
|
||||||
|
|
||||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
|
||||||
|
|
||||||
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
|
||||||
|
|
||||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
|
||||||
|
|
||||||
v = {
|
|
||||||
"authorization": authorization,
|
|
||||||
"date": date,
|
|
||||||
"host": host
|
|
||||||
}
|
|
||||||
# generate url
|
|
||||||
url = api_base + '?' + urlencode(v)
|
|
||||||
return url
|
|
||||||
|
|
||||||
def run(self, messages: list, user_id: str,
|
|
||||||
model_kwargs: Optional[dict] = None, streaming: bool = False):
|
|
||||||
websocket.enableTrace(False)
|
|
||||||
ws = websocket.WebSocketApp(
|
|
||||||
self.ws_url,
|
|
||||||
on_message=self.on_message,
|
|
||||||
on_error=self.on_error,
|
|
||||||
on_close=self.on_close,
|
|
||||||
on_open=self.on_open
|
|
||||||
)
|
|
||||||
ws.messages = messages
|
|
||||||
ws.user_id = user_id
|
|
||||||
ws.model_kwargs = model_kwargs
|
|
||||||
ws.streaming = streaming
|
|
||||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
|
||||||
|
|
||||||
def on_error(self, ws, error):
|
|
||||||
self.queue.put({
|
|
||||||
'status_code': error.status_code,
|
|
||||||
'error': error.resp_body.decode('utf-8')
|
|
||||||
})
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
def on_close(self, ws, close_status_code, close_reason):
|
|
||||||
self.queue.put({'done': True})
|
|
||||||
|
|
||||||
def on_open(self, ws):
|
|
||||||
self.blocking_message = ''
|
|
||||||
data = json.dumps(self.gen_params(
|
|
||||||
messages=ws.messages,
|
|
||||||
user_id=ws.user_id,
|
|
||||||
model_kwargs=ws.model_kwargs
|
|
||||||
))
|
|
||||||
ws.send(data)
|
|
||||||
|
|
||||||
def on_message(self, ws, message):
|
|
||||||
data = json.loads(message)
|
|
||||||
code = data['header']['code']
|
|
||||||
if code != 0:
|
|
||||||
self.queue.put({
|
|
||||||
'status_code': 400,
|
|
||||||
'error': f"Code: {code}, Error: {data['header']['message']}"
|
|
||||||
})
|
|
||||||
ws.close()
|
|
||||||
else:
|
|
||||||
choices = data["payload"]["choices"]
|
|
||||||
status = choices["status"]
|
|
||||||
content = choices["text"][0]["content"]
|
|
||||||
if ws.streaming:
|
|
||||||
self.queue.put({'data': content})
|
|
||||||
else:
|
|
||||||
self.blocking_message += content
|
|
||||||
|
|
||||||
if status == 2:
|
|
||||||
if not ws.streaming:
|
|
||||||
self.queue.put({'data': self.blocking_message})
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
def gen_params(self, messages: list, user_id: str,
|
|
||||||
model_kwargs: Optional[dict] = None) -> dict:
|
|
||||||
data = {
|
|
||||||
"header": {
|
|
||||||
"app_id": self.app_id,
|
|
||||||
"uid": user_id
|
|
||||||
},
|
|
||||||
"parameter": {
|
|
||||||
"chat": {
|
|
||||||
"domain": self.chat_domain
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"payload": {
|
|
||||||
"message": {
|
|
||||||
"text": messages
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if model_kwargs:
|
|
||||||
data['parameter']['chat'].update(model_kwargs)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def subscribe(self):
|
|
||||||
while True:
|
|
||||||
content = self.queue.get()
|
|
||||||
if 'error' in content:
|
|
||||||
if content['status_code'] == 401:
|
|
||||||
raise SparkError('[Spark] The credentials you provided are incorrect. '
|
|
||||||
'Please double-check and fill them in again.')
|
|
||||||
elif content['status_code'] == 403:
|
|
||||||
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
|
|
||||||
"Please try again after obtaining the necessary permissions.")
|
|
||||||
else:
|
|
||||||
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
|
|
||||||
|
|
||||||
if 'data' not in content:
|
|
||||||
break
|
|
||||||
yield content
|
|
||||||
|
|
||||||
|
|
||||||
class SparkError(Exception):
|
|
||||||
pass
|
|
@ -1,24 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class DatetimeToolInput(BaseModel):
|
|
||||||
type: str = Field(..., description="Type for current time, must be: datetime.")
|
|
||||||
|
|
||||||
|
|
||||||
class DatetimeTool(BaseTool):
|
|
||||||
"""Tool for querying current datetime."""
|
|
||||||
name: str = "current_datetime"
|
|
||||||
args_schema: type[BaseModel] = DatetimeToolInput
|
|
||||||
description: str = "A tool when you want to get the current date, time, week, month or year, " \
|
|
||||||
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\"."
|
|
||||||
|
|
||||||
def _run(self, type: str) -> str:
|
|
||||||
# get current time
|
|
||||||
current_time = datetime.utcnow()
|
|
||||||
return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A")
|
|
||||||
|
|
||||||
async def _arun(self, tool_input: str) -> str:
|
|
||||||
raise NotImplementedError()
|
|
@ -1,63 +0,0 @@
|
|||||||
import base64
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from libs import rsa
|
|
||||||
from models.account import Tenant
|
|
||||||
from models.tool import ToolProvider, ToolProviderName
|
|
||||||
|
|
||||||
|
|
||||||
class BaseToolProvider(ABC):
|
|
||||||
def __init__(self, tenant_id: str):
|
|
||||||
self.tenant_id = tenant_id
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_provider_name(self) -> ToolProviderName:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def credentials_validate(self, credentials: dict):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
|
|
||||||
"""
|
|
||||||
Returns the Provider instance for the given tenant_id and tool_name.
|
|
||||||
"""
|
|
||||||
query = db.session.query(ToolProvider).filter(
|
|
||||||
ToolProvider.tenant_id == self.tenant_id,
|
|
||||||
ToolProvider.tool_name == self.get_provider_name().value
|
|
||||||
)
|
|
||||||
|
|
||||||
if must_enabled:
|
|
||||||
query = query.filter(ToolProvider.is_enabled == True)
|
|
||||||
|
|
||||||
return query.first()
|
|
||||||
|
|
||||||
def encrypt_token(self, token) -> str:
|
|
||||||
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
|
||||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
|
||||||
return base64.b64encode(encrypted_token).decode()
|
|
||||||
|
|
||||||
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
|
|
||||||
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
|
|
||||||
|
|
||||||
if obfuscated:
|
|
||||||
return self._obfuscated_token(token)
|
|
||||||
|
|
||||||
return token
|
|
||||||
|
|
||||||
def _obfuscated_token(self, token: str) -> str:
|
|
||||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
|
@ -1,2 +0,0 @@
|
|||||||
class ToolValidateFailedError(Exception):
|
|
||||||
description = "Tool Provider Validate failed"
|
|
@ -1,77 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from core.tool.provider.base import BaseToolProvider
|
|
||||||
from core.tool.provider.errors import ToolValidateFailedError
|
|
||||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
|
|
||||||
from models.tool import ToolProviderName
|
|
||||||
|
|
||||||
|
|
||||||
class SerpAPIToolProvider(BaseToolProvider):
|
|
||||||
def get_provider_name(self) -> ToolProviderName:
|
|
||||||
"""
|
|
||||||
Returns the name of the provider.
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return ToolProviderName.SERPAPI
|
|
||||||
|
|
||||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Returns the credentials for SerpAPI as a dictionary.
|
|
||||||
|
|
||||||
:param obfuscated: obfuscate credentials if True
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
tool_provider = self.get_provider(must_enabled=True)
|
|
||||||
if not tool_provider:
|
|
||||||
return None
|
|
||||||
|
|
||||||
credentials = tool_provider.credentials
|
|
||||||
if not credentials:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if credentials.get('api_key'):
|
|
||||||
credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
|
|
||||||
|
|
||||||
return credentials
|
|
||||||
|
|
||||||
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Returns the credentials function kwargs as a dictionary.
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
credentials = self.get_credentials()
|
|
||||||
if not credentials:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return {
|
|
||||||
'serpapi_api_key': credentials.get('api_key')
|
|
||||||
}
|
|
||||||
|
|
||||||
def credentials_validate(self, credentials: dict):
|
|
||||||
"""
|
|
||||||
Validates the given credentials.
|
|
||||||
|
|
||||||
:param credentials:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if 'api_key' not in credentials or not credentials.get('api_key'):
|
|
||||||
raise ToolValidateFailedError("SerpAPI api_key is required.")
|
|
||||||
|
|
||||||
api_key = credentials.get('api_key')
|
|
||||||
|
|
||||||
try:
|
|
||||||
OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
|
|
||||||
except Exception as e:
|
|
||||||
raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
|
|
||||||
|
|
||||||
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Encrypts the given credentials.
|
|
||||||
|
|
||||||
:param credentials:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
|
|
||||||
return credentials
|
|
@ -1,43 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from core.tool.provider.base import BaseToolProvider
|
|
||||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderService:
|
|
||||||
|
|
||||||
def __init__(self, tenant_id: str, provider_name: str):
|
|
||||||
self.provider = self._init_provider(tenant_id, provider_name)
|
|
||||||
|
|
||||||
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
|
|
||||||
if provider_name == 'serpapi':
|
|
||||||
return SerpAPIToolProvider(tenant_id)
|
|
||||||
else:
|
|
||||||
raise Exception('tool provider {} not found'.format(provider_name))
|
|
||||||
|
|
||||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Returns the credentials for Tool as a dictionary.
|
|
||||||
|
|
||||||
:param obfuscated:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self.provider.get_credentials(obfuscated)
|
|
||||||
|
|
||||||
def credentials_validate(self, credentials: dict):
|
|
||||||
"""
|
|
||||||
Validates the given credentials.
|
|
||||||
|
|
||||||
:param credentials:
|
|
||||||
:raises: ValidateFailedError
|
|
||||||
"""
|
|
||||||
return self.provider.credentials_validate(credentials)
|
|
||||||
|
|
||||||
def encrypt_credentials(self, credentials: dict):
|
|
||||||
"""
|
|
||||||
Encrypts the given credentials.
|
|
||||||
|
|
||||||
:param credentials:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self.provider.encrypt_credentials(credentials)
|
|
@ -1,51 +0,0 @@
|
|||||||
from langchain import SerpAPIWrapper
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizedSerpAPIInput(BaseModel):
|
|
||||||
query: str = Field(..., description="search query.")
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizedSerpAPIWrapper(SerpAPIWrapper):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _process_response(res: dict, num_results: int = 5) -> str:
|
|
||||||
"""Process response from SerpAPI."""
|
|
||||||
if "error" in res.keys():
|
|
||||||
raise ValueError(f"Got error from SerpAPI: {res['error']}")
|
|
||||||
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
|
|
||||||
res["answer_box"] = res["answer_box"][0]
|
|
||||||
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
|
||||||
toret = res["answer_box"]["answer"]
|
|
||||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
|
||||||
toret = res["answer_box"]["snippet"]
|
|
||||||
elif (
|
|
||||||
"answer_box" in res.keys()
|
|
||||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
|
||||||
):
|
|
||||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
|
||||||
elif (
|
|
||||||
"sports_results" in res.keys()
|
|
||||||
and "game_spotlight" in res["sports_results"].keys()
|
|
||||||
):
|
|
||||||
toret = res["sports_results"]["game_spotlight"]
|
|
||||||
elif (
|
|
||||||
"shopping_results" in res.keys()
|
|
||||||
and "title" in res["shopping_results"][0].keys()
|
|
||||||
):
|
|
||||||
toret = res["shopping_results"][:3]
|
|
||||||
elif (
|
|
||||||
"knowledge_graph" in res.keys()
|
|
||||||
and "description" in res["knowledge_graph"].keys()
|
|
||||||
):
|
|
||||||
toret = res["knowledge_graph"]["description"]
|
|
||||||
elif 'organic_results' in res.keys() and len(res['organic_results']) > 0:
|
|
||||||
toret = ""
|
|
||||||
for result in res["organic_results"][:num_results]:
|
|
||||||
if "link" in result:
|
|
||||||
toret += "----------------\nlink: " + result["link"] + "\n"
|
|
||||||
if "snippet" in result:
|
|
||||||
toret += "snippet: " + result["snippet"] + "\n"
|
|
||||||
else:
|
|
||||||
toret = "No good search result found"
|
|
||||||
return "search result:\n" + toret
|
|
@ -1,443 +0,0 @@
|
|||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import site
|
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
import unicodedata
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString
|
|
||||||
from langchain.chains import RefineDocumentsChain
|
|
||||||
from langchain.chains.summarize import refine_prompts
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
||||||
from langchain.tools.base import BaseTool
|
|
||||||
from newspaper import Article
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from regex import regex
|
|
||||||
|
|
||||||
from core.chain.llm_chain import LLMChain
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
|
||||||
from core.rag.extractor import extract_processor
|
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
|
||||||
from core.rag.models.document import Document
|
|
||||||
|
|
||||||
FULL_TEMPLATE = """
|
|
||||||
TITLE: {title}
|
|
||||||
AUTHORS: {authors}
|
|
||||||
PUBLISH DATE: {publish_date}
|
|
||||||
TOP_IMAGE_URL: {top_image}
|
|
||||||
TEXT:
|
|
||||||
|
|
||||||
{text}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class WebReaderToolInput(BaseModel):
|
|
||||||
url: str = Field(..., description="URL of the website to read")
|
|
||||||
summary: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description="When the user's question requires extracting the summarizing content of the webpage, "
|
|
||||||
"set it to true."
|
|
||||||
)
|
|
||||||
cursor: int = Field(
|
|
||||||
default=0,
|
|
||||||
description="Start reading from this character."
|
|
||||||
"Use when the first response was truncated"
|
|
||||||
"and you want to continue reading the page."
|
|
||||||
"The value cannot exceed 24000.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WebReaderTool(BaseTool):
|
|
||||||
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
|
|
||||||
|
|
||||||
name: str = "web_reader"
|
|
||||||
args_schema: type[BaseModel] = WebReaderToolInput
|
|
||||||
description: str = "use this to read a website. " \
|
|
||||||
"If you can answer the question based on the information provided, " \
|
|
||||||
"there is no need to use."
|
|
||||||
page_contents: str = None
|
|
||||||
url: str = None
|
|
||||||
max_chunk_length: int = 4000
|
|
||||||
summary_chunk_tokens: int = 4000
|
|
||||||
summary_chunk_overlap: int = 0
|
|
||||||
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
|
|
||||||
continue_reading: bool = True
|
|
||||||
model_config: ModelConfigEntity
|
|
||||||
model_parameters: dict[str, Any]
|
|
||||||
|
|
||||||
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
|
|
||||||
try:
|
|
||||||
if not self.page_contents or self.url != url:
|
|
||||||
page_contents = get_url(url)
|
|
||||||
self.page_contents = page_contents
|
|
||||||
self.url = url
|
|
||||||
else:
|
|
||||||
page_contents = self.page_contents
|
|
||||||
except Exception as e:
|
|
||||||
return f'Read this website failed, caused by: {str(e)}.'
|
|
||||||
|
|
||||||
if summary:
|
|
||||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
|
||||||
chunk_size=self.summary_chunk_tokens,
|
|
||||||
chunk_overlap=self.summary_chunk_overlap,
|
|
||||||
separators=self.summary_separators
|
|
||||||
)
|
|
||||||
|
|
||||||
texts = character_splitter.split_text(page_contents)
|
|
||||||
docs = [Document(page_content=t) for t in texts]
|
|
||||||
|
|
||||||
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
|
|
||||||
return "No content found."
|
|
||||||
|
|
||||||
# only use first 5 docs
|
|
||||||
if len(docs) > 5:
|
|
||||||
docs = docs[:5]
|
|
||||||
|
|
||||||
chain = self.get_summary_chain()
|
|
||||||
try:
|
|
||||||
page_contents = chain.run(docs)
|
|
||||||
except Exception as e:
|
|
||||||
return f'Read this website failed, caused by: {str(e)}.'
|
|
||||||
else:
|
|
||||||
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
|
|
||||||
|
|
||||||
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
|
|
||||||
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
|
|
||||||
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
|
|
||||||
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
|
|
||||||
|
|
||||||
return page_contents
|
|
||||||
|
|
||||||
async def _arun(self, url: str) -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_summary_chain(self) -> RefineDocumentsChain:
|
|
||||||
initial_chain = LLMChain(
|
|
||||||
model_config=self.model_config,
|
|
||||||
prompt=refine_prompts.PROMPT,
|
|
||||||
parameters=self.model_parameters
|
|
||||||
)
|
|
||||||
refine_chain = LLMChain(
|
|
||||||
model_config=self.model_config,
|
|
||||||
prompt=refine_prompts.REFINE_PROMPT,
|
|
||||||
parameters=self.model_parameters
|
|
||||||
)
|
|
||||||
return RefineDocumentsChain(
|
|
||||||
initial_llm_chain=initial_chain,
|
|
||||||
refine_llm_chain=refine_chain,
|
|
||||||
document_variable_name="text",
|
|
||||||
initial_response_name="existing_answer",
|
|
||||||
callbacks=self.callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
|
||||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
|
||||||
return text[cursor: cursor + max_length]
|
|
||||||
|
|
||||||
|
|
||||||
def get_url(url: str) -> str:
|
|
||||||
"""Fetch URL and return the contents as a string."""
|
|
||||||
headers = {
|
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
|
||||||
}
|
|
||||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
|
||||||
|
|
||||||
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
|
|
||||||
|
|
||||||
if head_response.status_code != 200:
|
|
||||||
return "URL returned status code {}.".format(head_response.status_code)
|
|
||||||
|
|
||||||
# check content-type
|
|
||||||
main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
|
|
||||||
if main_content_type not in supported_content_types:
|
|
||||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
|
||||||
|
|
||||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
|
||||||
return ExtractProcessor.load_from_url(url, return_text=True)
|
|
||||||
|
|
||||||
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
|
|
||||||
a = extract_using_readabilipy(response.text)
|
|
||||||
|
|
||||||
if not a['plain_text'] or not a['plain_text'].strip():
|
|
||||||
return get_url_from_newspaper3k(url)
|
|
||||||
|
|
||||||
res = FULL_TEMPLATE.format(
|
|
||||||
title=a['title'],
|
|
||||||
authors=a['byline'],
|
|
||||||
publish_date=a['date'],
|
|
||||||
top_image="",
|
|
||||||
text=a['plain_text'] if a['plain_text'] else "",
|
|
||||||
)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def get_url_from_newspaper3k(url: str) -> str:
|
|
||||||
|
|
||||||
a = Article(url)
|
|
||||||
a.download()
|
|
||||||
a.parse()
|
|
||||||
|
|
||||||
res = FULL_TEMPLATE.format(
|
|
||||||
title=a.title,
|
|
||||||
authors=a.authors,
|
|
||||||
publish_date=a.publish_date,
|
|
||||||
top_image=a.top_image,
|
|
||||||
text=a.text,
|
|
||||||
)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def extract_using_readabilipy(html):
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
|
|
||||||
f_html.write(html)
|
|
||||||
f_html.close()
|
|
||||||
html_path = f_html.name
|
|
||||||
|
|
||||||
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
|
|
||||||
article_json_path = html_path + ".json"
|
|
||||||
jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
|
|
||||||
with chdir(jsdir):
|
|
||||||
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
|
|
||||||
|
|
||||||
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
|
|
||||||
with open(article_json_path, encoding="utf-8") as json_file:
|
|
||||||
input_json = json.loads(json_file.read())
|
|
||||||
|
|
||||||
# Deleting files after processing
|
|
||||||
os.unlink(article_json_path)
|
|
||||||
os.unlink(html_path)
|
|
||||||
|
|
||||||
article_json = {
|
|
||||||
"title": None,
|
|
||||||
"byline": None,
|
|
||||||
"date": None,
|
|
||||||
"content": None,
|
|
||||||
"plain_content": None,
|
|
||||||
"plain_text": None
|
|
||||||
}
|
|
||||||
# Populate article fields from readability fields where present
|
|
||||||
if input_json:
|
|
||||||
if "title" in input_json and input_json["title"]:
|
|
||||||
article_json["title"] = input_json["title"]
|
|
||||||
if "byline" in input_json and input_json["byline"]:
|
|
||||||
article_json["byline"] = input_json["byline"]
|
|
||||||
if "date" in input_json and input_json["date"]:
|
|
||||||
article_json["date"] = input_json["date"]
|
|
||||||
if "content" in input_json and input_json["content"]:
|
|
||||||
article_json["content"] = input_json["content"]
|
|
||||||
article_json["plain_content"] = plain_content(article_json["content"], False, False)
|
|
||||||
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
|
|
||||||
if "textContent" in input_json and input_json["textContent"]:
|
|
||||||
article_json["plain_text"] = input_json["textContent"]
|
|
||||||
article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
|
|
||||||
|
|
||||||
return article_json
|
|
||||||
|
|
||||||
|
|
||||||
def find_module_path(module_name):
|
|
||||||
for package_path in site.getsitepackages():
|
|
||||||
potential_path = os.path.join(package_path, module_name)
|
|
||||||
if os.path.exists(potential_path):
|
|
||||||
return potential_path
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def chdir(path):
|
|
||||||
"""Change directory in context and return to original on exit"""
|
|
||||||
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
|
|
||||||
original_path = os.getcwd()
|
|
||||||
os.chdir(path)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
os.chdir(original_path)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_text_blocks_as_plain_text(paragraph_html):
|
|
||||||
# Load article as DOM
|
|
||||||
soup = BeautifulSoup(paragraph_html, 'html.parser')
|
|
||||||
# Select all lists
|
|
||||||
list_elements = soup.find_all(['ul', 'ol'])
|
|
||||||
# Prefix text in all list items with "* " and make lists paragraphs
|
|
||||||
for list_element in list_elements:
|
|
||||||
plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
|
|
||||||
list_element.string = plain_items
|
|
||||||
list_element.name = "p"
|
|
||||||
# Select all text blocks
|
|
||||||
text_blocks = [s.parent for s in soup.find_all(string=True)]
|
|
||||||
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
|
|
||||||
# Drop empty paragraphs
|
|
||||||
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
|
|
||||||
return text_blocks
|
|
||||||
|
|
||||||
|
|
||||||
def plain_text_leaf_node(element):
|
|
||||||
# Extract all text, stripped of any child HTML elements and normalise it
|
|
||||||
plain_text = normalise_text(element.get_text())
|
|
||||||
if plain_text != "" and element.name == "li":
|
|
||||||
plain_text = "* {}, ".format(plain_text)
|
|
||||||
if plain_text == "":
|
|
||||||
plain_text = None
|
|
||||||
if "data-node-index" in element.attrs:
|
|
||||||
plain = {"node_index": element["data-node-index"], "text": plain_text}
|
|
||||||
else:
|
|
||||||
plain = {"text": plain_text}
|
|
||||||
return plain
|
|
||||||
|
|
||||||
|
|
||||||
def plain_content(readability_content, content_digests, node_indexes):
|
|
||||||
# Load article as DOM
|
|
||||||
soup = BeautifulSoup(readability_content, 'html.parser')
|
|
||||||
# Make all elements plain
|
|
||||||
elements = plain_elements(soup.contents, content_digests, node_indexes)
|
|
||||||
if node_indexes:
|
|
||||||
# Add node index attributes to nodes
|
|
||||||
elements = [add_node_indexes(element) for element in elements]
|
|
||||||
# Replace article contents with plain elements
|
|
||||||
soup.contents = elements
|
|
||||||
return str(soup)
|
|
||||||
|
|
||||||
|
|
||||||
def plain_elements(elements, content_digests, node_indexes):
|
|
||||||
# Get plain content versions of all elements
|
|
||||||
elements = [plain_element(element, content_digests, node_indexes)
|
|
||||||
for element in elements]
|
|
||||||
if content_digests:
|
|
||||||
# Add content digest attribute to nodes
|
|
||||||
elements = [add_content_digest(element) for element in elements]
|
|
||||||
return elements
|
|
||||||
|
|
||||||
|
|
||||||
def plain_element(element, content_digests, node_indexes):
|
|
||||||
# For lists, we make each item plain text
|
|
||||||
if is_leaf(element):
|
|
||||||
# For leaf node elements, extract the text content, discarding any HTML tags
|
|
||||||
# 1. Get element contents as text
|
|
||||||
plain_text = element.get_text()
|
|
||||||
# 2. Normalise the extracted text string to a canonical representation
|
|
||||||
plain_text = normalise_text(plain_text)
|
|
||||||
# 3. Update element content to be plain text
|
|
||||||
element.string = plain_text
|
|
||||||
elif is_text(element):
|
|
||||||
if is_non_printing(element):
|
|
||||||
# The simplified HTML may have come from Readability.js so might
|
|
||||||
# have non-printing text (e.g. Comment or CData). In this case, we
|
|
||||||
# keep the structure, but ensure that the string is empty.
|
|
||||||
element = type(element)("")
|
|
||||||
else:
|
|
||||||
plain_text = element.string
|
|
||||||
plain_text = normalise_text(plain_text)
|
|
||||||
element = type(element)(plain_text)
|
|
||||||
else:
|
|
||||||
# If not a leaf node or leaf type call recursively on child nodes, replacing
|
|
||||||
element.contents = plain_elements(element.contents, content_digests, node_indexes)
|
|
||||||
return element
|
|
||||||
|
|
||||||
|
|
||||||
def add_node_indexes(element, node_index="0"):
|
|
||||||
# Can't add attributes to string types
|
|
||||||
if is_text(element):
|
|
||||||
return element
|
|
||||||
# Add index to current element
|
|
||||||
element["data-node-index"] = node_index
|
|
||||||
# Add index to child elements
|
|
||||||
for local_idx, child in enumerate(
|
|
||||||
[c for c in element.contents if not is_text(c)], start=1):
|
|
||||||
# Can't add attributes to leaf string types
|
|
||||||
child_index = "{stem}.{local}".format(
|
|
||||||
stem=node_index, local=local_idx)
|
|
||||||
add_node_indexes(child, node_index=child_index)
|
|
||||||
return element
|
|
||||||
|
|
||||||
|
|
||||||
def normalise_text(text):
|
|
||||||
"""Normalise unicode and whitespace."""
|
|
||||||
# Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
|
|
||||||
text = strip_control_characters(text)
|
|
||||||
text = normalise_unicode(text)
|
|
||||||
text = normalise_whitespace(text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def strip_control_characters(text):
|
|
||||||
"""Strip out unicode control characters which might break the parsing."""
|
|
||||||
# Unicode control characters
|
|
||||||
# [Cc]: Other, Control [includes new lines]
|
|
||||||
# [Cf]: Other, Format
|
|
||||||
# [Cn]: Other, Not Assigned
|
|
||||||
# [Co]: Other, Private Use
|
|
||||||
# [Cs]: Other, Surrogate
|
|
||||||
control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
|
|
||||||
retained_chars = ['\t', '\n', '\r', '\f']
|
|
||||||
|
|
||||||
# Remove non-printing control characters
|
|
||||||
return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
|
|
||||||
|
|
||||||
|
|
||||||
def normalise_unicode(text):
|
|
||||||
"""Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
|
||||||
normal_form = "NFKC"
|
|
||||||
text = unicodedata.normalize(normal_form, text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def normalise_whitespace(text):
|
|
||||||
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
|
|
||||||
text = regex.sub(r"\s+", " ", text)
|
|
||||||
# Remove leading and trailing whitespace
|
|
||||||
text = text.strip()
|
|
||||||
return text
|
|
||||||
|
|
||||||
def is_leaf(element):
|
|
||||||
return (element.name in ['p', 'li'])
|
|
||||||
|
|
||||||
|
|
||||||
def is_text(element):
|
|
||||||
return isinstance(element, NavigableString)
|
|
||||||
|
|
||||||
|
|
||||||
def is_non_printing(element):
|
|
||||||
return any(isinstance(element, _e) for _e in [Comment, CData])
|
|
||||||
|
|
||||||
|
|
||||||
def add_content_digest(element):
|
|
||||||
if not is_text(element):
|
|
||||||
element["data-content-digest"] = content_digest(element)
|
|
||||||
return element
|
|
||||||
|
|
||||||
|
|
||||||
def content_digest(element):
|
|
||||||
if is_text(element):
|
|
||||||
# Hash
|
|
||||||
trimmed_string = element.string.strip()
|
|
||||||
if trimmed_string == "":
|
|
||||||
digest = ""
|
|
||||||
else:
|
|
||||||
digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
|
|
||||||
else:
|
|
||||||
contents = element.contents
|
|
||||||
num_contents = len(contents)
|
|
||||||
if num_contents == 0:
|
|
||||||
# No hash when no child elements exist
|
|
||||||
digest = ""
|
|
||||||
elif num_contents == 1:
|
|
||||||
# If single child, use digest of child
|
|
||||||
digest = content_digest(contents[0])
|
|
||||||
else:
|
|
||||||
# Build content digest from the "non-empty" digests of child nodes
|
|
||||||
digest = hashlib.sha256()
|
|
||||||
child_digests = list(
|
|
||||||
filter(lambda x: x != "", [content_digest(content) for content in contents]))
|
|
||||||
for child in child_digests:
|
|
||||||
digest.update(child.encode('utf-8'))
|
|
||||||
digest = digest.hexdigest()
|
|
||||||
return digest
|
|
@ -4,7 +4,7 @@ from langchain.tools import BaseTool
|
|||||||
|
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
|
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
|
||||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
|
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
|
@ -7,23 +7,14 @@ import subprocess
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unicodedata
|
import unicodedata
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString
|
from bs4 import BeautifulSoup, CData, Comment, NavigableString
|
||||||
from langchain.chains import RefineDocumentsChain
|
|
||||||
from langchain.chains.summarize import refine_prompts
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
||||||
from langchain.tools.base import BaseTool
|
|
||||||
from newspaper import Article
|
from newspaper import Article
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from regex import regex
|
from regex import regex
|
||||||
|
|
||||||
from core.chain.llm_chain import LLMChain
|
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
|
||||||
from core.rag.extractor import extract_processor
|
from core.rag.extractor import extract_processor
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.models.document import Document
|
|
||||||
|
|
||||||
FULL_TEMPLATE = """
|
FULL_TEMPLATE = """
|
||||||
TITLE: {title}
|
TITLE: {title}
|
||||||
@ -36,106 +27,6 @@ TEXT:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class WebReaderToolInput(BaseModel):
|
|
||||||
url: str = Field(..., description="URL of the website to read")
|
|
||||||
summary: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description="When the user's question requires extracting the summarizing content of the webpage, "
|
|
||||||
"set it to true."
|
|
||||||
)
|
|
||||||
cursor: int = Field(
|
|
||||||
default=0,
|
|
||||||
description="Start reading from this character."
|
|
||||||
"Use when the first response was truncated"
|
|
||||||
"and you want to continue reading the page."
|
|
||||||
"The value cannot exceed 24000.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WebReaderTool(BaseTool):
|
|
||||||
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
|
|
||||||
|
|
||||||
name: str = "web_reader"
|
|
||||||
args_schema: type[BaseModel] = WebReaderToolInput
|
|
||||||
description: str = "use this to read a website. " \
|
|
||||||
"If you can answer the question based on the information provided, " \
|
|
||||||
"there is no need to use."
|
|
||||||
page_contents: str = None
|
|
||||||
url: str = None
|
|
||||||
max_chunk_length: int = 4000
|
|
||||||
summary_chunk_tokens: int = 4000
|
|
||||||
summary_chunk_overlap: int = 0
|
|
||||||
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
|
|
||||||
continue_reading: bool = True
|
|
||||||
model_config: ModelConfigEntity
|
|
||||||
model_parameters: dict[str, Any]
|
|
||||||
|
|
||||||
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
|
|
||||||
try:
|
|
||||||
if not self.page_contents or self.url != url:
|
|
||||||
page_contents = get_url(url)
|
|
||||||
self.page_contents = page_contents
|
|
||||||
self.url = url
|
|
||||||
else:
|
|
||||||
page_contents = self.page_contents
|
|
||||||
except Exception as e:
|
|
||||||
return f'Read this website failed, caused by: {str(e)}.'
|
|
||||||
|
|
||||||
if summary:
|
|
||||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
|
||||||
chunk_size=self.summary_chunk_tokens,
|
|
||||||
chunk_overlap=self.summary_chunk_overlap,
|
|
||||||
separators=self.summary_separators
|
|
||||||
)
|
|
||||||
|
|
||||||
texts = character_splitter.split_text(page_contents)
|
|
||||||
docs = [Document(page_content=t) for t in texts]
|
|
||||||
|
|
||||||
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
|
|
||||||
return "No content found."
|
|
||||||
|
|
||||||
# only use first 5 docs
|
|
||||||
if len(docs) > 5:
|
|
||||||
docs = docs[:5]
|
|
||||||
|
|
||||||
chain = self.get_summary_chain()
|
|
||||||
try:
|
|
||||||
page_contents = chain.run(docs)
|
|
||||||
except Exception as e:
|
|
||||||
return f'Read this website failed, caused by: {str(e)}.'
|
|
||||||
else:
|
|
||||||
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
|
|
||||||
|
|
||||||
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
|
|
||||||
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
|
|
||||||
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
|
|
||||||
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
|
|
||||||
|
|
||||||
return page_contents
|
|
||||||
|
|
||||||
async def _arun(self, url: str) -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_summary_chain(self) -> RefineDocumentsChain:
|
|
||||||
initial_chain = LLMChain(
|
|
||||||
model_config=self.model_config,
|
|
||||||
prompt=refine_prompts.PROMPT,
|
|
||||||
parameters=self.model_parameters
|
|
||||||
)
|
|
||||||
refine_chain = LLMChain(
|
|
||||||
model_config=self.model_config,
|
|
||||||
prompt=refine_prompts.REFINE_PROMPT,
|
|
||||||
parameters=self.model_parameters
|
|
||||||
)
|
|
||||||
return RefineDocumentsChain(
|
|
||||||
initial_llm_chain=initial_chain,
|
|
||||||
refine_llm_chain=refine_chain,
|
|
||||||
document_variable_name="text",
|
|
||||||
initial_response_name="existing_answer",
|
|
||||||
callbacks=self.callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||||
return text[cursor: cursor + max_length]
|
return text[cursor: cursor + max_length]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from core.agent.agent_executor import PlanningStrategy
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||||
from core.model_runtime.model_providers import model_provider_factory
|
from core.model_runtime.model_providers import model_provider_factory
|
||||||
|
Loading…
x
Reference in New Issue
Block a user