diff --git a/api/core/datasource/__base/tool.py b/api/core/datasource/__base/datasource.py similarity index 90% rename from api/core/datasource/__base/tool.py rename to api/core/datasource/__base/datasource.py index 35e16b5c8f..3a67b56e32 100644 --- a/api/core/datasource/__base/tool.py +++ b/api/core/datasource/__base/datasource.py @@ -6,31 +6,30 @@ from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from models.model import File -from core.tools.__base.tool_runtime import ToolRuntime +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType from core.tools.entities.tool_entities import ( - ToolEntity, ToolInvokeMessage, ToolParameter, - ToolProviderType, ) -class Tool(ABC): +class Datasource(ABC): """ - The base class of a tool + The base class of a datasource """ - entity: ToolEntity - runtime: ToolRuntime + entity: DatasourceEntity + runtime: DatasourceRuntime - def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None: + def __init__(self, entity: DatasourceEntity, runtime: DatasourceRuntime) -> None: self.entity = entity self.runtime = runtime - def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool": + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "Datasource": """ - fork a new tool with metadata - :return: the new tool + fork a new datasource with metadata + :return: the new datasource """ return self.__class__( entity=self.entity.model_copy(), @@ -38,9 +37,9 @@ class Tool(ABC): ) @abstractmethod - def tool_provider_type(self) -> ToolProviderType: + def datasource_provider_type(self) -> DatasourceProviderType: """ - get the tool provider type + get the datasource provider type :return: the tool provider type """ diff --git a/api/core/datasource/__base/tool_runtime.py b/api/core/datasource/__base/datasource_runtime.py similarity index 66% rename from api/core/datasource/__base/tool_runtime.py rename to api/core/datasource/__base/datasource_runtime.py index c9e157cb77..51ff1fc6c1 100644 --- a/api/core/datasource/__base/tool_runtime.py +++ b/api/core/datasource/__base/datasource_runtime.py @@ -4,12 +4,13 @@ from openai import BaseModel from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom +from core.datasource.entities.datasource_entities import DatasourceInvokeFrom from core.tools.entities.tool_entities import ToolInvokeFrom -class ToolRuntime(BaseModel): +class DatasourceRuntime(BaseModel): """ - Meta data of a tool call processing + Meta data of a datasource call processing """ tenant_id: str @@ -20,17 +21,17 @@ class ToolRuntime(BaseModel): runtime_parameters: dict[str, Any] = Field(default_factory=dict) -class FakeToolRuntime(ToolRuntime): +class FakeDatasourceRuntime(DatasourceRuntime): """ - Fake tool runtime for testing + Fake datasource runtime for testing """ def __init__(self): super().__init__( tenant_id="fake_tenant_id", - tool_id="fake_tool_id", + datasource_id="fake_datasource_id", invoke_from=InvokeFrom.DEBUGGER, - tool_invoke_from=ToolInvokeFrom.AGENT, + datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE, credentials={}, runtime_parameters={}, ) diff --git a/api/core/datasource/tool_engine.py b/api/core/datasource/datasource_engine.py similarity index 99% rename from api/core/datasource/tool_engine.py rename to api/core/datasource/datasource_engine.py index ad0c62537c..423f78a787 100644 --- a/api/core/datasource/tool_engine.py +++ b/api/core/datasource/datasource_engine.py @@ -36,9 +36,9 @@ from models.enums import CreatedByRole from models.model import Message, MessageFile -class ToolEngine: +class DatasourceEngine: """ - Tool runtime engine take care of the tool executions. + Datasource runtime engine take care of the datasource executions. """ @staticmethod diff --git a/api/core/datasource/datasource_tool/tool.py b/api/core/datasource/datasource_tool/tool.py index b69b2368a4..1c8572c2c5 100644 --- a/api/core/datasource/datasource_tool/tool.py +++ b/api/core/datasource/datasource_tool/tool.py @@ -1,7 +1,9 @@ from collections.abc import Generator from typing import Any, Optional +from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType +from core.plugin.manager.datasource import PluginDatasourceManager from core.plugin.manager.tool import PluginToolManager from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.__base.tool import Tool @@ -9,7 +11,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType -class DatasourceTool(Tool): +class DatasourcePlugin(Datasource): tenant_id: str icon: str plugin_unique_identifier: str @@ -31,53 +33,45 @@ class DatasourceTool(Tool): self, user_id: str, datasource_parameters: dict[str, Any], - conversation_id: Optional[str] = None, rag_pipeline_id: Optional[str] = None, - message_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: - manager = PluginToolManager() + manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) yield from manager.invoke_first_step( tenant_id=self.tenant_id, user_id=user_id, - tool_provider=self.entity.identity.provider, - tool_name=self.entity.identity.name, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, credentials=self.runtime.credentials, - tool_parameters=tool_parameters, - conversation_id=conversation_id, - app_id=app_id, - message_id=message_id, + datasource_parameters=datasource_parameters, + rag_pipeline_id=rag_pipeline_id, ) def _invoke_second_step( self, user_id: str, datasource_parameters: dict[str, Any], - conversation_id: Optional[str] = None, rag_pipeline_id: Optional[str] = None, - message_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: - manager = PluginToolManager() + manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) yield from manager.invoke( tenant_id=self.tenant_id, user_id=user_id, - tool_provider=self.entity.identity.provider, - tool_name=self.entity.identity.name, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, credentials=self.runtime.credentials, - tool_parameters=tool_parameters, - conversation_id=conversation_id, - app_id=app_id, - message_id=message_id, + datasource_parameters=datasource_parameters, + rag_pipeline_id=rag_pipeline_id, ) - def fork_tool_runtime(self, runtime: ToolRuntime) -> "DatasourceTool": - return DatasourceTool( + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + return DatasourcePlugin( entity=self.entity, runtime=runtime, tenant_id=self.tenant_id, @@ -87,9 +81,7 @@ class DatasourceTool(Tool): def get_runtime_parameters( self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + rag_pipeline_id: Optional[str] = None, ) -> list[DatasourceParameter]: """ get the runtime parameters @@ -100,16 +92,14 @@ class DatasourceTool(Tool): if self.runtime_parameters is not None: return self.runtime_parameters - manager = PluginToolManager() + manager = PluginDatasourceManager() self.runtime_parameters = manager.get_runtime_parameters( tenant_id=self.tenant_id, user_id="", provider=self.entity.identity.provider, - tool=self.entity.identity.name, + datasource=self.entity.identity.name, credentials=self.runtime.credentials, - conversation_id=conversation_id, - app_id=app_id, - message_id=message_id, + rag_pipeline_id=rag_pipeline_id, ) return self.runtime_parameters diff --git a/api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py deleted file mode 100644 index 032274b87e..0000000000 --- a/api/core/datasource/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ /dev/null @@ -1,199 +0,0 @@ -import threading -from typing import Any - -from flask import Flask, current_app -from pydantic import BaseModel, Field - -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.rag.datasource.retrieval_service import RetrievalService -from core.rag.models.document import Document as RagDocument -from core.rag.rerank.rerank_model import RerankModelRunner -from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool -from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment - -default_retrieval_model: dict[str, Any] = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, -} - - -class DatasetMultiRetrieverToolInput(BaseModel): - query: str = Field(..., description="dataset multi retriever and rerank") - - -class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): - """Tool for querying multi dataset.""" - - name: str = "dataset_" - args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput - description: str = "dataset multi retriever and rerank. " - dataset_ids: list[str] - reranking_provider_name: str - reranking_model_name: str - - @classmethod - def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): - return cls( - name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs - ) - - def _run(self, query: str) -> str: - threads = [] - all_documents: list[RagDocument] = [] - for dataset_id in self.dataset_ids: - retrieval_thread = threading.Thread( - target=self._retriever, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "dataset_id": dataset_id, - "query": query, - "all_documents": all_documents, - "hit_callbacks": self.hit_callbacks, - }, - ) - threads.append(retrieval_thread) - retrieval_thread.start() - for thread in threads: - thread.join() - # do rerank for searched documents - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - provider=self.reranking_provider_name, - model_type=ModelType.RERANK, - model=self.reranking_model_name, - ) - - rerank_runner = RerankModelRunner(rerank_model_instance) - all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k) - - for hit_callback in self.hit_callbacks: - hit_callback.on_tool_end(all_documents) - - document_score_list = {} - for item in all_documents: - if item.metadata and item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - - document_context_list = [] - index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(self.dataset_ids), - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - for segment in sorted_segments: - if segment.answer: - document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") - else: - document_context_list.append(segment.get_sign_content()) - if self.return_resource: - context_list = [] - resource_number = 1 - for segment in sorted_segments: - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() - document = Document.query.filter( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() - if dataset and document: - source = { - "position": resource_number, - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, - "segment_id": segment.id, - "retriever_from": self.retriever_from, - "score": document_score_list.get(segment.index_node_id, None), - "doc_metadata": document.doc_metadata, - } - - if self.retriever_from == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash - if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" - else: - source["content"] = segment.content - context_list.append(source) - resource_number += 1 - - for hit_callback in self.hit_callbacks: - hit_callback.return_retriever_resource_info(context_list) - - return str("\n".join(document_context_list)) - return "" - - raise RuntimeError("not segments found") - - def _retriever( - self, - flask_app: Flask, - dataset_id: str, - query: str, - all_documents: list, - hit_callbacks: list[DatasetIndexToolCallbackHandler], - ): - with flask_app.app_context(): - dataset = ( - db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() - ) - - if not dataset: - return [] - - for hit_callback in hit_callbacks: - hit_callback.on_query(query, dataset.id) - - # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model - - if dataset.indexing_technique == "economy": - # use keyword table query - documents = RetrievalService.retrieve( - retrieval_method="keyword_search", - dataset_id=dataset.id, - query=query, - top_k=retrieval_model.get("top_k") or 2, - ) - if documents: - all_documents.extend(documents) - else: - if self.top_k > 0: - # retrieval source - documents = RetrievalService.retrieve( - retrieval_method=retrieval_model["search_method"], - dataset_id=dataset.id, - query=query, - top_k=retrieval_model.get("top_k") or 2, - score_threshold=retrieval_model.get("score_threshold", 0.0) - if retrieval_model["score_threshold_enabled"] - else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) - if retrieval_model["reranking_enable"] - else None, - reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), - ) - - all_documents.extend(documents) diff --git a/api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py deleted file mode 100644 index a4d2de3b1c..0000000000 --- a/api/core/datasource/utils/dataset_retriever/dataset_retriever_base_tool.py +++ /dev/null @@ -1,33 +0,0 @@ -from abc import abstractmethod -from typing import Any, Optional - -from msal_extensions.persistence import ABC # type: ignore -from pydantic import BaseModel, ConfigDict - -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler - - -class DatasetRetrieverBaseTool(BaseModel, ABC): - """Tool for querying a Dataset.""" - - name: str = "dataset" - description: str = "use this to retrieve a dataset. " - tenant_id: str - top_k: int = 2 - score_threshold: Optional[float] = None - hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] - return_resource: bool - retriever_from: str - model_config = ConfigDict(arbitrary_types_allowed=True) - - @abstractmethod - def _run( - self, - *args: Any, - **kwargs: Any, - ) -> Any: - """Use the tool. - - Add run_manager: Optional[CallbackManagerForToolRun] = None - to child implementations to enable tracing, - """ diff --git a/api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py deleted file mode 100644 index 63260cfac3..0000000000 --- a/api/core/datasource/utils/dataset_retriever/dataset_retriever_tool.py +++ /dev/null @@ -1,202 +0,0 @@ -from typing import Any - -from pydantic import BaseModel, Field - -from core.rag.datasource.retrieval_service import RetrievalService -from core.rag.entities.context_entities import DocumentContext -from core.rag.models.document import Document as RetrievalDocument -from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool -from extensions.ext_database import db -from models.dataset import Dataset -from models.dataset import Document as DatasetDocument -from services.external_knowledge_service import ExternalDatasetService - -default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "reranking_mode": "reranking_model", - "top_k": 2, - "score_threshold_enabled": False, -} - - -class DatasetRetrieverToolInput(BaseModel): - query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") - - -class DatasetRetrieverTool(DatasetRetrieverBaseTool): - """Tool for querying a Dataset.""" - - name: str = "dataset" - args_schema: type[BaseModel] = DatasetRetrieverToolInput - description: str = "use this to retrieve a dataset. " - dataset_id: str - - @classmethod - def from_dataset(cls, dataset: Dataset, **kwargs): - description = dataset.description - if not description: - description = "useful for when you want to answer queries about the " + dataset.name - - description = description.replace("\n", "").replace("\r", "") - return cls( - name=f"dataset_{dataset.id.replace('-', '_')}", - tenant_id=dataset.tenant_id, - dataset_id=dataset.id, - description=description, - **kwargs, - ) - - def _run(self, query: str) -> str: - dataset = ( - db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() - ) - - if not dataset: - return "" - for hit_callback in self.hit_callbacks: - hit_callback.on_query(query, dataset.id) - if dataset.provider == "external": - results = [] - external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( - tenant_id=dataset.tenant_id, - dataset_id=dataset.id, - query=query, - external_retrieval_parameters=dataset.retrieval_model, - ) - for external_document in external_documents: - document = RetrievalDocument( - page_content=external_document.get("content"), - metadata=external_document.get("metadata"), - provider="external", - ) - if document.metadata is not None: - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset.id - document.metadata["dataset_name"] = dataset.name - results.append(document) - # deal with external documents - context_list = [] - for position, item in enumerate(results, start=1): - if item.metadata is not None: - source = { - "position": position, - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": self.retriever_from, - "score": item.metadata.get("score"), - "title": item.metadata.get("title"), - "content": item.page_content, - } - context_list.append(source) - for hit_callback in self.hit_callbacks: - hit_callback.return_retriever_resource_info(context_list) - - return str("\n".join([item.page_content for item in results])) - else: - # get retrieval model , if the model is not setting , using default - retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model - if dataset.indexing_technique == "economy": - # use keyword table query - documents = RetrievalService.retrieve( - retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k - ) - return str("\n".join([document.page_content for document in documents])) - else: - if self.top_k > 0: - # retrieval source - documents = RetrievalService.retrieve( - retrieval_method=retrieval_model.get("search_method", "semantic_search"), - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get("score_threshold", 0.0) - if retrieval_model["score_threshold_enabled"] - else 0.0, - reranking_model=retrieval_model.get("reranking_model") - if retrieval_model["reranking_enable"] - else None, - reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights"), - ) - else: - documents = [] - for hit_callback in self.hit_callbacks: - hit_callback.on_tool_end(documents) - document_score_list = {} - if dataset.indexing_technique != "economy": - for item in documents: - if item.metadata is not None and item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - document_context_list = [] - records = RetrievalService.format_retrieval_documents(documents) - if records: - for record in records: - segment = record.segment - if segment.answer: - document_context_list.append( - DocumentContext( - content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=record.score, - ) - ) - else: - document_context_list.append( - DocumentContext( - content=segment.get_sign_content(), - score=record.score, - ) - ) - retrieval_resource_list = [] - if self.return_resource: - for record in records: - segment = record.segment - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() - document = DatasetDocument.query.filter( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() - if dataset and document: - source = { - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, # type: ignore - "document_name": document.name, # type: ignore - "data_source_type": document.data_source_type, # type: ignore - "segment_id": segment.id, - "retriever_from": self.retriever_from, - "score": record.score or 0.0, - "doc_metadata": document.doc_metadata, # type: ignore - } - - if self.retriever_from == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash - if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" - else: - source["content"] = segment.content - retrieval_resource_list.append(source) - - if self.return_resource and retrieval_resource_list: - retrieval_resource_list = sorted( - retrieval_resource_list, - key=lambda x: x.get("score") or 0.0, - reverse=True, - ) - for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore - item["position"] = position # type: ignore - for hit_callback in self.hit_callbacks: - hit_callback.return_retriever_resource_info(retrieval_resource_list) - if document_context_list: - document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) - return str("\n".join([document_context.content for document_context in document_context_list])) - return "" diff --git a/api/core/datasource/utils/dataset_retriever_tool.py b/api/core/datasource/utils/dataset_retriever_tool.py deleted file mode 100644 index b73dec4ebc..0000000000 --- a/api/core/datasource/utils/dataset_retriever_tool.py +++ /dev/null @@ -1,134 +0,0 @@ -from collections.abc import Generator -from typing import Any, Optional - -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import InvokeFrom -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.tools.__base.tool import Tool -from core.tools.__base.tool_runtime import ToolRuntime -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ToolDescription, - ToolEntity, - ToolIdentity, - ToolInvokeMessage, - ToolParameter, - ToolProviderType, -) -from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool - - -class DatasetRetrieverTool(Tool): - retrieval_tool: DatasetRetrieverBaseTool - - def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None: - super().__init__(entity, runtime) - self.retrieval_tool = retrieval_tool - - @staticmethod - def get_dataset_tools( - tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity | None, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler, - ) -> list["DatasetRetrieverTool"]: - """ - get dataset tool - """ - # check if retrieve_config is valid - if dataset_ids is None or len(dataset_ids) == 0: - return [] - if retrieve_config is None: - return [] - - feature = DatasetRetrieval() - - # save original retrieve strategy, and set retrieve strategy to SINGLE - # Agent only support SINGLE mode - original_retriever_mode = retrieve_config.retrieve_strategy - retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE - retrieval_tools = feature.to_dataset_retriever_tool( - tenant_id=tenant_id, - dataset_ids=dataset_ids, - retrieve_config=retrieve_config, - return_resource=return_resource, - invoke_from=invoke_from, - hit_callback=hit_callback, - ) - if retrieval_tools is None or len(retrieval_tools) == 0: - return [] - - # restore retrieve strategy - retrieve_config.retrieve_strategy = original_retriever_mode - - # convert retrieval tools to Tools - tools = [] - for retrieval_tool in retrieval_tools: - tool = DatasetRetrieverTool( - retrieval_tool=retrieval_tool, - entity=ToolEntity( - identity=ToolIdentity( - provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") - ), - parameters=[], - description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), - ), - runtime=ToolRuntime(tenant_id=tenant_id), - ) - - tools.append(tool) - - return tools - - def get_runtime_parameters( - self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> list[ToolParameter]: - return [ - ToolParameter( - name="query", - label=I18nObject(en_US="", zh_Hans=""), - human_description=I18nObject(en_US="", zh_Hans=""), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description="Query for the dataset to be used to retrieve the dataset.", - required=True, - default="", - placeholder=I18nObject(en_US="", zh_Hans=""), - ), - ] - - def tool_provider_type(self) -> ToolProviderType: - return ToolProviderType.DATASET_RETRIEVAL - - def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage, None, None]: - """ - invoke dataset retriever tool - """ - query = tool_parameters.get("query") - if not query: - yield self.create_text_message(text="please input query") - else: - # invoke dataset retriever tool - result = self.retrieval_tool._run(query=query) - yield self.create_text_message(text=result) - - def validate_credentials( - self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False - ) -> str | None: - """ - validate the credentials for dataset retriever tool - """ - pass diff --git a/api/core/datasource/utils/model_invocation_utils.py b/api/core/datasource/utils/model_invocation_utils.py deleted file mode 100644 index 3f59b3f472..0000000000 --- a/api/core/datasource/utils/model_invocation_utils.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. - -Therefore, a model manager is needed to list/invoke/validate models. -""" - -import json -from typing import Optional, cast - -from core.model_manager import ModelManager -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils.encoders import jsonable_encoder -from extensions.ext_database import db -from models.tools import ToolModelInvoke - - -class InvokeModelError(Exception): - pass - - -class ModelInvocationUtils: - @staticmethod - def get_max_llm_context_tokens( - tenant_id: str, - ) -> int: - """ - get max llm context tokens of the model - """ - model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - ) - - if not model_instance: - raise InvokeModelError("Model not found") - - llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) - - if not schema: - raise InvokeModelError("No model schema found") - - max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) - if max_tokens is None: - return 2048 - - return max_tokens - - @staticmethod - def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: - """ - calculate tokens from prompt messages and model parameters - """ - - # get model instance - model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) - - if not model_instance: - raise InvokeModelError("Model not found") - - # get tokens - tokens = model_instance.get_llm_num_tokens(prompt_messages) - - return tokens - - @staticmethod - def invoke( - user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] - ) -> LLMResult: - """ - invoke model with parameters in user's own context - - :param user_id: user id - :param tenant_id: tenant id, the tenant id of the creator of the tool - :param tool_type: tool type - :param tool_name: tool name - :param prompt_messages: prompt messages - :return: AssistantPromptMessage - """ - - # get model manager - model_manager = ModelManager() - # get model instance - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - ) - - # get prompt tokens - prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - model_parameters = { - "temperature": 0.8, - "top_p": 0.8, - } - - # create tool model invoke - tool_model_invoke = ToolModelInvoke( - user_id=user_id, - tenant_id=tenant_id, - provider=model_instance.provider, - tool_type=tool_type, - tool_name=tool_name, - model_parameters=json.dumps(model_parameters), - prompt_messages=json.dumps(jsonable_encoder(prompt_messages)), - model_response="", - prompt_tokens=prompt_tokens, - answer_tokens=0, - answer_unit_price=0, - answer_price_unit=0, - provider_response_latency=0, - total_price=0, - currency="USD", - ) - - db.session.add(tool_model_invoke) - db.session.commit() - - try: - response: LLMResult = cast( - LLMResult, - model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=[], - stop=[], - stream=False, - user=user_id, - callbacks=[], - ), - ) - except InvokeRateLimitError as e: - raise InvokeModelError(f"Invoke rate limit error: {e}") - except InvokeBadRequestError as e: - raise InvokeModelError(f"Invoke bad request error: {e}") - except InvokeConnectionError as e: - raise InvokeModelError(f"Invoke connection error: {e}") - except InvokeAuthorizationError as e: - raise InvokeModelError("Invoke authorization error") - except InvokeServerUnavailableError as e: - raise InvokeModelError(f"Invoke server unavailable error: {e}") - except Exception as e: - raise InvokeModelError(f"Invoke error: {e}") - - # update tool model invoke - tool_model_invoke.model_response = response.message.content - if response.usage: - tool_model_invoke.answer_tokens = response.usage.completion_tokens - tool_model_invoke.answer_unit_price = response.usage.completion_unit_price - tool_model_invoke.answer_price_unit = response.usage.completion_price_unit - tool_model_invoke.provider_response_latency = response.usage.latency - tool_model_invoke.total_price = response.usage.total_price - tool_model_invoke.currency = response.usage.currency - - db.session.commit() - - return response diff --git a/api/core/datasource/utils/rag_web_reader.py b/api/core/datasource/utils/rag_web_reader.py deleted file mode 100644 index 22c47fa814..0000000000 --- a/api/core/datasource/utils/rag_web_reader.py +++ /dev/null @@ -1,17 +0,0 @@ -import re - - -def get_image_upload_file_ids(content): - pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" - matches = re.findall(pattern, content) - image_upload_file_ids = [] - for match in matches: - if match[1] == "file-preview": - content_pattern = r"files/([^/]+)/file-preview" - else: - content_pattern = r"files/([^/]+)/image-preview" - content_match = re.search(content_pattern, match[0]) - if content_match: - image_upload_file_id = content_match.group(1) - image_upload_file_ids.append(image_upload_file_id) - return image_upload_file_ids diff --git a/api/core/datasource/utils/web_reader_tool.py b/api/core/datasource/utils/web_reader_tool.py deleted file mode 100644 index d42fd99fce..0000000000 --- a/api/core/datasource/utils/web_reader_tool.py +++ /dev/null @@ -1,375 +0,0 @@ -import hashlib -import json -import mimetypes -import os -import re -import site -import subprocess -import tempfile -import unicodedata -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Literal, Optional, cast -from urllib.parse import unquote - -import chardet -import cloudscraper # type: ignore -from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore -from regex import regex # type: ignore - -from core.helper import ssrf_proxy -from core.rag.extractor import extract_processor -from core.rag.extractor.extract_processor import ExtractProcessor - -FULL_TEMPLATE = """ -TITLE: {title} -AUTHORS: {authors} -PUBLISH DATE: {publish_date} -TOP_IMAGE_URL: {top_image} -TEXT: - -{text} -""" - - -def page_result(text: str, cursor: int, max_length: int) -> str: - """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" - return text[cursor : cursor + max_length] - - -def get_url(url: str, user_agent: Optional[str] = None) -> str: - """Fetch URL and return the contents as a string.""" - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" - " Chrome/91.0.4472.124 Safari/537.36" - } - if user_agent: - headers["User-Agent"] = user_agent - - main_content_type = None - supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] - response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10)) - - if response.status_code == 200: - # check content-type - content_type = response.headers.get("Content-Type") - if content_type: - main_content_type = response.headers.get("Content-Type").split(";")[0].strip() - else: - content_disposition = response.headers.get("Content-Disposition", "") - filename_match = re.search(r'filename="([^"]+)"', content_disposition) - if filename_match: - filename = unquote(filename_match.group(1)) - extension = re.search(r"\.(\w+)$", filename) - if extension: - main_content_type = mimetypes.guess_type(filename)[0] - - if main_content_type not in supported_content_types: - return "Unsupported content-type [{}] of URL.".format(main_content_type) - - if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: - return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) - - response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) - elif response.status_code == 403: - scraper = cloudscraper.create_scraper() - scraper.perform_request = ssrf_proxy.make_request - response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) - - if response.status_code != 200: - return "URL returned status code {}.".format(response.status_code) - - # Detect encoding using chardet - detected_encoding = chardet.detect(response.content) - encoding = detected_encoding["encoding"] - if encoding: - try: - content = response.content.decode(encoding) - except (UnicodeDecodeError, TypeError): - content = response.text - else: - content = response.text - - a = extract_using_readabilipy(content) - - if not a["plain_text"] or not a["plain_text"].strip(): - return "" - - res = FULL_TEMPLATE.format( - title=a["title"], - authors=a["byline"], - publish_date=a["date"], - top_image="", - text=a["plain_text"] or "", - ) - - return res - - -def extract_using_readabilipy(html): - with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: - f_html.write(html) - f_html.close() - html_path = f_html.name - - # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file - article_json_path = html_path + ".json" - jsdir = os.path.join(find_module_path("readabilipy"), "javascript") - with chdir(jsdir): - subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) - - # Read output of call to Readability.parse() from JSON file and return as Python dictionary - input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8")) - - # Deleting files after processing - os.unlink(article_json_path) - os.unlink(html_path) - - article_json: dict[str, Any] = { - "title": None, - "byline": None, - "date": None, - "content": None, - "plain_content": None, - "plain_text": None, - } - # Populate article fields from readability fields where present - if input_json: - if input_json.get("title"): - article_json["title"] = input_json["title"] - if input_json.get("byline"): - article_json["byline"] = input_json["byline"] - if input_json.get("date"): - article_json["date"] = input_json["date"] - if input_json.get("content"): - article_json["content"] = input_json["content"] - article_json["plain_content"] = plain_content(article_json["content"], False, False) - article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) - if input_json.get("textContent"): - article_json["plain_text"] = input_json["textContent"] - article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) - - return article_json - - -def find_module_path(module_name): - for package_path in site.getsitepackages(): - potential_path = os.path.join(package_path, module_name) - if os.path.exists(potential_path): - return potential_path - - return None - - -@contextmanager -def chdir(path): - """Change directory in context and return to original on exit""" - # From https://stackoverflow.com/a/37996581, couldn't find a built-in - original_path = os.getcwd() - os.chdir(path) - try: - yield - finally: - os.chdir(original_path) - - -def extract_text_blocks_as_plain_text(paragraph_html): - # Load article as DOM - soup = BeautifulSoup(paragraph_html, "html.parser") - # Select all lists - list_elements = soup.find_all(["ul", "ol"]) - # Prefix text in all list items with "* " and make lists paragraphs - for list_element in list_elements: - plain_items = "".join( - list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) - ) - list_element.string = plain_items - list_element.name = "p" - # Select all text blocks - text_blocks = [s.parent for s in soup.find_all(string=True)] - text_blocks = [plain_text_leaf_node(block) for block in text_blocks] - # Drop empty paragraphs - text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks)) - return text_blocks - - -def plain_text_leaf_node(element): - # Extract all text, stripped of any child HTML elements and normalize it - plain_text = normalize_text(element.get_text()) - if plain_text != "" and element.name == "li": - plain_text = "* {}, ".format(plain_text) - if plain_text == "": - plain_text = None - if "data-node-index" in element.attrs: - plain = {"node_index": element["data-node-index"], "text": plain_text} - else: - plain = {"text": plain_text} - return plain - - -def plain_content(readability_content, content_digests, node_indexes): - # Load article as DOM - soup = BeautifulSoup(readability_content, "html.parser") - # Make all elements plain - elements = plain_elements(soup.contents, content_digests, node_indexes) - if node_indexes: - # Add node index attributes to nodes - elements = [add_node_indexes(element) for element in elements] - # Replace article contents with plain elements - soup.contents = elements - return str(soup) - - -def plain_elements(elements, content_digests, node_indexes): - # Get plain content versions of all elements - elements = [plain_element(element, content_digests, node_indexes) for element in elements] - if content_digests: - # Add content digest attribute to nodes - elements = [add_content_digest(element) for element in elements] - return elements - - -def plain_element(element, content_digests, node_indexes): - # For lists, we make each item plain text - if is_leaf(element): - # For leaf node elements, extract the text content, discarding any HTML tags - # 1. Get element contents as text - plain_text = element.get_text() - # 2. Normalize the extracted text string to a canonical representation - plain_text = normalize_text(plain_text) - # 3. Update element content to be plain text - element.string = plain_text - elif is_text(element): - if is_non_printing(element): - # The simplified HTML may have come from Readability.js so might - # have non-printing text (e.g. Comment or CData). In this case, we - # keep the structure, but ensure that the string is empty. - element = type(element)("") - else: - plain_text = element.string - plain_text = normalize_text(plain_text) - element = type(element)(plain_text) - else: - # If not a leaf node or leaf type call recursively on child nodes, replacing - element.contents = plain_elements(element.contents, content_digests, node_indexes) - return element - - -def add_node_indexes(element, node_index="0"): - # Can't add attributes to string types - if is_text(element): - return element - # Add index to current element - element["data-node-index"] = node_index - # Add index to child elements - for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): - # Can't add attributes to leaf string types - child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) - add_node_indexes(child, node_index=child_index) - return element - - -def normalize_text(text): - """Normalize unicode and whitespace.""" - # Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them - text = strip_control_characters(text) - text = normalize_unicode(text) - text = normalize_whitespace(text) - return text - - -def strip_control_characters(text): - """Strip out unicode control characters which might break the parsing.""" - # Unicode control characters - # [Cc]: Other, Control [includes new lines] - # [Cf]: Other, Format - # [Cn]: Other, Not Assigned - # [Co]: Other, Private Use - # [Cs]: Other, Surrogate - control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} - retained_chars = ["\t", "\n", "\r", "\f"] - - # Remove non-printing control characters - return "".join( - [ - "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char - for char in text - ] - ) - - -def normalize_unicode(text): - """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" - normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC" - text = unicodedata.normalize(normal_form, text) - return text - - -def normalize_whitespace(text): - """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" - text = regex.sub(r"\s+", " ", text) - # Remove leading and trailing whitespace - text = text.strip() - return text - - -def is_leaf(element): - return element.name in {"p", "li"} - - -def is_text(element): - return isinstance(element, NavigableString) - - -def is_non_printing(element): - return any(isinstance(element, _e) for _e in [Comment, CData]) - - -def add_content_digest(element): - if not is_text(element): - element["data-content-digest"] = content_digest(element) - return element - - -def content_digest(element): - digest: Any - if is_text(element): - # Hash - trimmed_string = element.string.strip() - if trimmed_string == "": - digest = "" - else: - digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() - else: - contents = element.contents - num_contents = len(contents) - if num_contents == 0: - # No hash when no child elements exist - digest = "" - elif num_contents == 1: - # If single child, use digest of child - digest = content_digest(contents[0]) - else: - # Build content digest from the "non-empty" digests of child nodes - digest = hashlib.sha256() - child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) - for child in child_digests: - digest.update(child.encode("utf-8")) - digest = digest.hexdigest() - return digest - - -def get_image_upload_file_ids(content): - pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" - matches = re.findall(pattern, content) - image_upload_file_ids = [] - for match in matches: - if match[1] == "file-preview": - content_pattern = r"files/([^/]+)/file-preview" - else: - content_pattern = r"files/([^/]+)/image-preview" - content_match = re.search(content_pattern, match[0]) - if content_match: - image_upload_file_id = content_match.group(1) - image_upload_file_ids.append(image_upload_file_id) - return image_upload_file_ids diff --git a/api/installed_plugins.jsonl b/api/installed_plugins.jsonl new file mode 100644 index 0000000000..463e24ae64 --- /dev/null +++ b/api/installed_plugins.jsonl @@ -0,0 +1 @@ +{"not_installed": [], "plugin_install_failed": []} \ No newline at end of file