This commit is contained in:
jyong 2025-04-24 15:42:30 +08:00
parent c7f4b41920
commit b9ab1555fb
12 changed files with 41 additions and 1179 deletions

View File

@ -6,31 +6,30 @@ from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from models.model import File from models.model import File
from core.tools.__base.tool_runtime import ToolRuntime from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolEntity,
ToolInvokeMessage, ToolInvokeMessage,
ToolParameter, ToolParameter,
ToolProviderType,
) )
class Tool(ABC): class Datasource(ABC):
""" """
The base class of a tool The base class of a datasource
""" """
entity: ToolEntity entity: DatasourceEntity
runtime: ToolRuntime runtime: DatasourceRuntime
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None: def __init__(self, entity: DatasourceEntity, runtime: DatasourceRuntime) -> None:
self.entity = entity self.entity = entity
self.runtime = runtime self.runtime = runtime
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool": def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "Datasource":
""" """
fork a new tool with metadata fork a new datasource with metadata
:return: the new tool :return: the new datasource
""" """
return self.__class__( return self.__class__(
entity=self.entity.model_copy(), entity=self.entity.model_copy(),
@ -38,9 +37,9 @@ class Tool(ABC):
) )
@abstractmethod @abstractmethod
def tool_provider_type(self) -> ToolProviderType: def datasource_provider_type(self) -> DatasourceProviderType:
""" """
get the tool provider type get the datasource provider type
:return: the tool provider type :return: the tool provider type
""" """

View File

@ -4,12 +4,13 @@ from openai import BaseModel
from pydantic import Field from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom from core.tools.entities.tool_entities import ToolInvokeFrom
class ToolRuntime(BaseModel): class DatasourceRuntime(BaseModel):
""" """
Meta data of a tool call processing Meta data of a datasource call processing
""" """
tenant_id: str tenant_id: str
@ -20,17 +21,17 @@ class ToolRuntime(BaseModel):
runtime_parameters: dict[str, Any] = Field(default_factory=dict) runtime_parameters: dict[str, Any] = Field(default_factory=dict)
class FakeToolRuntime(ToolRuntime): class FakeDatasourceRuntime(DatasourceRuntime):
""" """
Fake tool runtime for testing Fake datasource runtime for testing
""" """
def __init__(self): def __init__(self):
super().__init__( super().__init__(
tenant_id="fake_tenant_id", tenant_id="fake_tenant_id",
tool_id="fake_tool_id", datasource_id="fake_datasource_id",
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
tool_invoke_from=ToolInvokeFrom.AGENT, datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
credentials={}, credentials={},
runtime_parameters={}, runtime_parameters={},
) )

View File

@ -36,9 +36,9 @@ from models.enums import CreatedByRole
from models.model import Message, MessageFile from models.model import Message, MessageFile
class ToolEngine: class DatasourceEngine:
""" """
Tool runtime engine take care of the tool executions. Datasource runtime engine take care of the datasource executions.
""" """
@staticmethod @staticmethod

View File

