mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 14:25:53 +08:00
r2
This commit is contained in:
parent
c7f4b41920
commit
b9ab1555fb
@ -6,31 +6,30 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from models.model import File
|
from models.model import File
|
||||||
|
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ToolEntity,
|
|
||||||
ToolInvokeMessage,
|
ToolInvokeMessage,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolProviderType,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Tool(ABC):
|
class Datasource(ABC):
|
||||||
"""
|
"""
|
||||||
The base class of a tool
|
The base class of a datasource
|
||||||
"""
|
"""
|
||||||
|
|
||||||
entity: ToolEntity
|
entity: DatasourceEntity
|
||||||
runtime: ToolRuntime
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
|
def __init__(self, entity: DatasourceEntity, runtime: DatasourceRuntime) -> None:
|
||||||
self.entity = entity
|
self.entity = entity
|
||||||
self.runtime = runtime
|
self.runtime = runtime
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
|
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "Datasource":
|
||||||
"""
|
"""
|
||||||
fork a new tool with metadata
|
fork a new datasource with metadata
|
||||||
:return: the new tool
|
:return: the new datasource
|
||||||
"""
|
"""
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
entity=self.entity.model_copy(),
|
entity=self.entity.model_copy(),
|
||||||
@ -38,9 +37,9 @@ class Tool(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def tool_provider_type(self) -> ToolProviderType:
|
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||||
"""
|
"""
|
||||||
get the tool provider type
|
get the datasource provider type
|
||||||
|
|
||||||
:return: the tool provider type
|
:return: the tool provider type
|
||||||
"""
|
"""
|
@ -4,12 +4,13 @@ from openai import BaseModel
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
|
||||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntime(BaseModel):
|
class DatasourceRuntime(BaseModel):
|
||||||
"""
|
"""
|
||||||
Meta data of a tool call processing
|
Meta data of a datasource call processing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
@ -20,17 +21,17 @@ class ToolRuntime(BaseModel):
|
|||||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class FakeToolRuntime(ToolRuntime):
|
class FakeDatasourceRuntime(DatasourceRuntime):
|
||||||
"""
|
"""
|
||||||
Fake tool runtime for testing
|
Fake datasource runtime for testing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
tenant_id="fake_tenant_id",
|
tenant_id="fake_tenant_id",
|
||||||
tool_id="fake_tool_id",
|
datasource_id="fake_datasource_id",
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
|
||||||
credentials={},
|
credentials={},
|
||||||
runtime_parameters={},
|
runtime_parameters={},
|
||||||
)
|
)
|
@ -36,9 +36,9 @@ from models.enums import CreatedByRole
|
|||||||
from models.model import Message, MessageFile
|
from models.model import Message, MessageFile
|
||||||
|
|
||||||
|
|
||||||
class ToolEngine:
|
class DatasourceEngine:
|
||||||
"""
|
"""
|
||||||
Tool runtime engine take care of the tool executions.
|
Datasource runtime engine take care of the datasource executions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
@ -1,7 +1,9 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType
|
from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType
|
||||||
|
from core.plugin.manager.datasource import PluginDatasourceManager
|
||||||
from core.plugin.manager.tool import PluginToolManager
|
from core.plugin.manager.tool import PluginToolManager
|
||||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
@ -9,7 +11,7 @@ from core.tools.__base.tool_runtime import ToolRuntime
|
|||||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||||
|
|
||||||
|
|
||||||
class DatasourceTool(Tool):
|
class DatasourcePlugin(Datasource):
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
icon: str
|
icon: str
|
||||||
plugin_unique_identifier: str
|
plugin_unique_identifier: str
|
||||||
@ -31,53 +33,45 @@ class DatasourceTool(Tool):
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
datasource_parameters: dict[str, Any],
|
datasource_parameters: dict[str, Any],
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
rag_pipeline_id: Optional[str] = None,
|
rag_pipeline_id: Optional[str] = None,
|
||||||
message_id: Optional[str] = None,
|
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
) -> Generator[ToolInvokeMessage, None, None]:
|
||||||
manager = PluginToolManager()
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||||
|
|
||||||
yield from manager.invoke_first_step(
|
yield from manager.invoke_first_step(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tool_provider=self.entity.identity.provider,
|
datasource_provider=self.entity.identity.provider,
|
||||||
tool_name=self.entity.identity.name,
|
datasource_name=self.entity.identity.name,
|
||||||
credentials=self.runtime.credentials,
|
credentials=self.runtime.credentials,
|
||||||
tool_parameters=tool_parameters,
|
datasource_parameters=datasource_parameters,
|
||||||
conversation_id=conversation_id,
|
rag_pipeline_id=rag_pipeline_id,
|
||||||
app_id=app_id,
|
|
||||||
message_id=message_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _invoke_second_step(
|
def _invoke_second_step(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
datasource_parameters: dict[str, Any],
|
datasource_parameters: dict[str, Any],
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
rag_pipeline_id: Optional[str] = None,
|
rag_pipeline_id: Optional[str] = None,
|
||||||
message_id: Optional[str] = None,
|
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
) -> Generator[ToolInvokeMessage, None, None]:
|
||||||
manager = PluginToolManager()
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||||
|
|
||||||
yield from manager.invoke(
|
yield from manager.invoke(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tool_provider=self.entity.identity.provider,
|
datasource_provider=self.entity.identity.provider,
|
||||||
tool_name=self.entity.identity.name,
|
datasource_name=self.entity.identity.name,
|
||||||
credentials=self.runtime.credentials,
|
credentials=self.runtime.credentials,
|
||||||
tool_parameters=tool_parameters,
|
datasource_parameters=datasource_parameters,
|
||||||
conversation_id=conversation_id,
|
rag_pipeline_id=rag_pipeline_id,
|
||||||
app_id=app_id,
|
|
||||||
message_id=message_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "DatasourceTool":
|
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||||
return DatasourceTool(
|
return DatasourcePlugin(
|
||||||
entity=self.entity,
|
entity=self.entity,
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
@ -87,9 +81,7 @@ class DatasourceTool(Tool):
|
|||||||
|
|
||||||
def get_runtime_parameters(
|
def get_runtime_parameters(
|
||||||
self,
|
self,
|
||||||
conversation_id: Optional[str] = None,
|
rag_pipeline_id: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
|
||||||
message_id: Optional[str] = None,
|
|
||||||
) -> list[DatasourceParameter]:
|
) -> list[DatasourceParameter]:
|
||||||
"""
|
"""
|
||||||
get the runtime parameters
|
get the runtime parameters
|
||||||
@ -100,16 +92,14 @@ class DatasourceTool(Tool):
|
|||||||
if self.runtime_parameters is not None:
|
if self.runtime_parameters is not None:
|
||||||
return self.runtime_parameters
|
return self.runtime_parameters
|
||||||
|
|
||||||
manager = PluginToolManager()
|
manager = PluginDatasourceManager()
|
||||||
self.runtime_parameters = manager.get_runtime_parameters(
|
self.runtime_parameters = manager.get_runtime_parameters(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
user_id="",
|
user_id="",
|
||||||
provider=self.entity.identity.provider,
|
provider=self.entity.identity.provider,
|
||||||
tool=self.entity.identity.name,
|
datasource=self.entity.identity.name,
|
||||||
credentials=self.runtime.credentials,
|
credentials=self.runtime.credentials,
|
||||||
conversation_id=conversation_id,
|
rag_pipeline_id=rag_pipeline_id,
|
||||||
app_id=app_id,
|
|
||||||
message_id=message_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.runtime_parameters
|
return self.runtime_parameters
|
||||||
|
@ -1,199 +0,0 @@
|
|||||||
import threading
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from flask import Flask, current_app
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
|
||||||
from core.model_manager import ModelManager
|
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
|
||||||
from core.rag.models.document import Document as RagDocument
|
|
||||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
|
||||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset, Document, DocumentSegment
|
|
||||||
|
|
||||||
default_retrieval_model: dict[str, Any] = {
|
|
||||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
|
||||||
"reranking_enable": False,
|
|
||||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
|
||||||
"top_k": 2,
|
|
||||||
"score_threshold_enabled": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetMultiRetrieverToolInput(BaseModel):
|
|
||||||
query: str = Field(..., description="dataset multi retriever and rerank")
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|
||||||
"""Tool for querying multi dataset."""
|
|
||||||
|
|
||||||
name: str = "dataset_"
|
|
||||||
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
|
|
||||||
description: str = "dataset multi retriever and rerank. "
|
|
||||||
dataset_ids: list[str]
|
|
||||||
reranking_provider_name: str
|
|
||||||
reranking_model_name: str
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
|
|
||||||
return cls(
|
|
||||||
name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run(self, query: str) -> str:
|
|
||||||
threads = []
|
|
||||||
all_documents: list[RagDocument] = []
|
|
||||||
for dataset_id in self.dataset_ids:
|
|
||||||
retrieval_thread = threading.Thread(
|
|
||||||
target=self._retriever,
|
|
||||||
kwargs={
|
|
||||||
"flask_app": current_app._get_current_object(), # type: ignore
|
|
||||||
"dataset_id": dataset_id,
|
|
||||||
"query": query,
|
|
||||||
"all_documents": all_documents,
|
|
||||||
"hit_callbacks": self.hit_callbacks,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
threads.append(retrieval_thread)
|
|
||||||
retrieval_thread.start()
|
|
||||||
for thread in threads:
|
|
||||||
thread.join()
|
|
||||||
# do rerank for searched documents
|
|
||||||
model_manager = ModelManager()
|
|
||||||
rerank_model_instance = model_manager.get_model_instance(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
provider=self.reranking_provider_name,
|
|
||||||
model_type=ModelType.RERANK,
|
|
||||||
model=self.reranking_model_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
rerank_runner = RerankModelRunner(rerank_model_instance)
|
|
||||||
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
|
|
||||||
|
|
||||||
for hit_callback in self.hit_callbacks:
|
|
||||||
hit_callback.on_tool_end(all_documents)
|
|
||||||
|
|
||||||
document_score_list = {}
|
|
||||||
for item in all_documents:
|
|
||||||
if item.metadata and item.metadata.get("score"):
|
|
||||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
|
||||||
|
|
||||||
document_context_list = []
|
|
||||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
|
||||||
segments = DocumentSegment.query.filter(
|
|
||||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
|
||||||
DocumentSegment.completed_at.isnot(None),
|
|
||||||
DocumentSegment.status == "completed",
|
|
||||||
DocumentSegment.enabled == True,
|
|
||||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
|
||||||
).all()
|
|
||||||
|
|
||||||
if segments:
|
|
||||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
|
||||||
sorted_segments = sorted(
|
|
||||||
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
|
|
||||||
)
|
|
||||||
for segment in sorted_segments:
|
|
||||||
if segment.answer:
|
|
||||||
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
|
|
||||||
else:
|
|
||||||
document_context_list.append(segment.get_sign_content())
|
|
||||||
if self.return_resource:
|
|
||||||
context_list = []
|
|
||||||
resource_number = 1
|
|
||||||
for segment in sorted_segments:
|
|
||||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
|
||||||
document = Document.query.filter(
|
|
||||||
Document.id == segment.document_id,
|
|
||||||
Document.enabled == True,
|
|
||||||
Document.archived == False,
|
|
||||||
).first()
|
|
||||||
if dataset and document:
|
|
||||||
source = {
|
|
||||||
"position": resource_number,
|
|
||||||
"dataset_id": dataset.id,
|
|
||||||
"dataset_name": dataset.name,
|
|
||||||
"document_id": document.id,
|
|
||||||
"document_name": document.name,
|
|
||||||
"data_source_type": document.data_source_type,
|
|
||||||
"segment_id": segment.id,
|
|
||||||
"retriever_from": self.retriever_from,
|
|
||||||
"score": document_score_list.get(segment.index_node_id, None),
|
|
||||||
"doc_metadata": document.doc_metadata,
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.retriever_from == "dev":
|
|
||||||
source["hit_count"] = segment.hit_count
|
|
||||||
source["word_count"] = segment.word_count
|
|
||||||
source["segment_position"] = segment.position
|
|
||||||
source["index_node_hash"] = segment.index_node_hash
|
|
||||||
if segment.answer:
|
|
||||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
|
||||||
else:
|
|
||||||
source["content"] = segment.content
|
|
||||||
context_list.append(source)
|
|
||||||
resource_number += 1
|
|
||||||
|
|
||||||
for hit_callback in self.hit_callbacks:
|
|
||||||
hit_callback.return_retriever_resource_info(context_list)
|
|
||||||
|
|
||||||
return str("\n".join(document_context_list))
|
|
||||||
return ""
|
|
||||||
|
|
||||||
raise RuntimeError("not segments found")
|
|
||||||
|
|
||||||
def _retriever(
|
|
||||||
self,
|
|
||||||
flask_app: Flask,
|
|
||||||
dataset_id: str,
|
|
||||||
query: str,
|
|
||||||
all_documents: list,
|
|
||||||
hit_callbacks: list[DatasetIndexToolCallbackHandler],
|
|
||||||
):
|
|
||||||
with flask_app.app_context():
|
|
||||||
dataset = (
|
|
||||||
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not dataset:
|
|
||||||
return []
|
|
||||||
|
|
||||||
for hit_callback in hit_callbacks:
|
|
||||||
hit_callback.on_query(query, dataset.id)
|
|
||||||
|
|
||||||
# get retrieval model , if the model is not setting , using default
|
|
||||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
|
||||||
|
|
||||||
if dataset.indexing_technique == "economy":
|
|
||||||
# use keyword table query
|
|
||||||
documents = RetrievalService.retrieve(
|
|
||||||
retrieval_method="keyword_search",
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
query=query,
|
|
||||||
top_k=retrieval_model.get("top_k") or 2,
|
|
||||||
)
|
|
||||||
if documents:
|
|
||||||
all_documents.extend(documents)
|
|
||||||
else:
|
|
||||||
if self.top_k > 0:
|
|
||||||
# retrieval source
|
|
||||||
documents = RetrievalService.retrieve(
|
|
||||||
retrieval_method=retrieval_model["search_method"],
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
query=query,
|
|
||||||
top_k=retrieval_model.get("top_k") or 2,
|
|
||||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
|
||||||
if retrieval_model["score_threshold_enabled"]
|
|
||||||
else 0.0,
|
|
||||||
reranking_model=retrieval_model.get("reranking_model", None)
|
|
||||||
if retrieval_model["reranking_enable"]
|
|
||||||
else None,
|
|
||||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
|
||||||
weights=retrieval_model.get("weights", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
all_documents.extend(documents)
|
|
@ -1,33 +0,0 @@
|
|||||||
from abc import abstractmethod
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from msal_extensions.persistence import ABC # type: ignore
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrieverBaseTool(BaseModel, ABC):
|
|
||||||
"""Tool for querying a Dataset."""
|
|
||||||
|
|
||||||
name: str = "dataset"
|
|
||||||
description: str = "use this to retrieve a dataset. "
|
|
||||||
tenant_id: str
|
|
||||||
top_k: int = 2
|
|
||||||
score_threshold: Optional[float] = None
|
|
||||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
|
||||||
return_resource: bool
|
|
||||||
retriever_from: str
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _run(
|
|
||||||
self,
|
|
||||||
*args: Any,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
|
||||||
"""Use the tool.
|
|
||||||
|
|
||||||
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
|
||||||
to child implementations to enable tracing,
|
|
||||||
"""
|
|
@ -1,202 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
|
||||||
from core.rag.entities.context_entities import DocumentContext
|
|
||||||
from core.rag.models.document import Document as RetrievalDocument
|
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
|
||||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset
|
|
||||||
from models.dataset import Document as DatasetDocument
|
|
||||||
from services.external_knowledge_service import ExternalDatasetService
|
|
||||||
|
|
||||||
default_retrieval_model = {
|
|
||||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
|
||||||
"reranking_enable": False,
|
|
||||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
|
||||||
"reranking_mode": "reranking_model",
|
|
||||||
"top_k": 2,
|
|
||||||
"score_threshold_enabled": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrieverToolInput(BaseModel):
|
|
||||||
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
||||||
"""Tool for querying a Dataset."""
|
|
||||||
|
|
||||||
name: str = "dataset"
|
|
||||||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
|
||||||
description: str = "use this to retrieve a dataset. "
|
|
||||||
dataset_id: str
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
|
||||||
description = dataset.description
|
|
||||||
if not description:
|
|
||||||
description = "useful for when you want to answer queries about the " + dataset.name
|
|
||||||
|
|
||||||
description = description.replace("\n", "").replace("\r", "")
|
|
||||||
return cls(
|
|
||||||
name=f"dataset_{dataset.id.replace('-', '_')}",
|
|
||||||
tenant_id=dataset.tenant_id,
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
description=description,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run(self, query: str) -> str:
|
|
||||||
dataset = (
|
|
||||||
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not dataset:
|
|
||||||
return ""
|
|
||||||
for hit_callback in self.hit_callbacks:
|
|
||||||
hit_callback.on_query(query, dataset.id)
|
|
||||||
if dataset.provider == "external":
|
|
||||||
results = []
|
|
||||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
|
||||||
tenant_id=dataset.tenant_id,
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
query=query,
|
|
||||||
external_retrieval_parameters=dataset.retrieval_model,
|
|
||||||
)
|
|
||||||
for external_document in external_documents:
|
|
||||||
document = RetrievalDocument(
|
|
||||||
page_content=external_document.get("content"),
|
|
||||||
metadata=external_document.get("metadata"),
|
|
||||||
provider="external",
|
|
||||||
)
|
|
||||||
if document.metadata is not None:
|
|
||||||
document.metadata["score"] = external_document.get("score")
|
|
||||||
document.metadata["title"] = external_document.get("title")
|
|
||||||
document.metadata["dataset_id"] = dataset.id
|
|
||||||
document.metadata["dataset_name"] = dataset.name
|
|
||||||
results.append(document)
|
|
||||||
# deal with external documents
|
|
||||||
context_list = []
|
|
||||||
for position, item in enumerate(results, start=1):
|
|
||||||
if item.metadata is not None:
|
|
||||||
source = {
|
|
||||||
"position": position,
|
|
||||||
"dataset_id": item.metadata.get("dataset_id"),
|
|
||||||
"dataset_name": item.metadata.get("dataset_name"),
|
|
||||||
"document_name": item.metadata.get("title"),
|
|
||||||
"data_source_type": "external",
|
|
||||||
"retriever_from": self.retriever_from,
|
|
||||||
"score": item.metadata.get("score"),
|
|
||||||
"title": item.metadata.get("title"),
|
|
||||||
"content": item.page_content,
|
|
||||||
}
|
|
||||||
context_list.append(source)
|
|
||||||
for hit_callback in self.hit_callbacks:
|
|
||||||
hit_callback.return_retriever_resource_info(context_list)
|
|
||||||
|
|
||||||
return str("\n".join([item.page_content for item in results]))
|
|
||||||
else:
|
|
||||||
# get retrieval model , if the model is not setting , using default
|
|
||||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
|
||||||
if dataset.indexing_technique == "economy":
|
|
||||||
# use keyword table query
|
|
||||||
documents = RetrievalService.retrieve(
|
|
||||||
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
|
|
||||||
)
|
|
||||||
return str("\n".join([document.page_content for document in documents]))
|
|
||||||
else:
|
|
||||||
if self.top_k > 0:
|
|
||||||
# retrieval source
|
|
||||||
documents = RetrievalService.retrieve(
|
|
||||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
query=query,
|
|
||||||
top_k=self.top_k,
|
|
||||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
|
||||||
if retrieval_model["score_threshold_enabled"]
|
|
||||||
else 0.0,
|
|
||||||
reranking_model=retrieval_model.get("reranking_model")
|
|
||||||
if retrieval_model["reranking_enable"]
|
|
||||||
else None,
|
|
||||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
|
||||||
weights=retrieval_model.get("weights"),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
documents = []
|
|
||||||
for hit_callback in self.hit_callbacks:
|
|
||||||
hit_callback.on_tool_end(documents)
|
|
||||||
document_score_list = {}
|
|
||||||
if dataset.indexing_technique != "economy":
|
|
||||||
for item in documents:
|
|
||||||
if item.metadata is not None and item.metadata.get("score"):
|
|
||||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
|
||||||
document_context_list = []
|
|
||||||
records = RetrievalService.format_retrieval_documents(documents)
|
|
||||||
if records:
|
|
||||||
for record in records:
|
|
||||||
segment = record.segment
|
|
||||||
if segment.answer:
|
|
||||||
document_context_list.append(
|
|
||||||
DocumentContext(
|
|
||||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
|
||||||
score=record.score,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
document_context_list.append(
|
|
||||||
DocumentContext(
|
|
||||||
content=segment.get_sign_content(),
|
|
||||||
score=record.score,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
retrieval_resource_list = []
|
|
||||||
if self.return_resource:
|
|
||||||
for record in records:
|
|
||||||
segment = record.segment
|
|
||||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
|
||||||
document = DatasetDocument.query.filter(
|
|
||||||
DatasetDocument.id == segment.document_id,
|
|
||||||
DatasetDocument.enabled == True,
|
|
||||||
DatasetDocument.archived == False,
|
|
||||||
).first()
|
|
||||||
if dataset and document:
|
|
||||||
source = {
|
|
||||||
"dataset_id": dataset.id,
|
|
||||||
"dataset_name": dataset.name,
|
|
||||||
"document_id": document.id, # type: ignore
|
|
||||||
"document_name": document.name, # type: ignore
|
|
||||||
"data_source_type": document.data_source_type, # type: ignore
|
|
||||||
"segment_id": segment.id,
|
|
||||||
"retriever_from": self.retriever_from,
|
|
||||||
"score": record.score or 0.0,
|
|
||||||
"doc_metadata": document.doc_metadata, # type: ignore
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.retriever_from == "dev":
|
|
||||||
source["hit_count"] = segment.hit_count
|
|
||||||
source["word_count"] = segment.word_count
|
|
||||||
source["segment_position"] = segment.position
|
|
||||||
source["index_node_hash"] = segment.index_node_hash
|
|
||||||
if segment.answer:
|
|
||||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
|
||||||
else:
|
|
||||||
source["content"] = segment.content
|
|
||||||
retrieval_resource_list.append(source)
|
|
||||||
|
|
||||||
if self.return_resource and retrieval_resource_list:
|
|
||||||
retrieval_resource_list = sorted(
|
|
||||||
retrieval_resource_list,
|
|
||||||
key=lambda x: x.get("score") or 0.0,
|
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
|
||||||
item["position"] = position # type: ignore
|
|
||||||
for hit_callback in self.hit_callbacks:
|
|
||||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
|
||||||
if document_context_list:
|
|
||||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
|
||||||
return str("\n".join([document_context.content for document_context in document_context_list]))
|
|
||||||
return ""
|
|
@ -1,134 +0,0 @@
|
|||||||
from collections.abc import Generator
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
|
||||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
|
||||||
from core.tools.__base.tool import Tool
|
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
|
||||||
from core.tools.entities.common_entities import I18nObject
|
|
||||||
from core.tools.entities.tool_entities import (
|
|
||||||
ToolDescription,
|
|
||||||
ToolEntity,
|
|
||||||
ToolIdentity,
|
|
||||||
ToolInvokeMessage,
|
|
||||||
ToolParameter,
|
|
||||||
ToolProviderType,
|
|
||||||
)
|
|
||||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrieverTool(Tool):
|
|
||||||
retrieval_tool: DatasetRetrieverBaseTool
|
|
||||||
|
|
||||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
|
|
||||||
super().__init__(entity, runtime)
|
|
||||||
self.retrieval_tool = retrieval_tool
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_dataset_tools(
|
|
||||||
tenant_id: str,
|
|
||||||
dataset_ids: list[str],
|
|
||||||
retrieve_config: DatasetRetrieveConfigEntity | None,
|
|
||||||
return_resource: bool,
|
|
||||||
invoke_from: InvokeFrom,
|
|
||||||
hit_callback: DatasetIndexToolCallbackHandler,
|
|
||||||
) -> list["DatasetRetrieverTool"]:
|
|
||||||
"""
|
|
||||||
get dataset tool
|
|
||||||
"""
|
|
||||||
# check if retrieve_config is valid
|
|
||||||
if dataset_ids is None or len(dataset_ids) == 0:
|
|
||||||
return []
|
|
||||||
if retrieve_config is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
feature = DatasetRetrieval()
|
|
||||||
|
|
||||||
# save original retrieve strategy, and set retrieve strategy to SINGLE
|
|
||||||
# Agent only support SINGLE mode
|
|
||||||
original_retriever_mode = retrieve_config.retrieve_strategy
|
|
||||||
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
|
||||||
retrieval_tools = feature.to_dataset_retriever_tool(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
dataset_ids=dataset_ids,
|
|
||||||
retrieve_config=retrieve_config,
|
|
||||||
return_resource=return_resource,
|
|
||||||
invoke_from=invoke_from,
|
|
||||||
hit_callback=hit_callback,
|
|
||||||
)
|
|
||||||
if retrieval_tools is None or len(retrieval_tools) == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# restore retrieve strategy
|
|
||||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
|
||||||
|
|
||||||
# convert retrieval tools to Tools
|
|
||||||
tools = []
|
|
||||||
for retrieval_tool in retrieval_tools:
|
|
||||||
tool = DatasetRetrieverTool(
|
|
||||||
retrieval_tool=retrieval_tool,
|
|
||||||
entity=ToolEntity(
|
|
||||||
identity=ToolIdentity(
|
|
||||||
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
|
|
||||||
),
|
|
||||||
parameters=[],
|
|
||||||
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
|
|
||||||
),
|
|
||||||
runtime=ToolRuntime(tenant_id=tenant_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
tools.append(tool)
|
|
||||||
|
|
||||||
return tools
|
|
||||||
|
|
||||||
def get_runtime_parameters(
|
|
||||||
self,
|
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
app_id: Optional[str] = None,
|
|
||||||
message_id: Optional[str] = None,
|
|
||||||
) -> list[ToolParameter]:
|
|
||||||
return [
|
|
||||||
ToolParameter(
|
|
||||||
name="query",
|
|
||||||
label=I18nObject(en_US="", zh_Hans=""),
|
|
||||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
|
||||||
type=ToolParameter.ToolParameterType.STRING,
|
|
||||||
form=ToolParameter.ToolParameterForm.LLM,
|
|
||||||
llm_description="Query for the dataset to be used to retrieve the dataset.",
|
|
||||||
required=True,
|
|
||||||
default="",
|
|
||||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def tool_provider_type(self) -> ToolProviderType:
|
|
||||||
return ToolProviderType.DATASET_RETRIEVAL
|
|
||||||
|
|
||||||
def _invoke(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
tool_parameters: dict[str, Any],
|
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
app_id: Optional[str] = None,
|
|
||||||
message_id: Optional[str] = None,
|
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
|
||||||
"""
|
|
||||||
invoke dataset retriever tool
|
|
||||||
"""
|
|
||||||
query = tool_parameters.get("query")
|
|
||||||
if not query:
|
|
||||||
yield self.create_text_message(text="please input query")
|
|
||||||
else:
|
|
||||||
# invoke dataset retriever tool
|
|
||||||
result = self.retrieval_tool._run(query=query)
|
|
||||||
yield self.create_text_message(text=result)
|
|
||||||
|
|
||||||
def validate_credentials(
|
|
||||||
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
|
|
||||||
) -> str | None:
|
|
||||||
"""
|
|
||||||
validate the credentials for dataset retriever tool
|
|
||||||
"""
|
|
||||||
pass
|
|
@ -1,169 +0,0 @@
|
|||||||
"""
|
|
||||||
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
|
|
||||||
|
|
||||||
Therefore, a model manager is needed to list/invoke/validate models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Optional, cast
|
|
||||||
|
|
||||||
from core.model_manager import ModelManager
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage
|
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
|
||||||
from core.model_runtime.errors.invoke import (
|
|
||||||
InvokeAuthorizationError,
|
|
||||||
InvokeBadRequestError,
|
|
||||||
InvokeConnectionError,
|
|
||||||
InvokeRateLimitError,
|
|
||||||
InvokeServerUnavailableError,
|
|
||||||
)
|
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.tools import ToolModelInvoke
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeModelError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInvocationUtils:
|
|
||||||
@staticmethod
|
|
||||||
def get_max_llm_context_tokens(
|
|
||||||
tenant_id: str,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
get max llm context tokens of the model
|
|
||||||
"""
|
|
||||||
model_manager = ModelManager()
|
|
||||||
model_instance = model_manager.get_default_model_instance(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_type=ModelType.LLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not model_instance:
|
|
||||||
raise InvokeModelError("Model not found")
|
|
||||||
|
|
||||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
|
||||||
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
|
||||||
|
|
||||||
if not schema:
|
|
||||||
raise InvokeModelError("No model schema found")
|
|
||||||
|
|
||||||
max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
|
||||||
if max_tokens is None:
|
|
||||||
return 2048
|
|
||||||
|
|
||||||
return max_tokens
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
|
|
||||||
"""
|
|
||||||
calculate tokens from prompt messages and model parameters
|
|
||||||
"""
|
|
||||||
|
|
||||||
# get model instance
|
|
||||||
model_manager = ModelManager()
|
|
||||||
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
|
|
||||||
|
|
||||||
if not model_instance:
|
|
||||||
raise InvokeModelError("Model not found")
|
|
||||||
|
|
||||||
# get tokens
|
|
||||||
tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
|
||||||
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def invoke(
|
|
||||||
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
|
|
||||||
) -> LLMResult:
|
|
||||||
"""
|
|
||||||
invoke model with parameters in user's own context
|
|
||||||
|
|
||||||
:param user_id: user id
|
|
||||||
:param tenant_id: tenant id, the tenant id of the creator of the tool
|
|
||||||
:param tool_type: tool type
|
|
||||||
:param tool_name: tool name
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:return: AssistantPromptMessage
|
|
||||||
"""
|
|
||||||
|
|
||||||
# get model manager
|
|
||||||
model_manager = ModelManager()
|
|
||||||
# get model instance
|
|
||||||
model_instance = model_manager.get_default_model_instance(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_type=ModelType.LLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get prompt tokens
|
|
||||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
|
||||||
|
|
||||||
model_parameters = {
|
|
||||||
"temperature": 0.8,
|
|
||||||
"top_p": 0.8,
|
|
||||||
}
|
|
||||||
|
|
||||||
# create tool model invoke
|
|
||||||
tool_model_invoke = ToolModelInvoke(
|
|
||||||
user_id=user_id,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
provider=model_instance.provider,
|
|
||||||
tool_type=tool_type,
|
|
||||||
tool_name=tool_name,
|
|
||||||
model_parameters=json.dumps(model_parameters),
|
|
||||||
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
|
|
||||||
model_response="",
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
answer_tokens=0,
|
|
||||||
answer_unit_price=0,
|
|
||||||
answer_price_unit=0,
|
|
||||||
provider_response_latency=0,
|
|
||||||
total_price=0,
|
|
||||||
currency="USD",
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(tool_model_invoke)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
try:
|
|
||||||
response: LLMResult = cast(
|
|
||||||
LLMResult,
|
|
||||||
model_instance.invoke_llm(
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
model_parameters=model_parameters,
|
|
||||||
tools=[],
|
|
||||||
stop=[],
|
|
||||||
stream=False,
|
|
||||||
user=user_id,
|
|
||||||
callbacks=[],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except InvokeRateLimitError as e:
|
|
||||||
raise InvokeModelError(f"Invoke rate limit error: {e}")
|
|
||||||
except InvokeBadRequestError as e:
|
|
||||||
raise InvokeModelError(f"Invoke bad request error: {e}")
|
|
||||||
except InvokeConnectionError as e:
|
|
||||||
raise InvokeModelError(f"Invoke connection error: {e}")
|
|
||||||
except InvokeAuthorizationError as e:
|
|
||||||
raise InvokeModelError("Invoke authorization error")
|
|
||||||
except InvokeServerUnavailableError as e:
|
|
||||||
raise InvokeModelError(f"Invoke server unavailable error: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
raise InvokeModelError(f"Invoke error: {e}")
|
|
||||||
|
|
||||||
# update tool model invoke
|
|
||||||
tool_model_invoke.model_response = response.message.content
|
|
||||||
if response.usage:
|
|
||||||
tool_model_invoke.answer_tokens = response.usage.completion_tokens
|
|
||||||
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
|
|
||||||
tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
|
|
||||||
tool_model_invoke.provider_response_latency = response.usage.latency
|
|
||||||
tool_model_invoke.total_price = response.usage.total_price
|
|
||||||
tool_model_invoke.currency = response.usage.currency
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return response
|
|
@ -1,17 +0,0 @@
|
|||||||
import re
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_upload_file_ids(content):
|
|
||||||
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
|
|
||||||
matches = re.findall(pattern, content)
|
|
||||||
image_upload_file_ids = []
|
|
||||||
for match in matches:
|
|
||||||
if match[1] == "file-preview":
|
|
||||||
content_pattern = r"files/([^/]+)/file-preview"
|
|
||||||
else:
|
|
||||||
content_pattern = r"files/([^/]+)/image-preview"
|
|
||||||
content_match = re.search(content_pattern, match[0])
|
|
||||||
if content_match:
|
|
||||||
image_upload_file_id = content_match.group(1)
|
|
||||||
image_upload_file_ids.append(image_upload_file_id)
|
|
||||||
return image_upload_file_ids
|
|
@ -1,375 +0,0 @@
|
|||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import mimetypes
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import site
|
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
import unicodedata
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Literal, Optional, cast
|
|
||||||
from urllib.parse import unquote
|
|
||||||
|
|
||||||
import chardet
|
|
||||||
import cloudscraper # type: ignore
|
|
||||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore
|
|
||||||
from regex import regex # type: ignore
|
|
||||||
|
|
||||||
from core.helper import ssrf_proxy
|
|
||||||
from core.rag.extractor import extract_processor
|
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
|
||||||
|
|
||||||
FULL_TEMPLATE = """
|
|
||||||
TITLE: {title}
|
|
||||||
AUTHORS: {authors}
|
|
||||||
PUBLISH DATE: {publish_date}
|
|
||||||
TOP_IMAGE_URL: {top_image}
|
|
||||||
TEXT:
|
|
||||||
|
|
||||||
{text}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
|
||||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
|
||||||
return text[cursor : cursor + max_length]
|
|
||||||
|
|
||||||
|
|
||||||
def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
|
||||||
"""Fetch URL and return the contents as a string."""
|
|
||||||
headers = {
|
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
|
||||||
" Chrome/91.0.4472.124 Safari/537.36"
|
|
||||||
}
|
|
||||||
if user_agent:
|
|
||||||
headers["User-Agent"] = user_agent
|
|
||||||
|
|
||||||
main_content_type = None
|
|
||||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
|
||||||
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
# check content-type
|
|
||||||
content_type = response.headers.get("Content-Type")
|
|
||||||
if content_type:
|
|
||||||
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
|
|
||||||
else:
|
|
||||||
content_disposition = response.headers.get("Content-Disposition", "")
|
|
||||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
|
||||||
if filename_match:
|
|
||||||
filename = unquote(filename_match.group(1))
|
|
||||||
extension = re.search(r"\.(\w+)$", filename)
|
|
||||||
if extension:
|
|
||||||
main_content_type = mimetypes.guess_type(filename)[0]
|
|
||||||
|
|
||||||
if main_content_type not in supported_content_types:
|
|
||||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
|
||||||
|
|
||||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
|
||||||
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
|
|
||||||
|
|
||||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
|
||||||
elif response.status_code == 403:
|
|
||||||
scraper = cloudscraper.create_scraper()
|
|
||||||
scraper.perform_request = ssrf_proxy.make_request
|
|
||||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
return "URL returned status code {}.".format(response.status_code)
|
|
||||||
|
|
||||||
# Detect encoding using chardet
|
|
||||||
detected_encoding = chardet.detect(response.content)
|
|
||||||
encoding = detected_encoding["encoding"]
|
|
||||||
if encoding:
|
|
||||||
try:
|
|
||||||
content = response.content.decode(encoding)
|
|
||||||
except (UnicodeDecodeError, TypeError):
|
|
||||||
content = response.text
|
|
||||||
else:
|
|
||||||
content = response.text
|
|
||||||
|
|
||||||
a = extract_using_readabilipy(content)
|
|
||||||
|
|
||||||
if not a["plain_text"] or not a["plain_text"].strip():
|
|
||||||
return ""
|
|
||||||
|
|
||||||
res = FULL_TEMPLATE.format(
|
|
||||||
title=a["title"],
|
|
||||||
authors=a["byline"],
|
|
||||||
publish_date=a["date"],
|
|
||||||
top_image="",
|
|
||||||
text=a["plain_text"] or "",
|
|
||||||
)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def extract_using_readabilipy(html):
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html:
|
|
||||||
f_html.write(html)
|
|
||||||
f_html.close()
|
|
||||||
html_path = f_html.name
|
|
||||||
|
|
||||||
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
|
|
||||||
article_json_path = html_path + ".json"
|
|
||||||
jsdir = os.path.join(find_module_path("readabilipy"), "javascript")
|
|
||||||
with chdir(jsdir):
|
|
||||||
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
|
|
||||||
|
|
||||||
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
|
|
||||||
input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
# Deleting files after processing
|
|
||||||
os.unlink(article_json_path)
|
|
||||||
os.unlink(html_path)
|
|
||||||
|
|
||||||
article_json: dict[str, Any] = {
|
|
||||||
"title": None,
|
|
||||||
"byline": None,
|
|
||||||
"date": None,
|
|
||||||
"content": None,
|
|
||||||
"plain_content": None,
|
|
||||||
"plain_text": None,
|
|
||||||
}
|
|
||||||
# Populate article fields from readability fields where present
|
|
||||||
if input_json:
|
|
||||||
if input_json.get("title"):
|
|
||||||
article_json["title"] = input_json["title"]
|
|
||||||
if input_json.get("byline"):
|
|
||||||
article_json["byline"] = input_json["byline"]
|
|
||||||
if input_json.get("date"):
|
|
||||||
article_json["date"] = input_json["date"]
|
|
||||||
if input_json.get("content"):
|
|
||||||
article_json["content"] = input_json["content"]
|
|
||||||
article_json["plain_content"] = plain_content(article_json["content"], False, False)
|
|
||||||
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
|
|
||||||
if input_json.get("textContent"):
|
|
||||||
article_json["plain_text"] = input_json["textContent"]
|
|
||||||
article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"])
|
|
||||||
|
|
||||||
return article_json
|
|
||||||
|
|
||||||
|
|
||||||
def find_module_path(module_name):
|
|
||||||
for package_path in site.getsitepackages():
|
|
||||||
potential_path = os.path.join(package_path, module_name)
|
|
||||||
if os.path.exists(potential_path):
|
|
||||||
return potential_path
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def chdir(path):
|
|
||||||
"""Change directory in context and return to original on exit"""
|
|
||||||
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
|
|
||||||
original_path = os.getcwd()
|
|
||||||
os.chdir(path)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
os.chdir(original_path)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_text_blocks_as_plain_text(paragraph_html):
|
|
||||||
# Load article as DOM
|
|
||||||
soup = BeautifulSoup(paragraph_html, "html.parser")
|
|
||||||
# Select all lists
|
|
||||||
list_elements = soup.find_all(["ul", "ol"])
|
|
||||||
# Prefix text in all list items with "* " and make lists paragraphs
|
|
||||||
for list_element in list_elements:
|
|
||||||
plain_items = "".join(
|
|
||||||
list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")]))
|
|
||||||
)
|
|
||||||
list_element.string = plain_items
|
|
||||||
list_element.name = "p"
|
|
||||||
# Select all text blocks
|
|
||||||
text_blocks = [s.parent for s in soup.find_all(string=True)]
|
|
||||||
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
|
|
||||||
# Drop empty paragraphs
|
|
||||||
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
|
|
||||||
return text_blocks
|
|
||||||
|
|
||||||
|
|
||||||
def plain_text_leaf_node(element):
|
|
||||||
# Extract all text, stripped of any child HTML elements and normalize it
|
|
||||||
plain_text = normalize_text(element.get_text())
|
|
||||||
if plain_text != "" and element.name == "li":
|
|
||||||
plain_text = "* {}, ".format(plain_text)
|
|
||||||
if plain_text == "":
|
|
||||||
plain_text = None
|
|
||||||
if "data-node-index" in element.attrs:
|
|
||||||
plain = {"node_index": element["data-node-index"], "text": plain_text}
|
|
||||||
else:
|
|
||||||
plain = {"text": plain_text}
|
|
||||||
return plain
|
|
||||||
|
|
||||||
|
|
||||||
def plain_content(readability_content, content_digests, node_indexes):
|
|
||||||
# Load article as DOM
|
|
||||||
soup = BeautifulSoup(readability_content, "html.parser")
|
|
||||||
# Make all elements plain
|
|
||||||
elements = plain_elements(soup.contents, content_digests, node_indexes)
|
|
||||||
if node_indexes:
|
|
||||||
# Add node index attributes to nodes
|
|
||||||
elements = [add_node_indexes(element) for element in elements]
|
|
||||||
# Replace article contents with plain elements
|
|
||||||
soup.contents = elements
|
|
||||||
return str(soup)
|
|
||||||
|
|
||||||
|
|
||||||
def plain_elements(elements, content_digests, node_indexes):
|
|
||||||
# Get plain content versions of all elements
|
|
||||||
elements = [plain_element(element, content_digests, node_indexes) for element in elements]
|
|
||||||
if content_digests:
|
|
||||||
# Add content digest attribute to nodes
|
|
||||||
elements = [add_content_digest(element) for element in elements]
|
|
||||||
return elements
|
|
||||||
|
|
||||||
|
|
||||||
def plain_element(element, content_digests, node_indexes):
|
|
||||||
# For lists, we make each item plain text
|
|
||||||
if is_leaf(element):
|
|
||||||
# For leaf node elements, extract the text content, discarding any HTML tags
|
|
||||||
# 1. Get element contents as text
|
|
||||||
plain_text = element.get_text()
|
|
||||||
# 2. Normalize the extracted text string to a canonical representation
|
|
||||||
plain_text = normalize_text(plain_text)
|
|
||||||
# 3. Update element content to be plain text
|
|
||||||
element.string = plain_text
|
|
||||||
elif is_text(element):
|
|
||||||
if is_non_printing(element):
|
|
||||||
# The simplified HTML may have come from Readability.js so might
|
|
||||||
# have non-printing text (e.g. Comment or CData). In this case, we
|
|
||||||
# keep the structure, but ensure that the string is empty.
|
|
||||||
element = type(element)("")
|
|
||||||
else:
|
|
||||||
plain_text = element.string
|
|
||||||
plain_text = normalize_text(plain_text)
|
|
||||||
element = type(element)(plain_text)
|
|
||||||
else:
|
|
||||||
# If not a leaf node or leaf type call recursively on child nodes, replacing
|
|
||||||
element.contents = plain_elements(element.contents, content_digests, node_indexes)
|
|
||||||
return element
|
|
||||||
|
|
||||||
|
|
||||||
def add_node_indexes(element, node_index="0"):
|
|
||||||
# Can't add attributes to string types
|
|
||||||
if is_text(element):
|
|
||||||
return element
|
|
||||||
# Add index to current element
|
|
||||||
element["data-node-index"] = node_index
|
|
||||||
# Add index to child elements
|
|
||||||
for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1):
|
|
||||||
# Can't add attributes to leaf string types
|
|
||||||
child_index = "{stem}.{local}".format(stem=node_index, local=local_idx)
|
|
||||||
add_node_indexes(child, node_index=child_index)
|
|
||||||
return element
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_text(text):
|
|
||||||
"""Normalize unicode and whitespace."""
|
|
||||||
# Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them
|
|
||||||
text = strip_control_characters(text)
|
|
||||||
text = normalize_unicode(text)
|
|
||||||
text = normalize_whitespace(text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def strip_control_characters(text):
|
|
||||||
"""Strip out unicode control characters which might break the parsing."""
|
|
||||||
# Unicode control characters
|
|
||||||
# [Cc]: Other, Control [includes new lines]
|
|
||||||
# [Cf]: Other, Format
|
|
||||||
# [Cn]: Other, Not Assigned
|
|
||||||
# [Co]: Other, Private Use
|
|
||||||
# [Cs]: Other, Surrogate
|
|
||||||
control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"}
|
|
||||||
retained_chars = ["\t", "\n", "\r", "\f"]
|
|
||||||
|
|
||||||
# Remove non-printing control characters
|
|
||||||
return "".join(
|
|
||||||
[
|
|
||||||
"" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char
|
|
||||||
for char in text
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_unicode(text):
|
|
||||||
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
|
||||||
normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC"
|
|
||||||
text = unicodedata.normalize(normal_form, text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_whitespace(text):
|
|
||||||
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
|
|
||||||
text = regex.sub(r"\s+", " ", text)
|
|
||||||
# Remove leading and trailing whitespace
|
|
||||||
text = text.strip()
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def is_leaf(element):
|
|
||||||
return element.name in {"p", "li"}
|
|
||||||
|
|
||||||
|
|
||||||
def is_text(element):
|
|
||||||
return isinstance(element, NavigableString)
|
|
||||||
|
|
||||||
|
|
||||||
def is_non_printing(element):
|
|
||||||
return any(isinstance(element, _e) for _e in [Comment, CData])
|
|
||||||
|
|
||||||
|
|
||||||
def add_content_digest(element):
|
|
||||||
if not is_text(element):
|
|
||||||
element["data-content-digest"] = content_digest(element)
|
|
||||||
return element
|
|
||||||
|
|
||||||
|
|
||||||
def content_digest(element):
|
|
||||||
digest: Any
|
|
||||||
if is_text(element):
|
|
||||||
# Hash
|
|
||||||
trimmed_string = element.string.strip()
|
|
||||||
if trimmed_string == "":
|
|
||||||
digest = ""
|
|
||||||
else:
|
|
||||||
digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest()
|
|
||||||
else:
|
|
||||||
contents = element.contents
|
|
||||||
num_contents = len(contents)
|
|
||||||
if num_contents == 0:
|
|
||||||
# No hash when no child elements exist
|
|
||||||
digest = ""
|
|
||||||
elif num_contents == 1:
|
|
||||||
# If single child, use digest of child
|
|
||||||
digest = content_digest(contents[0])
|
|
||||||
else:
|
|
||||||
# Build content digest from the "non-empty" digests of child nodes
|
|
||||||
digest = hashlib.sha256()
|
|
||||||
child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents]))
|
|
||||||
for child in child_digests:
|
|
||||||
digest.update(child.encode("utf-8"))
|
|
||||||
digest = digest.hexdigest()
|
|
||||||
return digest
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_upload_file_ids(content):
|
|
||||||
pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
|
|
||||||
matches = re.findall(pattern, content)
|
|
||||||
image_upload_file_ids = []
|
|
||||||
for match in matches:
|
|
||||||
if match[1] == "file-preview":
|
|
||||||
content_pattern = r"files/([^/]+)/file-preview"
|
|
||||||
else:
|
|
||||||
content_pattern = r"files/([^/]+)/image-preview"
|
|
||||||
content_match = re.search(content_pattern, match[0])
|
|
||||||
if content_match:
|
|
||||||
image_upload_file_id = content_match.group(1)
|
|
||||||
image_upload_file_ids.append(image_upload_file_id)
|
|
||||||
return image_upload_file_ids
|
|
1
api/installed_plugins.jsonl
Normal file
1
api/installed_plugins.jsonl
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"not_installed": [], "plugin_install_failed": []}
|
Loading…
x
Reference in New Issue
Block a user