@ -1,7 +1,9 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional from typing import Any, Optional
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType
from core.plugin.manager.datasource import PluginDatasourceManager
from core.plugin.manager.tool import PluginToolManager from core.plugin.manager.tool import PluginToolManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.plugin.utils.converter import convert_parameters_to_plugin_format
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
@ -9,7 +11,7 @@ from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
class DatasourceTool(Tool): class DatasourcePlugin(Datasource):
tenant_id: str tenant_id: str
icon: str icon: str
plugin_unique_identifier: str plugin_unique_identifier: str
@ -31,53 +33,45 @@ class DatasourceTool(Tool):
self, self,
user_id: str, user_id: str,
datasource_parameters: dict[str, Any], datasource_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
rag_pipeline_id: Optional[str] = None, rag_pipeline_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]: ) -> Generator[ToolInvokeMessage, None, None]:
manager = PluginToolManager() manager = PluginDatasourceManager()
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
yield from manager.invoke_first_step( yield from manager.invoke_first_step(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
user_id=user_id, user_id=user_id,
tool_provider=self.entity.identity.provider, datasource_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name, datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials, credentials=self.runtime.credentials,
tool_parameters=tool_parameters, datasource_parameters=datasource_parameters,
conversation_id=conversation_id, rag_pipeline_id=rag_pipeline_id,
app_id=app_id,
message_id=message_id,
) )
def _invoke_second_step( def _invoke_second_step(
self, self,
user_id: str, user_id: str,
datasource_parameters: dict[str, Any], datasource_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
rag_pipeline_id: Optional[str] = None, rag_pipeline_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]: ) -> Generator[ToolInvokeMessage, None, None]:
manager = PluginToolManager() manager = PluginDatasourceManager()
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
yield from manager.invoke( yield from manager.invoke(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
user_id=user_id, user_id=user_id,
tool_provider=self.entity.identity.provider, datasource_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name, datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials, credentials=self.runtime.credentials,
tool_parameters=tool_parameters, datasource_parameters=datasource_parameters,
conversation_id=conversation_id, rag_pipeline_id=rag_pipeline_id,
app_id=app_id,
message_id=message_id,
) )
def fork_tool_runtime(self, runtime: ToolRuntime) -> "DatasourceTool": def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return DatasourceTool( return DatasourcePlugin(
entity=self.entity, entity=self.entity,
runtime=runtime, runtime=runtime,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
@ -87,9 +81,7 @@ class DatasourceTool(Tool):
def get_runtime_parameters( def get_runtime_parameters(
self, self,
conversation_id: Optional[str] = None, rag_pipeline_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[DatasourceParameter]: ) -> list[DatasourceParameter]:
""" """
get the runtime parameters get the runtime parameters
@ -100,16 +92,14 @@ class DatasourceTool(Tool):
if self.runtime_parameters is not None: if self.runtime_parameters is not None:
return self.runtime_parameters return self.runtime_parameters
manager = PluginToolManager() manager = PluginDatasourceManager()
self.runtime_parameters = manager.get_runtime_parameters( self.runtime_parameters = manager.get_runtime_parameters(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
user_id="", user_id="",
provider=self.entity.identity.provider, provider=self.entity.identity.provider,
tool=self.entity.identity.name, datasource=self.entity.identity.name,
credentials=self.runtime.credentials, credentials=self.runtime.credentials,
conversation_id=conversation_id, rag_pipeline_id=rag_pipeline_id,
app_id=app_id,
message_id=message_id,
) )
return self.runtime_parameters return self.runtime_parameters

View File

@ -1,199 +0,0 @@
import threading
from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class DatasetMultiRetrieverToolInput(BaseModel):
query: str = Field(..., description="dataset multi retriever and rerank")
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying multi dataset."""
name: str = "dataset_"
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
description: str = "dataset multi retriever and rerank. "
dataset_ids: list[str]
reranking_provider_name: str
reranking_model_name: str
@classmethod
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
return cls(
name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs
)
def _run(self, query: str) -> str:
threads = []
all_documents: list[RagDocument] = []
for dataset_id in self.dataset_ids:
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"all_documents": all_documents,
"hit_callbacks": self.hit_callbacks,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.reranking_provider_name,
model_type=ModelType.RERANK,
model=self.reranking_model_name,
)
rerank_runner = RerankModelRunner(rerank_model_instance)
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(all_documents)
document_score_list = {}
for item in all_documents:
if item.metadata and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
else:
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
resource_number = 1
for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
"doc_metadata": document.doc_metadata,
}
if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
context_list.append(source)
resource_number += 1
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
return ""
raise RuntimeError("not segments found")
def _retriever(
self,
flask_app: Flask,
dataset_id: str,
query: str,
all_documents: list,
hit_callbacks: list[DatasetIndexToolCallbackHandler],
):
with flask_app.app_context():
dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
)
if not dataset:
return []
for hit_callback in hit_callbacks:
hit_callback.on_query(query, dataset.id)
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
)
if documents:
all_documents.extend(documents)
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
all_documents.extend(documents)

View File

@ -1,33 +0,0 @@
from abc import abstractmethod
from typing import Any, Optional
from msal_extensions.persistence import ABC # type: ignore
from pydantic import BaseModel, ConfigDict
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
class DatasetRetrieverBaseTool(BaseModel, ABC):
"""Tool for querying a Dataset."""
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""

View File

@ -1,202 +0,0 @@
from typing import Any
from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"reranking_mode": "reranking_model",
"top_k": 2,
"score_threshold_enabled": False,
}
class DatasetRetrieverToolInput(BaseModel):
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying a Dataset."""
name: str = "dataset"
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
dataset_id: str
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description
if not description:
description = "useful for when you want to answer queries about the " + dataset.name
description = description.replace("\n", "").replace("\r", "")
return cls(
name=f"dataset_{dataset.id.replace('-', '_')}",
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
**kwargs,
)
def _run(self, query: str) -> str:
dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
)
if not dataset:
return ""
for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)
if dataset.provider == "external":
results = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
)
for external_document in external_documents:
document = RetrievalDocument(
page_content=external_document.get("content"),
metadata=external_document.get("metadata"),
provider="external",
)
if document.metadata is not None:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
# deal with external documents
context_list = []
for position, item in enumerate(results, start=1):
if item.metadata is not None:
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join([item.page_content for item in results]))
else:
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
)
return str("\n".join([document.page_content for document in documents]))
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model")
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights"),
)
else:
documents = []
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
records = RetrievalService.format_retrieval_documents(documents)
if records:
for record in records:
segment = record.segment
if segment.answer:
document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=record.score,
)
)
else:
document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=record.score,
)
)
retrieval_resource_list = []
if self.return_resource:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
if dataset and document:
source = {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id, # type: ignore
"document_name": document.name, # type: ignore
"data_source_type": document.data_source_type, # type: ignore
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": record.score or 0.0,
"doc_metadata": document.doc_metadata, # type: ignore
}
if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
retrieval_resource_list.append(source)
if self.return_resource and retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=lambda x: x.get("score") or 0.0,
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
item["position"] = position # type: ignore
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
return str("\n".join([document_context.content for document_context in document_context_list]))
return ""

View File

@ -1,134 +0,0 @@
from collections.abc import Generator
from typing import Any, Optional
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolDescription,
ToolEntity,
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
class DatasetRetrieverTool(Tool):
retrieval_tool: DatasetRetrieverBaseTool
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
super().__init__(entity, runtime)
self.retrieval_tool = retrieval_tool
@staticmethod
def get_dataset_tools(
tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity | None,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
) -> list["DatasetRetrieverTool"]:
"""
get dataset tool
"""
# check if retrieve_config is valid
if dataset_ids is None or len(dataset_ids) == 0:
return []
if retrieve_config is None:
return []
feature = DatasetRetrieval()
# save original retrieve strategy, and set retrieve strategy to SINGLE
# Agent only support SINGLE mode
original_retriever_mode = retrieve_config.retrieve_strategy
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
retrieval_tools = feature.to_dataset_retriever_tool(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
return_resource=return_resource,
invoke_from=invoke_from,
hit_callback=hit_callback,
)
if retrieval_tools is None or len(retrieval_tools) == 0:
return []
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
# convert retrieval tools to Tools
tools = []
for retrieval_tool in retrieval_tools:
tool = DatasetRetrieverTool(
retrieval_tool=retrieval_tool,
entity=ToolEntity(
identity=ToolIdentity(
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
),
parameters=[],
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
),
runtime=ToolRuntime(tenant_id=tenant_id),
)
tools.append(tool)
return tools
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
return [
ToolParameter(
name="query",
label=I18nObject(en_US="", zh_Hans=""),
human_description=I18nObject(en_US="", zh_Hans=""),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description="Query for the dataset to be used to retrieve the dataset.",
required=True,
default="",
placeholder=I18nObject(en_US="", zh_Hans=""),
),
]
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.DATASET_RETRIEVAL
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke dataset retriever tool
"""
query = tool_parameters.get("query")
if not query:
yield self.create_text_message(text="please input query")
else:
# invoke dataset retriever tool
result = self.retrieval_tool._run(query=query)
yield self.create_text_message(text=result)
def validate_credentials(
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
) -> str | None:
"""
validate the credentials for dataset retriever tool
"""
pass

View File

@ -1,169 +0,0 @@
"""
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
Therefore, a model manager is needed to list/invoke/validate models.
"""
import json
from typing import Optional, cast
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.tools import ToolModelInvoke
class InvokeModelError(Exception):
pass
class ModelInvocationUtils:
@staticmethod
def get_max_llm_context_tokens(
tenant_id: str,
) -> int:
"""
get max llm context tokens of the model
"""
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
if not model_instance:
raise InvokeModelError("Model not found")
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not schema:
raise InvokeModelError("No model schema found")
max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
if max_tokens is None:
return 2048
return max_tokens
@staticmethod
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
"""
calculate tokens from prompt messages and model parameters
"""
# get model instance
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
if not model_instance:
raise InvokeModelError("Model not found")
# get tokens
tokens = model_instance.get_llm_num_tokens(prompt_messages)
return tokens
@staticmethod
def invoke(
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
invoke model with parameters in user's own context
:param user_id: user id
:param tenant_id: tenant id, the tenant id of the creator of the tool
:param tool_type: tool type
:param tool_name: tool name
:param prompt_messages: prompt messages
:return: AssistantPromptMessage
"""
# get model manager
model_manager = ModelManager()
# get model instance
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
# get prompt tokens
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
model_parameters = {
"temperature": 0.8,
"top_p": 0.8,
}
# create tool model invoke
tool_model_invoke = ToolModelInvoke(
user_id=user_id,
tenant_id=tenant_id,
provider=model_instance.provider,
tool_type=tool_type,
tool_name=tool_name,
model_parameters=json.dumps(model_parameters),
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
model_response="",
prompt_tokens=prompt_tokens,
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency="USD",
)
db.session.add(tool_model_invoke)
db.session.commit()
try:
response: LLMResult = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
),
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")
except InvokeBadRequestError as e:
raise InvokeModelError(f"Invoke bad request error: {e}")
except InvokeConnectionError as e:
raise InvokeModelError(f"Invoke connection error: {e}")
except InvokeAuthorizationError as e:
raise InvokeModelError("Invoke authorization error")
except InvokeServerUnavailableError as e:
raise InvokeModelError(f"Invoke server unavailable error: {e}")
except Exception as e:
raise InvokeModelError(f"Invoke error: {e}")
# update tool model invoke
tool_model_invoke.model_response = response.message.content
if response.usage:
tool_model_invoke.answer_tokens = response.usage.completion_tokens
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
tool_model_invoke.provider_response_latency = response.usage.latency
tool_model_invoke.total_price = response.usage.total_price
tool_model_invoke.currency = response.usage.currency
db.session.commit()
return response

View File

@ -1,17 +0,0 @@
import re
def get_image_upload_file_ids(content):
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
matches = re.findall(pattern, content)
image_upload_file_ids = []
for match in matches:
if match[1] == "file-preview":
content_pattern = r"files/([^/]+)/file-preview"
else:
content_pattern = r"files/([^/]+)/image-preview"
content_match = re.search(content_pattern, match[0])
if content_match:
image_upload_file_id = content_match.group(1)
image_upload_file_ids.append(image_upload_file_id)
return image_upload_file_ids

View File

@ -1,375 +0,0 @@
import hashlib
import json
import mimetypes
import os
import re
import site
import subprocess
import tempfile
import unicodedata
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Literal, Optional, cast
from urllib.parse import unquote
import chardet
import cloudscraper # type: ignore
from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore
from regex import regex # type: ignore
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
FULL_TEMPLATE = """
TITLE: {title}
AUTHORS: {authors}
PUBLISH DATE: {publish_date}
TOP_IMAGE_URL: {top_image}
TEXT:
{text}
"""
def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor : cursor + max_length]
def get_url(url: str, user_agent: Optional[str] = None) -> str:
"""Fetch URL and return the contents as a string."""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/91.0.4472.124 Safari/537.36"
}
if user_agent:
headers["User-Agent"] = user_agent
main_content_type = None
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
if response.status_code == 200:
# check content-type
content_type = response.headers.get("Content-Type")
if content_type:
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
else:
content_disposition = response.headers.get("Content-Disposition", "")
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
extension = re.search(r"\.(\w+)$", filename)
if extension:
main_content_type = mimetypes.guess_type(filename)[0]
if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
scraper = cloudscraper.create_scraper()
scraper.perform_request = ssrf_proxy.make_request
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
if response.status_code != 200:
return "URL returned status code {}.".format(response.status_code)
# Detect encoding using chardet
detected_encoding = chardet.detect(response.content)
encoding = detected_encoding["encoding"]
if encoding:
try:
content = response.content.decode(encoding)
except (UnicodeDecodeError, TypeError):
content = response.text
else:
content = response.text
a = extract_using_readabilipy(content)
if not a["plain_text"] or not a["plain_text"].strip():
return ""
res = FULL_TEMPLATE.format(
title=a["title"],
authors=a["byline"],
publish_date=a["date"],
top_image="",
text=a["plain_text"] or "",
)
return res
def extract_using_readabilipy(html):
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html:
f_html.write(html)
f_html.close()
html_path = f_html.name
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
article_json_path = html_path + ".json"
jsdir = os.path.join(find_module_path("readabilipy"), "javascript")
with chdir(jsdir):
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))
# Deleting files after processing
os.unlink(article_json_path)
os.unlink(html_path)
article_json: dict[str, Any] = {
"title": None,
"byline": None,
"date": None,
"content": None,
"plain_content": None,
"plain_text": None,
}
# Populate article fields from readability fields where present
if input_json:
if input_json.get("title"):
article_json["title"] = input_json["title"]
if input_json.get("byline"):
article_json["byline"] = input_json["byline"]
if input_json.get("date"):
article_json["date"] = input_json["date"]
if input_json.get("content"):
article_json["content"] = input_json["content"]
article_json["plain_content"] = plain_content(article_json["content"], False, False)
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
if input_json.get("textContent"):
article_json["plain_text"] = input_json["textContent"]
article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"])
return article_json
def find_module_path(module_name):
for package_path in site.getsitepackages():
potential_path = os.path.join(package_path, module_name)
if os.path.exists(potential_path):
return potential_path
return None
@contextmanager
def chdir(path):
"""Change directory in context and return to original on exit"""
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
original_path = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(original_path)
def extract_text_blocks_as_plain_text(paragraph_html):
# Load article as DOM
soup = BeautifulSoup(paragraph_html, "html.parser")
# Select all lists
list_elements = soup.find_all(["ul", "ol"])
# Prefix text in all list items with "* " and make lists paragraphs
for list_element in list_elements:
plain_items = "".join(
list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")]))
)
list_element.string = plain_items
list_element.name = "p"
# Select all text blocks
text_blocks = [s.parent for s in soup.find_all(string=True)]
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
# Drop empty paragraphs
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
return text_blocks
def plain_text_leaf_node(element):
# Extract all text, stripped of any child HTML elements and normalize it
plain_text = normalize_text(element.get_text())
if plain_text != "" and element.name == "li":
plain_text = "* {}, ".format(plain_text)
if plain_text == "":
plain_text = None
if "data-node-index" in element.attrs:
plain = {"node_index": element["data-node-index"], "text": plain_text}
else:
plain = {"text": plain_text}
return plain
def plain_content(readability_content, content_digests, node_indexes):
# Load article as DOM
soup = BeautifulSoup(readability_content, "html.parser")
# Make all elements plain
elements = plain_elements(soup.contents, content_digests, node_indexes)
if node_indexes:
# Add node index attributes to nodes
elements = [add_node_indexes(element) for element in elements]
# Replace article contents with plain elements
soup.contents = elements
return str(soup)
def plain_elements(elements, content_digests, node_indexes):
# Get plain content versions of all elements
elements = [plain_element(element, content_digests, node_indexes) for element in elements]
if content_digests:
# Add content digest attribute to nodes
elements = [add_content_digest(element) for element in elements]
return elements
def plain_element(element, content_digests, node_indexes):
# For lists, we make each item plain text
if is_leaf(element):
# For leaf node elements, extract the text content, discarding any HTML tags
# 1. Get element contents as text
plain_text = element.get_text()
# 2. Normalize the extracted text string to a canonical representation
plain_text = normalize_text(plain_text)
# 3. Update element content to be plain text
element.string = plain_text
elif is_text(element):
if is_non_printing(element):
# The simplified HTML may have come from Readability.js so might
# have non-printing text (e.g. Comment or CData). In this case, we
# keep the structure, but ensure that the string is empty.
element = type(element)("")
else:
plain_text = element.string
plain_text = normalize_text(plain_text)
element = type(element)(plain_text)
else:
# If not a leaf node or leaf type call recursively on child nodes, replacing
element.contents = plain_elements(element.contents, content_digests, node_indexes)
return element
def add_node_indexes(element, node_index="0"):
# Can't add attributes to string types
if is_text(element):
return element
# Add index to current element
element["data-node-index"] = node_index
# Add index to child elements
for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1):
# Can't add attributes to leaf string types
child_index = "{stem}.{local}".format(stem=node_index, local=local_idx)
add_node_indexes(child, node_index=child_index)
return element
def normalize_text(text):
"""Normalize unicode and whitespace."""
# Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them
text = strip_control_characters(text)
text = normalize_unicode(text)
text = normalize_whitespace(text)
return text
def strip_control_characters(text):
"""Strip out unicode control characters which might break the parsing."""
# Unicode control characters
# [Cc]: Other, Control [includes new lines]
# [Cf]: Other, Format
# [Cn]: Other, Not Assigned
# [Co]: Other, Private Use
# [Cs]: Other, Surrogate
control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"}
retained_chars = ["\t", "\n", "\r", "\f"]
# Remove non-printing control characters
return "".join(
[
"" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char
for char in text
]
)
def normalize_unicode(text):
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC"
text = unicodedata.normalize(normal_form, text)
return text
def normalize_whitespace(text):
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
text = regex.sub(r"\s+", " ", text)
# Remove leading and trailing whitespace
text = text.strip()
return text
def is_leaf(element):
return element.name in {"p", "li"}
def is_text(element):
return isinstance(element, NavigableString)
def is_non_printing(element):
return any(isinstance(element, _e) for _e in [Comment, CData])
def add_content_digest(element):
if not is_text(element):
element["data-content-digest"] = content_digest(element)
return element
def content_digest(element):
digest: Any
if is_text(element):
# Hash
trimmed_string = element.string.strip()
if trimmed_string == "":
digest = ""
else:
digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest()
else:
contents = element.contents
num_contents = len(contents)
if num_contents == 0:
# No hash when no child elements exist
digest = ""
elif num_contents == 1:
# If single child, use digest of child
digest = content_digest(contents[0])
else:
# Build content digest from the "non-empty" digests of child nodes
digest = hashlib.sha256()
child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents]))
for child in child_digests:
digest.update(child.encode("utf-8"))
digest = digest.hexdigest()
return digest
def get_image_upload_file_ids(content):
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
matches = re.findall(pattern, content)
image_upload_file_ids = []
for match in matches:
if match[1] == "file-preview":
content_pattern = r"files/([^/]+)/file-preview"
else:
content_pattern = r"files/([^/]+)/image-preview"
content_match = re.search(content_pattern, match[0])
if content_match:
image_upload_file_id = content_match.group(1)
image_upload_file_ids.append(image_upload_file_id)
return image_upload_file_ids

View File

@ -0,0 +1 @@
{"not_installed": [], "plugin_install_failed": []}