diff --git a/api/commands.py b/api/commands.py index 105f936562..35b5c5d5f8 100644 --- a/api/commands.py +++ b/api/commands.py @@ -8,6 +8,8 @@ import time import uuid import click +import qdrant_client +from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType from tqdm import tqdm from flask import current_app, Flask from langchain.embeddings import OpenAIEmbeddings @@ -484,6 +486,38 @@ def normalization_collections(): click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green')) +@click.command('add-qdrant-full-text-index', help='add qdrant full text index') +def add_qdrant_full_text_index(): + click.echo(click.style('Start add full text index.', fg='green')) + binds = db.session.query(DatasetCollectionBinding).all() + if binds and current_app.config['VECTOR_STORE'] == 'qdrant': + qdrant_url = current_app.config['QDRANT_URL'] + qdrant_api_key = current_app.config['QDRANT_API_KEY'] + client = qdrant_client.QdrantClient( + qdrant_url, + api_key=qdrant_api_key, # For Qdrant Cloud, None for local instance + ) + for bind in binds: + try: + text_index_params = TextIndexParams( + type=TextIndexType.TEXT, + tokenizer=TokenizerType.MULTILINGUAL, + min_token_len=2, + max_token_len=20, + lowercase=True + ) + client.create_payload_index(bind.collection_name, 'page_content', + field_schema=text_index_params) + except Exception as e: + click.echo( + click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)), + fg='red')) + click.echo( + click.style( + 'Congratulations! add collection {} full text index successful.'.format(bind.collection_name), + fg='green')) + + def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list): with flask_app.app_context(): try: @@ -647,10 +681,10 @@ def update_app_model_configs(batch_size): pbar.update(len(data_batch)) + @click.command('migrate_default_input_to_dataset_query_variable') @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.") def migrate_default_input_to_dataset_query_variable(batch_size): - click.secho("Starting...", fg='green') total_records = db.session.query(AppModelConfig) \ @@ -658,13 +692,13 @@ def migrate_default_input_to_dataset_query_variable(batch_size): .filter(App.mode == 'completion') \ .filter(AppModelConfig.dataset_query_variable == None) \ .count() - + if total_records == 0: click.secho("No data to migrate.", fg='green') return num_batches = (total_records + batch_size - 1) // batch_size - + with tqdm(total=total_records, desc="Migrating Data") as pbar: for i in range(num_batches): offset = i * batch_size @@ -697,14 +731,14 @@ def migrate_default_input_to_dataset_query_variable(batch_size): for form in user_input_form: paragraph = form.get('paragraph') if paragraph \ - and paragraph.get('variable') == 'query': - data.dataset_query_variable = 'query' - break - + and paragraph.get('variable') == 'query': + data.dataset_query_variable = 'query' + break + if paragraph \ - and paragraph.get('variable') == 'default_input': - data.dataset_query_variable = 'default_input' - break + and paragraph.get('variable') == 'default_input': + data.dataset_query_variable = 'default_input' + break db.session.commit() @@ -712,7 +746,7 @@ def migrate_default_input_to_dataset_query_variable(batch_size): click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red') continue - + click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green') pbar.update(len(data_batch)) @@ -731,3 +765,4 @@ def register_commands(app): app.cli.add_command(update_app_model_configs) app.cli.add_command(normalization_collections) app.cli.add_command(migrate_default_input_to_dataset_query_variable) + app.cli.add_command(add_qdrant_full_text_index) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 7aa2a7bfc4..9a417b3660 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -170,6 +170,7 @@ class DatasetApi(Resource): help='Invalid indexing technique.') parser.add_argument('permission', type=str, location='json', choices=( 'only_me', 'all_team_members'), help='Invalid permission.') + parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') args = parser.parse_args() # The role of the current user in the ta table must be admin or owner @@ -401,6 +402,7 @@ class DatasetApiKeyApi(Resource): class DatasetApiDeleteApi(Resource): resource_type = 'dataset' + @setup_required @login_required @account_initialization_required @@ -436,6 +438,50 @@ class DatasetApiBaseUrlApi(Resource): } +class DatasetRetrievalSettingApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + vector_type = current_app.config['VECTOR_STORE'] + if vector_type == 'milvus': + return { + 'retrieval_method': [ + 'semantic_search' + ] + } + elif vector_type == 'qdrant' or vector_type == 'weaviate': + return { + 'retrieval_method': [ + 'semantic_search', 'full_text_search', 'hybrid_search' + ] + } + else: + raise ValueError("Unsupported vector db type.") + + +class DatasetRetrievalSettingMockApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, vector_type): + + if vector_type == 'milvus': + return { + 'retrieval_method': [ + 'semantic_search' + ] + } + elif vector_type == 'qdrant' or vector_type == 'weaviate': + return { + 'retrieval_method': [ + 'semantic_search', 'full_text_search', 'hybrid_search' + ] + } + else: + raise ValueError("Unsupported vector db type.") + + api.add_resource(DatasetListApi, '/datasets') api.add_resource(DatasetApi, '/datasets/') api.add_resource(DatasetQueryApi, '/datasets//queries') @@ -445,3 +491,5 @@ api.add_resource(DatasetIndexingStatusApi, '/datasets//indexing api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/') api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') +api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') +api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/') diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index e85433c83f..0f5634de4d 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -221,6 +221,8 @@ class DatasetDocumentListApi(Resource): parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') + parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, + location='json') args = parser.parse_args() if not dataset.indexing_technique and not args['indexing_technique']: @@ -263,6 +265,8 @@ class DatasetInitApi(Resource): parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') + parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, + location='json') args = parser.parse_args() if args['indexing_technique'] == 'high_quality': try: diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 6d3397e16f..82b036ad7a 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -42,19 +42,18 @@ class HitTestingApi(Resource): parser = reqparse.RequestParser() parser.add_argument('query', type=str, location='json') + parser.add_argument('retrieval_model', type=dict, required=False, location='json') args = parser.parse_args() - query = args['query'] - - if not query or len(query) > 250: - raise ValueError('Query is required and cannot exceed 250 characters') + HitTestingService.hit_testing_args_check(args) try: response = HitTestingService.retrieve( dataset=dataset, - query=query, + query=args['query'], account=current_user, - limit=10, + retrieval_model=args['retrieval_model'], + limit=10 ) return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 7099b8f23d..dd10ae1ce2 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -19,7 +19,7 @@ class DefaultModelApi(Resource): def get(self): parser = reqparse.RequestParser() parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=['text-generation', 'embeddings', 'speech2text'], location='args') + choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args') args = parser.parse_args() tenant_id = current_user.current_tenant_id @@ -71,19 +71,18 @@ class DefaultModelApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=['text-generation', 'embeddings', 'speech2text'], location='json') - parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json') + parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json') args = parser.parse_args() provider_service = ProviderService() - provider_service.update_default_model_of_model_type( - tenant_id=current_user.current_tenant_id, - model_type=args['model_type'], - provider_name=args['provider_name'], - model_name=args['model_name'] - ) + model_settings = args['model_settings'] + for model_setting in model_settings: + provider_service.update_default_model_of_model_type( + tenant_id=current_user.current_tenant_id, + model_type=model_setting['model_type'], + provider_name=model_setting['provider_name'], + model_name=model_setting['model_name'] + ) return {'result': 'success'} diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 28545a36ab..e900e84a01 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -36,6 +36,8 @@ class DocumentAddByTextApi(DatasetApiResource): location='json') parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location='json') + parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, + location='json') args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) @@ -95,6 +97,8 @@ class DocumentUpdateByTextApi(DatasetApiResource): parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') + parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, + location='json') args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index 16b4a2ab24..4fb7121127 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -14,7 +14,6 @@ from pydantic import root_validator from core.model_providers.models.entity.message import to_prompt_messages from core.model_providers.models.llm.base import BaseLLM from core.third_party.langchain.llms.fake import FakeLLM -from core.tool.dataset_retriever_tool import DatasetRetrieverTool class MultiDatasetRouterAgent(OpenAIFunctionsAgent): @@ -60,7 +59,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): return AgentFinish(return_values={"output": ''}, log='') elif len(self.tools) == 1: tool = next(iter(self.tools)) - tool = cast(DatasetRetrieverTool, tool) rst = tool.run(tool_input={'query': kwargs['input']}) # output = '' # rst_json = json.loads(rst) diff --git a/api/core/agent/agent/output_parser/retirver_dataset_agent.py b/api/core/agent/agent/output_parser/retirver_dataset_agent.py new file mode 100644 index 0000000000..16b4a2ab24 --- /dev/null +++ b/api/core/agent/agent/output_parser/retirver_dataset_agent.py @@ -0,0 +1,158 @@ +import json +from typing import Tuple, List, Any, Union, Sequence, Optional, cast + +from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent +from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message +from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import Callbacks +from langchain.prompts.chat import BaseMessagePromptTemplate +from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage +from langchain.schema.language_model import BaseLanguageModel +from langchain.tools import BaseTool +from pydantic import root_validator + +from core.model_providers.models.entity.message import to_prompt_messages +from core.model_providers.models.llm.base import BaseLLM +from core.third_party.langchain.llms.fake import FakeLLM +from core.tool.dataset_retriever_tool import DatasetRetrieverTool + + +class MultiDatasetRouterAgent(OpenAIFunctionsAgent): + """ + An Multi Dataset Retrieve Agent driven by Router. + """ + model_instance: BaseLLM + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @root_validator + def validate_llm(cls, values: dict) -> dict: + return values + + def should_use_agent(self, query: str): + """ + return should use agent + + :param query: + :return: + """ + return True + + def plan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, along with observations + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + if len(self.tools) == 0: + return AgentFinish(return_values={"output": ''}, log='') + elif len(self.tools) == 1: + tool = next(iter(self.tools)) + tool = cast(DatasetRetrieverTool, tool) + rst = tool.run(tool_input={'query': kwargs['input']}) + # output = '' + # rst_json = json.loads(rst) + # for item in rst_json: + # output += f'{item["content"]}\n' + return AgentFinish(return_values={"output": rst}, log=rst) + + if intermediate_steps: + _, observation = intermediate_steps[-1] + return AgentFinish(return_values={"output": observation}, log=observation) + + try: + agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs) + if isinstance(agent_decision, AgentAction): + tool_inputs = agent_decision.tool_input + if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs: + tool_inputs['query'] = kwargs['input'] + agent_decision.tool_input = tool_inputs + else: + agent_decision.return_values['output'] = '' + return agent_decision + except Exception as e: + new_exception = self.model_instance.handle_exceptions(e) + raise new_exception + + def real_plan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, along with observations + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + agent_scratchpad = _format_intermediate_steps(intermediate_steps) + selected_inputs = { + k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" + } + full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) + prompt = self.prompt.format_prompt(**full_inputs) + messages = prompt.to_messages() + prompt_messages = to_prompt_messages(messages) + result = self.model_instance.run( + messages=prompt_messages, + functions=self.functions, + ) + + ai_message = AIMessage( + content=result.content, + additional_kwargs={ + 'function_call': result.function_call + } + ) + + agent_decision = _parse_ai_message(ai_message) + return agent_decision + + async def aplan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + raise NotImplementedError() + + @classmethod + def from_llm_and_tools( + cls, + model_instance: BaseLLM, + tools: Sequence[BaseTool], + callback_manager: Optional[BaseCallbackManager] = None, + extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, + system_message: Optional[SystemMessage] = SystemMessage( + content="You are a helpful AI assistant." + ), + **kwargs: Any, + ) -> BaseSingleActionAgent: + prompt = cls.create_prompt( + extra_prompt_messages=extra_prompt_messages, + system_message=system_message, + ) + return cls( + model_instance=model_instance, + llm=FakeLLM(response=''), + prompt=prompt, + tools=tools, + callback_manager=callback_manager, + **kwargs, + ) diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/agent/agent/structed_multi_dataset_router_agent.py index 84c0553625..115ed69d17 100644 --- a/api/core/agent/agent/structed_multi_dataset_router_agent.py +++ b/api/core/agent/agent/structed_multi_dataset_router_agent.py @@ -89,7 +89,6 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): return AgentFinish(return_values={"output": ''}, log='') elif len(self.dataset_tools) == 1: tool = next(iter(self.dataset_tools)) - tool = cast(DatasetRetrieverTool, tool) rst = tool.run(tool_input={'query': kwargs['input']}) return AgentFinish(return_values={"output": rst}, log=rst) diff --git a/api/core/agent/agent_executor.py b/api/core/agent/agent_executor.py index 05c4b632ff..579f3d5d90 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/agent/agent_executor.py @@ -18,6 +18,7 @@ from langchain.agents import AgentExecutor as LCAgentExecutor from core.helper import moderation from core.model_providers.error import LLMError from core.model_providers.models.llm.base import BaseLLM +from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool @@ -78,7 +79,7 @@ class AgentExecutor: verbose=True ) elif self.configuration.strategy == PlanningStrategy.ROUTER: - self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] + self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)] agent = MultiDatasetRouterAgent.from_llm_and_tools( model_instance=self.configuration.model_instance, tools=self.configuration.tools, @@ -86,7 +87,7 @@ class AgentExecutor: verbose=True ) elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: - self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] + self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)] agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( model_instance=self.configuration.model_instance, tools=self.configuration.tools, diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index ec02bdae9e..ec91d67290 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -10,8 +10,7 @@ from models.dataset import DocumentSegment class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" - def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None: - self.dataset_id = dataset_id + def __init__(self, conversation_message_task: ConversationMessageTask) -> None: self.conversation_message_task = conversation_message_task def on_tool_end(self, documents: List[Document]) -> None: @@ -21,7 +20,6 @@ class DatasetIndexToolCallbackHandler: # add hit count to document segment db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self.dataset_id, DocumentSegment.index_node_id == doc_id ).update( {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, diff --git a/api/core/completion.py b/api/core/completion.py index 64db2ea4ce..30ad23e629 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -127,6 +127,7 @@ class Completion: memory=memory, rest_tokens=rest_tokens_for_context_and_memory, chain_callback=chain_callback, + tenant_id=app.tenant_id, retriever_from=retriever_from ) diff --git a/api/core/data_loader/file_extractor.py b/api/core/data_loader/file_extractor.py index a603bee749..40f0c1f201 100644 --- a/api/core/data_loader/file_extractor.py +++ b/api/core/data_loader/file_extractor.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import List, Union, Optional import requests -from langchain.document_loaders import TextLoader, Docx2txtLoader +from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader from langchain.schema import Document from core.data_loader.loader.csv_loader import CSVLoader @@ -20,13 +20,13 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM class FileExtractor: @classmethod - def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]: + def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document] | str]: with tempfile.TemporaryDirectory() as temp_dir: suffix = Path(upload_file.key).suffix file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" storage.download(upload_file.key, file_path) - return cls.load_from_file(file_path, return_text, upload_file) + return cls.load_from_file(file_path, return_text, upload_file, is_automatic) @classmethod def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]: @@ -44,24 +44,34 @@ class FileExtractor: @classmethod def load_from_file(cls, file_path: str, return_text: bool = False, - upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]: + upload_file: Optional[UploadFile] = None, + is_automatic: bool = False) -> Union[List[Document] | str]: input_file = Path(file_path) delimiter = '\n' file_extension = input_file.suffix.lower() - if file_extension == '.xlsx': - loader = ExcelLoader(file_path) - elif file_extension == '.pdf': - loader = PdfLoader(file_path, upload_file=upload_file) - elif file_extension in ['.md', '.markdown']: - loader = MarkdownLoader(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: - loader = HTMLLoader(file_path) - elif file_extension == '.docx': - loader = Docx2txtLoader(file_path) - elif file_extension == '.csv': - loader = CSVLoader(file_path, autodetect_encoding=True) + if is_automatic: + loader = UnstructuredFileLoader( + file_path, strategy="hi_res", mode="elements" + ) + # loader = UnstructuredAPIFileLoader( + # file_path=filenames[0], + # api_key="FAKE_API_KEY", + # ) else: - # txt - loader = TextLoader(file_path, autodetect_encoding=True) + if file_extension == '.xlsx': + loader = ExcelLoader(file_path) + elif file_extension == '.pdf': + loader = PdfLoader(file_path, upload_file=upload_file) + elif file_extension in ['.md', '.markdown']: + loader = MarkdownLoader(file_path, autodetect_encoding=True) + elif file_extension in ['.htm', '.html']: + loader = HTMLLoader(file_path) + elif file_extension == '.docx': + loader = Docx2txtLoader(file_path) + elif file_extension == '.csv': + loader = CSVLoader(file_path, autodetect_encoding=True) + else: + # txt + loader = TextLoader(file_path, autodetect_encoding=True) return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load() diff --git a/api/core/index/vector_index/base.py b/api/core/index/vector_index/base.py index 60f092d409..bc7811a0e2 100644 --- a/api/core/index/vector_index/base.py +++ b/api/core/index/vector_index/base.py @@ -40,6 +40,13 @@ class BaseVectorIndex(BaseIndex): def _get_vector_store_class(self) -> type: raise NotImplementedError + @abstractmethod + def search_by_full_text_index( + self, query: str, + **kwargs: Any + ) -> List[Document]: + raise NotImplementedError + def search( self, query: str, **kwargs: Any diff --git a/api/core/index/vector_index/milvus_vector_index.py b/api/core/index/vector_index/milvus_vector_index.py index d0a9c19ea0..a8bba763d4 100644 --- a/api/core/index/vector_index/milvus_vector_index.py +++ b/api/core/index/vector_index/milvus_vector_index.py @@ -1,16 +1,14 @@ -from typing import Optional, cast +from typing import cast, Any, List from langchain.embeddings.base import Embeddings -from langchain.schema import Document, BaseRetriever -from langchain.vectorstores import VectorStore, milvus +from langchain.schema import Document +from langchain.vectorstores import VectorStore from pydantic import BaseModel, root_validator from core.index.base import BaseIndex from core.index.vector_index.base import BaseVectorIndex from core.vector_store.milvus_vector_store import MilvusVectorStore -from core.vector_store.weaviate_vector_store import WeaviateVectorStore -from extensions.ext_database import db -from models.dataset import Dataset, DatasetCollectionBinding +from models.dataset import Dataset class MilvusConfig(BaseModel): @@ -74,7 +72,7 @@ class MilvusVectorIndex(BaseVectorIndex): index_params = { 'metric_type': 'IP', 'index_type': "HNSW", - 'params': {"M": 8, "efConstruction": 64} + 'params': {"M": 8, "efConstruction": 64} } self._vector_store = MilvusVectorStore.from_documents( texts, @@ -152,3 +150,7 @@ class MilvusVectorIndex(BaseVectorIndex): ), ], )) + + def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: + # milvus/zilliz doesn't support bm25 search + return [] diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py index 732a10b0ae..dbadab118e 100644 --- a/api/core/index/vector_index/qdrant_vector_index.py +++ b/api/core/index/vector_index/qdrant_vector_index.py @@ -191,3 +191,21 @@ class QdrantVectorIndex(BaseVectorIndex): return True return False + + def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + from qdrant_client.http import models + return vector_store.similarity_search_by_bm25(models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self.dataset.id), + ), + models.FieldCondition( + key="page_content", + match=models.MatchText(text=query), + ) + ], + ), kwargs.get('top_k', 2)) diff --git a/api/core/index/vector_index/weaviate_vector_index.py b/api/core/index/vector_index/weaviate_vector_index.py index 1432a70707..3e8d9ae1bf 100644 --- a/api/core/index/vector_index/weaviate_vector_index.py +++ b/api/core/index/vector_index/weaviate_vector_index.py @@ -1,4 +1,4 @@ -from typing import Optional, cast +from typing import Optional, cast, Any, List import requests import weaviate @@ -26,6 +26,7 @@ class WeaviateConfig(BaseModel): class WeaviateVectorIndex(BaseVectorIndex): + def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings): super().__init__(dataset, embeddings) self._client = self._init_client(config) @@ -148,3 +149,9 @@ class WeaviateVectorIndex(BaseVectorIndex): return True return False + + def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs) + diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 9978397428..8132f2a05d 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -49,14 +49,14 @@ class IndexingRunner: if not dataset: raise ValueError("no dataset found") - # load file - text_docs = self._load_data(dataset_document) - # get the process rule processing_rule = db.session.query(DatasetProcessRule). \ filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ first() + # load file + text_docs = self._load_data(dataset_document) + # get splitter splitter = self._get_splitter(processing_rule) @@ -380,7 +380,7 @@ class IndexingRunner: "preview": preview_texts } - def _load_data(self, dataset_document: DatasetDocument) -> List[Document]: + def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]: # load file if dataset_document.data_source_type not in ["upload_file", "notion_import"]: return [] @@ -396,7 +396,7 @@ class IndexingRunner: one_or_none() if file_detail: - text_docs = FileExtractor.load(file_detail) + text_docs = FileExtractor.load(file_detail, is_automatic=False) elif dataset_document.data_source_type == 'notion_import': loader = NotionLoader.from_document(dataset_document) text_docs = loader.load() diff --git a/api/core/model_providers/model_factory.py b/api/core/model_providers/model_factory.py index f7577b392f..3a4c422a83 100644 --- a/api/core/model_providers/model_factory.py +++ b/api/core/model_providers/model_factory.py @@ -9,6 +9,7 @@ from core.model_providers.models.embedding.base import BaseEmbedding from core.model_providers.models.entity.model_params import ModelKwargs, ModelType from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.moderation.base import BaseModeration +from core.model_providers.models.reranking.base import BaseReranking from core.model_providers.models.speech2text.base import BaseSpeech2Text from extensions.ext_database import db from models.provider import TenantDefaultModel @@ -140,6 +141,44 @@ class ModelFactory: name=model_name ) + + @classmethod + def get_reranking_model(cls, + tenant_id: str, + model_provider_name: Optional[str] = None, + model_name: Optional[str] = None) -> Optional[BaseReranking]: + """ + get reranking model. + + :param tenant_id: a string representing the ID of the tenant. + :param model_provider_name: + :param model_name: + :return: + """ + if model_provider_name is None and model_name is None: + default_model = cls.get_default_model(tenant_id, ModelType.RERANKING) + + if not default_model: + raise LLMBadRequestError(f"Default model is not available. " + f"Please configure a Default Reranking Model " + f"in the Settings -> Model Provider.") + + model_provider_name = default_model.provider_name + model_name = default_model.model_name + + # get model provider + model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) + + if not model_provider: + raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") + + # init reranking model + model_class = model_provider.get_model_class(model_type=ModelType.RERANKING) + return model_class( + model_provider=model_provider, + name=model_name + ) + @classmethod def get_speech2text_model(cls, tenant_id: str, diff --git a/api/core/model_providers/model_provider_factory.py b/api/core/model_providers/model_provider_factory.py index 985dc274b3..ea17f212c1 100644 --- a/api/core/model_providers/model_provider_factory.py +++ b/api/core/model_providers/model_provider_factory.py @@ -72,6 +72,9 @@ class ModelProviderFactory: elif provider_name == 'localai': from core.model_providers.providers.localai_provider import LocalAIProvider return LocalAIProvider + elif provider_name == 'cohere': + from core.model_providers.providers.cohere_provider import CohereProvider + return CohereProvider else: raise NotImplementedError diff --git a/api/core/model_providers/models/entity/model_params.py b/api/core/model_providers/models/entity/model_params.py index 225a5cc674..0effa75e6e 100644 --- a/api/core/model_providers/models/entity/model_params.py +++ b/api/core/model_providers/models/entity/model_params.py @@ -17,7 +17,7 @@ class ModelType(enum.Enum): IMAGE = 'image' VIDEO = 'video' MODERATION = 'moderation' - + RERANKING = 'reranking' @staticmethod def value_of(value): for member in ModelType: diff --git a/api/core/model_providers/models/reranking/__init__.py b/api/core/model_providers/models/reranking/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_providers/models/reranking/base.py b/api/core/model_providers/models/reranking/base.py new file mode 100644 index 0000000000..85863895f4 --- /dev/null +++ b/api/core/model_providers/models/reranking/base.py @@ -0,0 +1,36 @@ +from abc import abstractmethod +from typing import Any, Optional, List +from langchain.schema import Document + +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.base import BaseModelProvider +import logging + +logger = logging.getLogger(__name__) + + +class BaseReranking(BaseProviderModel): + name: str + type: ModelType = ModelType.RERANKING + + def __init__(self, model_provider: BaseModelProvider, client: Any, name: str): + super().__init__(model_provider, client) + self.name = name + + @property + def base_model_name(self) -> str: + """ + get base model name + + :return: str + """ + return self.name + + @abstractmethod + def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: + raise NotImplementedError + + @abstractmethod + def handle_exceptions(self, ex: Exception) -> Exception: + raise NotImplementedError diff --git a/api/core/model_providers/models/reranking/cohere_reranking.py b/api/core/model_providers/models/reranking/cohere_reranking.py new file mode 100644 index 0000000000..3119caeae1 --- /dev/null +++ b/api/core/model_providers/models/reranking/cohere_reranking.py @@ -0,0 +1,73 @@ +import logging +from typing import Optional, List + +import cohere +import openai +from langchain.schema import Document + +from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ + LLMRateLimitError, LLMAuthorizationError +from core.model_providers.models.reranking.base import BaseReranking +from core.model_providers.providers.base import BaseModelProvider + + +class CohereReranking(BaseReranking): + + def __init__(self, model_provider: BaseModelProvider, name: str): + self.credentials = model_provider.get_model_credentials( + model_name=name, + model_type=self.type + ) + + client = cohere.Client(self.credentials.get('api_key')) + + super().__init__(model_provider, client, name) + + def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: + docs = [] + doc_id = [] + for document in documents: + if document.metadata['doc_id'] not in doc_id: + doc_id.append(document.metadata['doc_id']) + docs.append(document.page_content) + results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k) + rerank_documents = [] + + for idx, result in enumerate(results): + # format document + rerank_document = Document( + page_content=result.document['text'], + metadata={ + "doc_id": documents[result.index].metadata['doc_id'], + "doc_hash": documents[result.index].metadata['doc_hash'], + "document_id": documents[result.index].metadata['document_id'], + "dataset_id": documents[result.index].metadata['dataset_id'], + 'score': result.relevance_score + } + ) + # score threshold check + if score_threshold is not None: + if result.relevance_score >= score_threshold: + rerank_documents.append(rerank_document) + else: + rerank_documents.append(rerank_document) + return rerank_documents + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, openai.error.InvalidRequestError): + logging.warning("Invalid request to OpenAI API.") + return LLMBadRequestError(str(ex)) + elif isinstance(ex, openai.error.APIConnectionError): + logging.warning("Failed to connect to OpenAI API.") + return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): + logging.warning("OpenAI service unavailable.") + return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, openai.error.RateLimitError): + return LLMRateLimitError(str(ex)) + elif isinstance(ex, openai.error.AuthenticationError): + return LLMAuthorizationError(str(ex)) + elif isinstance(ex, openai.error.OpenAIError): + return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) + else: + return ex diff --git a/api/core/model_providers/providers/cohere_provider.py b/api/core/model_providers/providers/cohere_provider.py new file mode 100644 index 0000000000..9fa77dfff2 --- /dev/null +++ b/api/core/model_providers/providers/cohere_provider.py @@ -0,0 +1,152 @@ +import json +from json import JSONDecodeError +from typing import Type + +from langchain.schema import HumanMessage + +from core.helper import encrypter +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode +from core.model_providers.models.reranking.cohere_reranking import CohereReranking +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError +from models.provider import ProviderType + + +class CohereProvider(BaseModelProvider): + + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'cohere' + + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.CHAT.value + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + if model_type == ModelType.RERANKING: + return [ + { + 'id': 'rerank-english-v2.0', + 'name': 'rerank-english-v2.0' + }, + { + 'id': 'rerank-multilingual-v2.0', + 'name': 'rerank-multilingual-v2.0' + } + ] + else: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.RERANKING: + model_class = CohereReranking + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + return ModelKwargsRules( + temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2), + top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](enabled=False), + ) + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + """ + Validates the given credentials. + """ + if 'api_key' not in credentials: + raise CredentialsValidateFailedError('Cohere api_key must be provided.') + + try: + credential_kwargs = { + 'api_key': credentials['api_key'], + } + # todo validate + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key']) + return credentials + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + if self.provider.provider_type == ProviderType.CUSTOM.value: + try: + credentials = json.loads(self.provider.encrypted_config) + except JSONDecodeError: + credentials = { + 'api_key': None, + } + + if credentials['api_key']: + credentials['api_key'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['api_key'] + ) + + if obfuscated: + credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key']) + + return credentials + else: + return {} + + def should_deduct_quota(self): + return True + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + return + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + return {} + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + return self.get_provider_credentials(obfuscated) diff --git a/api/core/model_providers/rules/_providers.json b/api/core/model_providers/rules/_providers.json index 92d56be824..0e549828bb 100644 --- a/api/core/model_providers/rules/_providers.json +++ b/api/core/model_providers/rules/_providers.json @@ -13,5 +13,6 @@ "huggingface_hub", "xinference", "openllm", - "localai" + "localai", + "cohere" ] diff --git a/api/core/model_providers/rules/cohere.json b/api/core/model_providers/rules/cohere.json new file mode 100644 index 0000000000..0af3e61ec7 --- /dev/null +++ b/api/core/model_providers/rules/cohere.json @@ -0,0 +1,7 @@ +{ + "support_provider_types": [ + "custom" + ], + "system_config": null, + "model_flexibility": "fixed" +} \ No newline at end of file diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index d13282419a..aef020a246 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -1,11 +1,17 @@ -from typing import Optional +import json +import threading +from typing import Optional, List +from flask import Flask from langchain import WikipediaAPIWrapper from langchain.callbacks.manager import Callbacks from langchain.memory.chat_memory import BaseChatMemory from langchain.tools import BaseTool, Tool, WikipediaQueryRun from pydantic import BaseModel, Field +from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent +from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser +from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler @@ -17,6 +23,7 @@ from core.model_providers.model_factory import ModelFactory from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode from core.model_providers.models.llm.base import BaseLLM from core.tool.current_datetime_tool import DatetimeTool +from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.provider.serpapi_provider import SerpAPIToolProvider from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput @@ -25,6 +32,16 @@ from extensions.ext_database import db from models.dataset import Dataset, DatasetProcessRule from models.model import AppModelConfig +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enable': False +} class OrchestratorRuleParser: """Parse the orchestrator rule to entities.""" @@ -34,7 +51,7 @@ class OrchestratorRuleParser: self.app_model_config = app_model_config def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], - rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, + rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, tenant_id: str, retriever_from: str = 'dev') -> Optional[AgentExecutor]: if not self.app_model_config.agent_mode_dict: return None @@ -101,7 +118,8 @@ class OrchestratorRuleParser: rest_tokens=rest_tokens, return_resource=return_resource, retriever_from=retriever_from, - dataset_configs=dataset_configs + dataset_configs=dataset_configs, + tenant_id=tenant_id ) if len(tools) == 0: @@ -123,7 +141,7 @@ class OrchestratorRuleParser: return chain - def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]: + def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]: """ Convert app agent tool configs to tools @@ -132,6 +150,7 @@ class OrchestratorRuleParser: :return: """ tools = [] + dataset_tools = [] for tool_config in tool_configs: tool_type = list(tool_config.keys())[0] tool_val = list(tool_config.values())[0] @@ -140,7 +159,7 @@ class OrchestratorRuleParser: tool = None if tool_type == "dataset": - tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs) + dataset_tools.append(tool_config) elif tool_type == "web_reader": tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs) elif tool_type == "google_search": @@ -156,57 +175,81 @@ class OrchestratorRuleParser: else: tool.callbacks = callbacks tools.append(tool) - + # format dataset tool + if len(dataset_tools) > 0: + dataset_retriever_tools = self.to_dataset_retriever_tool(tool_configs=dataset_tools, **kwargs) + if dataset_retriever_tools: + tools.extend(dataset_retriever_tools) return tools - def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask, - dataset_configs: dict, rest_tokens: int, + def to_dataset_retriever_tool(self, tool_configs: List, conversation_message_task: ConversationMessageTask, return_resource: bool = False, retriever_from: str = 'dev', **kwargs) \ - -> Optional[BaseTool]: + -> Optional[List[BaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset - :param rest_tokens: - :param tool_config: - :param dataset_configs: + :param tool_configs: :param conversation_message_task: :param return_resource: :param retriever_from: :return: """ - # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == tool_config.get("id") - ).first() + dataset_configs = kwargs['dataset_configs'] + retrieval_model = dataset_configs.get('retrieval_model', 'single') + tools = [] + dataset_ids = [] + tenant_id = None + for tool_config in tool_configs: + # get dataset from dataset id + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == tool_config.get('dataset').get("id") + ).first() - if not dataset: - return None + if not dataset: + return None - if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: - return None + if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: + return None + dataset_ids.append(dataset.id) + if retrieval_model == 'single': + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + top_k = retrieval_model['top_k'] - top_k = dataset_configs.get("top_k", 2) + # dynamically adjust top_k when the remaining token number is not enough to support top_k + # top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens) - # dynamically adjust top_k when the remaining token number is not enough to support top_k - top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens) + score_threshold = None + score_threshold_enable = retrieval_model.get("score_threshold_enable") + if score_threshold_enable: + score_threshold = retrieval_model.get("score_threshold") - score_threshold = None - score_threshold_config = dataset_configs.get("score_threshold") - if score_threshold_config and score_threshold_config.get("enable"): - score_threshold = score_threshold_config.get("value") + tool = DatasetRetrieverTool.from_dataset( + dataset=dataset, + top_k=top_k, + score_threshold=score_threshold, + callbacks=[DatasetToolCallbackHandler(conversation_message_task)], + conversation_message_task=conversation_message_task, + return_resource=return_resource, + retriever_from=retriever_from + ) + tools.append(tool) + if retrieval_model == 'multiple': + tool = DatasetMultiRetrieverTool.from_dataset( + dataset_ids=dataset_ids, + tenant_id=kwargs['tenant_id'], + top_k=dataset_configs.get('top_k', 2), + score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None, + callbacks=[DatasetToolCallbackHandler(conversation_message_task)], + conversation_message_task=conversation_message_task, + return_resource=return_resource, + retriever_from=retriever_from, + reranking_provider_name=dataset_configs.get('reranking_model').get('reranking_provider_name'), + reranking_model_name=dataset_configs.get('reranking_model').get('reranking_model_name') + ) + tools.append(tool) - tool = DatasetRetrieverTool.from_dataset( - dataset=dataset, - top_k=top_k, - score_threshold=score_threshold, - callbacks=[DatasetToolCallbackHandler(conversation_message_task)], - conversation_message_task=conversation_message_task, - return_resource=return_resource, - retriever_from=retriever_from - ) - - return tool + return tools def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]: """ diff --git a/api/core/tool/dataset_multi_retriever_tool.py b/api/core/tool/dataset_multi_retriever_tool.py new file mode 100644 index 0000000000..11c32503ed --- /dev/null +++ b/api/core/tool/dataset_multi_retriever_tool.py @@ -0,0 +1,227 @@ +import json +import threading +from typing import Type, Optional, List + +from flask import current_app, Flask +from langchain.tools import BaseTool +from pydantic import Field, BaseModel + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.conversation_message_task import ConversationMessageTask +from core.embedding.cached_embedding import CacheEmbedding +from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig +from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError +from core.model_providers.model_factory import ModelFactory +from extensions.ext_database import db +from models.dataset import Dataset, DocumentSegment, Document +from services.retrieval_service import RetrievalService + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enable': False +} + + +class DatasetMultiRetrieverToolInput(BaseModel): + query: str = Field(..., description="dataset multi retriever and rerank") + + +class DatasetMultiRetrieverTool(BaseTool): + """Tool for querying multi dataset.""" + name: str = "dataset-" + args_schema: Type[BaseModel] = DatasetMultiRetrieverToolInput + description: str = "dataset multi retriever and rerank. " + tenant_id: str + dataset_ids: List[str] + top_k: int = 2 + score_threshold: Optional[float] = None + reranking_provider_name: str + reranking_model_name: str + conversation_message_task: ConversationMessageTask + return_resource: bool + retriever_from: str + + @classmethod + def from_dataset(cls, dataset_ids: List[str], tenant_id: str, **kwargs): + return cls( + name=f'dataset-{tenant_id}', + tenant_id=tenant_id, + dataset_ids=dataset_ids, + **kwargs + ) + + def _run(self, query: str) -> str: + threads = [] + all_documents = [] + for dataset_id in self.dataset_ids: + retrieval_thread = threading.Thread(target=self._retriever, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'all_documents': all_documents + }) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + rerank = ModelFactory.get_reranking_model( + tenant_id=self.tenant_id, + model_provider_name=self.reranking_provider_name, + model_name=self.reranking_model_name + ) + all_documents = rerank.rerank(query, all_documents, self.score_threshold, self.top_k) + + hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task) + hit_callback.on_tool_end(all_documents) + + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in all_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + dataset = Dataset.query.filter_by( + id=segment.dataset_id + ).first() + document = Document.query.filter(Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + 'position': resource_number, + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': self.retriever_from + } + if self.retriever_from == 'dev': + source['hit_count'] = segment.hit_count + source['word_count'] = segment.word_count + source['segment_position'] = segment.position + source['index_node_hash'] = segment.index_node_hash + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + resource_number += 1 + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + + async def _arun(self, tool_input: str) -> str: + raise NotImplementedError() + + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: List): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == self.tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + return [] + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + kw_table_index = KeywordTableIndex( + dataset=dataset, + config=KeywordTableConfig( + max_keywords_per_chunk=5 + ) + ) + + documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) + if documents: + all_documents.extend(documents) + else: + + try: + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id, + model_provider_name=dataset.embedding_model_provider, + model_name=dataset.embedding_model + ) + except LLMBadRequestError: + return [] + except ProviderTokenNotInitError: + return [] + + embeddings = CacheEmbedding(embedding_model) + + documents = [] + threads = [] + if self.top_k > 0: + # retrieval_model source with semantic + if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[ + 'search_method'] == 'hybrid_search': + embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset': dataset, + 'query': query, + 'top_k': self.top_k, + 'score_threshold': self.score_threshold, + 'reranking_model': None, + 'all_documents': documents, + 'search_method': 'hybrid_search', + 'embeddings': embeddings + }) + threads.append(embedding_thread) + embedding_thread.start() + + # retrieval_model source with full text + if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[ + 'search_method'] == 'hybrid_search': + full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, + kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset': dataset, + 'query': query, + 'search_method': 'hybrid_search', + 'embeddings': embeddings, + 'score_threshold': retrieval_model[ + 'score_threshold'] if retrieval_model[ + 'score_threshold_enable'] else None, + 'top_k': self.top_k, + 'reranking_model': retrieval_model[ + 'reranking_model'] if retrieval_model[ + 'reranking_enable'] else None, + 'all_documents': documents + }) + threads.append(full_text_index_thread) + full_text_index_thread.start() + + for thread in threads: + thread.join() + + all_documents.extend(documents) diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index 2c14f40d15..de19ccb6b5 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -1,5 +1,6 @@ import json -from typing import Type, Optional +import threading +from typing import Type, Optional, List from flask import current_app from langchain.tools import BaseTool @@ -14,6 +15,18 @@ from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitE from core.model_providers.model_factory import ModelFactory from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment, Document +from services.retrieval_service import RetrievalService + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enable': False +} class DatasetRetrieverToolInput(BaseModel): @@ -56,7 +69,9 @@ class DatasetRetrieverTool(BaseTool): ).first() if not dataset: - return f'[{self.name} failed to find dataset with id {self.dataset_id}.]' + return '' + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query @@ -83,28 +98,62 @@ class DatasetRetrieverTool(BaseTool): return '' embeddings = CacheEmbedding(embedding_model) - vector_index = VectorIndex( - dataset=dataset, - config=current_app.config, - embeddings=embeddings - ) + documents = [] + threads = [] if self.top_k > 0: - documents = vector_index.search( - query, - search_type='similarity_score_threshold', - search_kwargs={ - 'k': self.top_k, - 'score_threshold': self.score_threshold, - 'filter': { - 'group_id': [dataset.id] - } - } - ) + # retrieval source with semantic + if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': + embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset': dataset, + 'query': query, + 'top_k': self.top_k, + 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ + 'score_threshold_enable'] else None, + 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ + 'reranking_enable'] else None, + 'all_documents': documents, + 'search_method': retrieval_model['search_method'], + 'embeddings': embeddings + }) + threads.append(embedding_thread) + embedding_thread.start() + + # retrieval_model source with full text + if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': + full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset': dataset, + 'query': query, + 'search_method': retrieval_model['search_method'], + 'embeddings': embeddings, + 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ + 'score_threshold_enable'] else None, + 'top_k': self.top_k, + 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ + 'reranking_enable'] else None, + 'all_documents': documents + }) + threads.append(full_text_index_thread) + full_text_index_thread.start() + + for thread in threads: + thread.join() + # hybrid search: rerank after all documents have been searched + if retrieval_model['search_method'] == 'hybrid_search': + hybrid_rerank = ModelFactory.get_reranking_model( + tenant_id=dataset.tenant_id, + model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'], + model_name=retrieval_model['reranking_model']['reranking_model_name'] + ) + documents = hybrid_rerank.rerank(query, documents, + retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, + self.top_k) else: documents = [] - hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task) + hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task) hit_callback.on_tool_end(documents) document_score_list = {} if dataset.indexing_technique != "economy": diff --git a/api/core/vector_store/milvus_vector_store.py b/api/core/vector_store/milvus_vector_store.py index a70445dd4c..0055d76c94 100644 --- a/api/core/vector_store/milvus_vector_store.py +++ b/api/core/vector_store/milvus_vector_store.py @@ -1,4 +1,4 @@ -from core.index.vector_index.milvus import Milvus +from core.vector_store.vector.milvus import Milvus class MilvusVectorStore(Milvus): diff --git a/api/core/vector_store/qdrant_vector_store.py b/api/core/vector_store/qdrant_vector_store.py index dc92b8cb24..e4f6c2c78f 100644 --- a/api/core/vector_store/qdrant_vector_store.py +++ b/api/core/vector_store/qdrant_vector_store.py @@ -4,7 +4,7 @@ from langchain.schema import Document from qdrant_client.http.models import Filter, PointIdsList, FilterSelector from qdrant_client.local.qdrant_local import QdrantLocal -from core.index.vector_index.qdrant import Qdrant +from core.vector_store.vector.qdrant import Qdrant class QdrantVectorStore(Qdrant): @@ -73,3 +73,4 @@ class QdrantVectorStore(Qdrant): if isinstance(self.client, QdrantLocal): self.client = cast(QdrantLocal, self.client) self.client._load() + diff --git a/api/core/index/vector_index/milvus.py b/api/core/vector_store/vector/milvus.py similarity index 100% rename from api/core/index/vector_index/milvus.py rename to api/core/vector_store/vector/milvus.py diff --git a/api/core/index/vector_index/qdrant.py b/api/core/vector_store/vector/qdrant.py similarity index 97% rename from api/core/index/vector_index/qdrant.py rename to api/core/vector_store/vector/qdrant.py index 5b9736a0b5..33ba0908dd 100644 --- a/api/core/index/vector_index/qdrant.py +++ b/api/core/vector_store/vector/qdrant.py @@ -28,7 +28,7 @@ from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance -from qdrant_client.http.models import PayloadSchemaType +from qdrant_client.http.models import PayloadSchemaType, FilterSelector, TextIndexParams, TokenizerType, TextIndexType if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -189,14 +189,25 @@ class Qdrant(VectorStore): texts, metadatas, ids, batch_size ): self.client.upsert( - collection_name=self.collection_name, points=points, **kwargs + collection_name=self.collection_name, points=points ) added_ids.extend(batch_ids) # if is new collection, create payload index on group_id if self.is_new_collection: + # create payload index self.client.create_payload_index(self.collection_name, self.group_payload_key, field_schema=PayloadSchemaType.KEYWORD, field_type=PayloadSchemaType.KEYWORD) + # creat full text index + text_index_params = TextIndexParams( + type=TextIndexType.TEXT, + tokenizer=TokenizerType.MULTILINGUAL, + min_token_len=2, + max_token_len=20, + lowercase=True + ) + self.client.create_payload_index(self.collection_name, self.content_payload_key, + field_schema=text_index_params) return added_ids @sync_call_fallback @@ -600,7 +611,7 @@ class Qdrant(VectorStore): limit=k, offset=offset, with_payload=True, - with_vectors=True, # Langchain does not expect vectors to be returned + with_vectors=True, score_threshold=score_threshold, consistency=consistency, **kwargs, @@ -615,6 +626,39 @@ class Qdrant(VectorStore): for result in results ] + def similarity_search_by_bm25( + self, + filter: Optional[MetadataFilter] = None, + k: int = 4 + ) -> List[Document]: + """Return docs most similar by bm25. + + Args: + embedding: Embedding vector to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter by metadata. Defaults to None. + search_params: Additional search params + Returns: + List of documents most similar to the query text and distance for each. + """ + response = self.client.scroll( + collection_name=self.collection_name, + scroll_filter=filter, + limit=k, + with_payload=True, + with_vectors=True + + ) + results = response[0] + documents = [] + for result in results: + if result: + documents.append(self._document_from_scored_point( + result, self.content_payload_key, self.metadata_payload_key + )) + + return documents + @sync_call_fallback async def asimilarity_search_with_score_by_vector( self, diff --git a/api/core/vector_store/vector/weaviate.py b/api/core/vector_store/vector/weaviate.py new file mode 100644 index 0000000000..afbf68db68 --- /dev/null +++ b/api/core/vector_store/vector/weaviate.py @@ -0,0 +1,505 @@ +"""Wrapper around weaviate vector database.""" +from __future__ import annotations + +import datetime +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type +from uuid import uuid4 + +import numpy as np + +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env +from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.utils import maximal_marginal_relevance + + +def _default_schema(index_name: str) -> Dict: + return { + "class": index_name, + "properties": [ + { + "name": "text", + "dataType": ["text"], + } + ], + } + + +def _create_weaviate_client(**kwargs: Any) -> Any: + client = kwargs.get("client") + if client is not None: + return client + + weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL") + + try: + # the weaviate api key param should not be mandatory + weaviate_api_key = get_from_dict_or_env( + kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None + ) + except ValueError: + weaviate_api_key = None + + try: + import weaviate + except ImportError: + raise ValueError( + "Could not import weaviate python package. " + "Please install it with `pip install weaviate-client`" + ) + + auth = ( + weaviate.auth.AuthApiKey(api_key=weaviate_api_key) + if weaviate_api_key is not None + else None + ) + client = weaviate.Client(weaviate_url, auth_client_secret=auth) + + return client + + +def _default_score_normalizer(val: float) -> float: + return 1 - 1 / (1 + np.exp(val)) + + +def _json_serializable(value: Any) -> Any: + if isinstance(value, datetime.datetime): + return value.isoformat() + return value + + +class Weaviate(VectorStore): + """Wrapper around Weaviate vector database. + + To use, you should have the ``weaviate-client`` python package installed. + + Example: + .. code-block:: python + + import weaviate + from langchain.vectorstores import Weaviate + client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...) + weaviate = Weaviate(client, index_name, text_key) + + """ + + def __init__( + self, + client: Any, + index_name: str, + text_key: str, + embedding: Optional[Embeddings] = None, + attributes: Optional[List[str]] = None, + relevance_score_fn: Optional[ + Callable[[float], float] + ] = _default_score_normalizer, + by_text: bool = True, + ): + """Initialize with Weaviate client.""" + try: + import weaviate + except ImportError: + raise ValueError( + "Could not import weaviate python package. " + "Please install it with `pip install weaviate-client`." + ) + if not isinstance(client, weaviate.Client): + raise ValueError( + f"client should be an instance of weaviate.Client, got {type(client)}" + ) + self._client = client + self._index_name = index_name + self._embedding = embedding + self._text_key = text_key + self._query_attrs = [self._text_key] + self.relevance_score_fn = relevance_score_fn + self._by_text = by_text + if attributes is not None: + self._query_attrs.extend(attributes) + + @property + def embeddings(self) -> Optional[Embeddings]: + return self._embedding + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + return ( + self.relevance_score_fn + if self.relevance_score_fn + else _default_score_normalizer + ) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> List[str]: + """Upload texts with metadata (properties) to Weaviate.""" + from weaviate.util import get_valid_uuid + + ids = [] + embeddings: Optional[List[List[float]]] = None + if self._embedding: + if not isinstance(texts, list): + texts = list(texts) + embeddings = self._embedding.embed_documents(texts) + + with self._client.batch as batch: + for i, text in enumerate(texts): + data_properties = {self._text_key: text} + if metadatas is not None: + for key, val in metadatas[i].items(): + data_properties[key] = _json_serializable(val) + + # Allow for ids (consistent w/ other methods) + # # Or uuids (backwards compatble w/ existing arg) + # If the UUID of one of the objects already exists + # then the existing object will be replaced by the new object. + _id = get_valid_uuid(uuid4()) + if "uuids" in kwargs: + _id = kwargs["uuids"][i] + elif "ids" in kwargs: + _id = kwargs["ids"][i] + + batch.add_data_object( + data_object=data_properties, + class_name=self._index_name, + uuid=_id, + vector=embeddings[i] if embeddings else None, + ) + ids.append(_id) + return ids + + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the query. + """ + if self._by_text: + return self.similarity_search_by_text(query, k, **kwargs) + else: + if self._embedding is None: + raise ValueError( + "_embedding cannot be None for similarity_search when " + "_by_text=False" + ) + embedding = self._embedding.embed_query(query) + return self.similarity_search_by_vector(embedding, k, **kwargs) + + def similarity_search_by_text( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the query. + """ + content: Dict[str, Any] = {"concepts": [query]} + if kwargs.get("search_distance"): + content["certainty"] = kwargs.get("search_distance") + query_obj = self._client.query.get(self._index_name, self._query_attrs) + if kwargs.get("where_filter"): + query_obj = query_obj.with_where(kwargs.get("where_filter")) + if kwargs.get("additional"): + query_obj = query_obj.with_additional(kwargs.get("additional")) + result = query_obj.with_near_text(content).with_limit(k).do() + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") + docs = [] + for res in result["data"]["Get"][self._index_name]: + text = res.pop(self._text_key) + docs.append(Document(page_content=text, metadata=res)) + return docs + + def similarity_search_by_bm25( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + """Return docs using BM25F. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of Documents most similar to the query. + """ + content: Dict[str, Any] = {"concepts": [query]} + if kwargs.get("search_distance"): + content["certainty"] = kwargs.get("search_distance") + query_obj = self._client.query.get(self._index_name, self._query_attrs) + if kwargs.get("where_filter"): + query_obj = query_obj.with_where(kwargs.get("where_filter")) + if kwargs.get("additional"): + query_obj = query_obj.with_additional(kwargs.get("additional")) + result = query_obj.with_bm25(query=content).with_limit(k).do() + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") + docs = [] + for res in result["data"]["Get"][self._index_name]: + text = res.pop(self._text_key) + docs.append(Document(page_content=text, metadata=res)) + return docs + + def similarity_search_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: + """Look up similar documents by embedding vector in Weaviate.""" + vector = {"vector": embedding} + query_obj = self._client.query.get(self._index_name, self._query_attrs) + if kwargs.get("where_filter"): + query_obj = query_obj.with_where(kwargs.get("where_filter")) + if kwargs.get("additional"): + query_obj = query_obj.with_additional(kwargs.get("additional")) + result = query_obj.with_near_vector(vector).with_limit(k).do() + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") + docs = [] + for res in result["data"]["Get"][self._index_name]: + text = res.pop(self._text_key) + docs.append(Document(page_content=text, metadata=res)) + return docs + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + + Returns: + List of Documents selected by maximal marginal relevance. + """ + if self._embedding is not None: + embedding = self._embedding.embed_query(query) + else: + raise ValueError( + "max_marginal_relevance_search requires a suitable Embeddings object" + ) + + return self.max_marginal_relevance_search_by_vector( + embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs + ) + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + + Returns: + List of Documents selected by maximal marginal relevance. + """ + vector = {"vector": embedding} + query_obj = self._client.query.get(self._index_name, self._query_attrs) + if kwargs.get("where_filter"): + query_obj = query_obj.with_where(kwargs.get("where_filter")) + results = ( + query_obj.with_additional("vector") + .with_near_vector(vector) + .with_limit(fetch_k) + .do() + ) + + payload = results["data"]["Get"][self._index_name] + embeddings = [result["_additional"]["vector"] for result in payload] + mmr_selected = maximal_marginal_relevance( + np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult + ) + + docs = [] + for idx in mmr_selected: + text = payload[idx].pop(self._text_key) + payload[idx].pop("_additional") + meta = payload[idx] + docs.append(Document(page_content=text, metadata=meta)) + return docs + + def similarity_search_with_score( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Tuple[Document, float]]: + """ + Return list of documents most similar to the query + text and cosine distance in float for each. + Lower score represents more similarity. + """ + if self._embedding is None: + raise ValueError( + "_embedding cannot be None for similarity_search_with_score" + ) + content: Dict[str, Any] = {"concepts": [query]} + if kwargs.get("search_distance"): + content["certainty"] = kwargs.get("search_distance") + query_obj = self._client.query.get(self._index_name, self._query_attrs) + + embedded_query = self._embedding.embed_query(query) + if not self._by_text: + vector = {"vector": embedded_query} + result = ( + query_obj.with_near_vector(vector) + .with_limit(k) + .with_additional("vector") + .do() + ) + else: + result = ( + query_obj.with_near_text(content) + .with_limit(k) + .with_additional("vector") + .do() + ) + + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") + + docs_and_scores = [] + for res in result["data"]["Get"][self._index_name]: + text = res.pop(self._text_key) + score = np.dot(res["_additional"]["vector"], embedded_query) + docs_and_scores.append((Document(page_content=text, metadata=res), score)) + return docs_and_scores + + @classmethod + def from_texts( + cls: Type[Weaviate], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> Weaviate: + """Construct Weaviate wrapper from raw documents. + + This is a user-friendly interface that: + 1. Embeds documents. + 2. Creates a new index for the embeddings in the Weaviate instance. + 3. Adds the documents to the newly created Weaviate index. + + This is intended to be a quick way to get started. + + Example: + .. code-block:: python + + from langchain.vectorstores.weaviate import Weaviate + from langchain.embeddings import OpenAIEmbeddings + embeddings = OpenAIEmbeddings() + weaviate = Weaviate.from_texts( + texts, + embeddings, + weaviate_url="http://localhost:8080" + ) + """ + + client = _create_weaviate_client(**kwargs) + + from weaviate.util import get_valid_uuid + + index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}") + embeddings = embedding.embed_documents(texts) if embedding else None + text_key = "text" + schema = _default_schema(index_name) + attributes = list(metadatas[0].keys()) if metadatas else None + + # check whether the index already exists + if not client.schema.contains(schema): + client.schema.create_class(schema) + + with client.batch as batch: + for i, text in enumerate(texts): + data_properties = { + text_key: text, + } + if metadatas is not None: + for key in metadatas[i].keys(): + data_properties[key] = metadatas[i][key] + + # If the UUID of one of the objects already exists + # then the existing objectwill be replaced by the new object. + if "uuids" in kwargs: + _id = kwargs["uuids"][i] + else: + _id = get_valid_uuid(uuid4()) + + # if an embedding strategy is not provided, we let + # weaviate create the embedding. Note that this will only + # work if weaviate has been installed with a vectorizer module + # like text2vec-contextionary for example + params = { + "uuid": _id, + "data_object": data_properties, + "class_name": index_name, + } + if embeddings is not None: + params["vector"] = embeddings[i] + + batch.add_data_object(**params) + + batch.flush() + + relevance_score_fn = kwargs.get("relevance_score_fn") + by_text: bool = kwargs.get("by_text", False) + + return cls( + client, + index_name, + text_key, + embedding=embedding, + attributes=attributes, + relevance_score_fn=relevance_score_fn, + by_text=by_text, + ) + + def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: + """Delete by vector IDs. + + Args: + ids: List of ids to delete. + """ + + if ids is None: + raise ValueError("No ids provided to delete.") + + # TODO: Check if this can be done in bulk + for id in ids: + self._client.data_object.delete(uuid=id) diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 90af9e1fdd..d7be65be01 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -12,6 +12,21 @@ dataset_fields = { 'created_at': TimestampField, } +reranking_model_fields = { + 'reranking_provider_name': fields.String, + 'reranking_model_name': fields.String +} + +dataset_retrieval_model_fields = { + 'search_method': fields.String, + 'reranking_enable': fields.Boolean, + 'reranking_model': fields.Nested(reranking_model_fields), + 'top_k': fields.Integer, + 'score_threshold_enable': fields.Boolean, + 'score_threshold': fields.Float +} + + dataset_detail_fields = { 'id': fields.String, 'name': fields.String, @@ -29,7 +44,8 @@ dataset_detail_fields = { 'updated_at': TimestampField, 'embedding_model': fields.String, 'embedding_model_provider': fields.String, - 'embedding_available': fields.Boolean + 'embedding_available': fields.Boolean, + 'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields) } dataset_query_detail_fields = { @@ -41,3 +57,5 @@ dataset_query_detail_fields = { "created_by": fields.String, "created_at": TimestampField } + + diff --git a/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py b/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py new file mode 100644 index 0000000000..c16781c15d --- /dev/null +++ b/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py @@ -0,0 +1,43 @@ +"""add-dataset-retrival-model + +Revision ID: fca025d3b60f +Revises: b3a09c049e8e +Create Date: 2023-11-03 13:08:23.246396 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'fca025d3b60f' +down_revision = '8fe468ba0ca5' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('sessions') + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('retrieval_model', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + batch_op.create_index('retrieval_model_idx', ['retrieval_model'], unique=False, postgresql_using='gin') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_index('retrieval_model_idx', postgresql_using='gin') + batch_op.drop_column('retrieval_model') + + op.create_table('sessions', + sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('session_id', sa.VARCHAR(length=255), autoincrement=False, nullable=True), + sa.Column('data', postgresql.BYTEA(), autoincrement=False, nullable=True), + sa.Column('expiry', postgresql.TIMESTAMP(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('id', name='sessions_pkey'), + sa.UniqueConstraint('session_id', name='sessions_session_id_key') + ) + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index a9a33cc1a7..5fbf035f84 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -3,7 +3,7 @@ import pickle from json import JSONDecodeError from sqlalchemy import func -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql import UUID, JSONB from extensions.ext_database import db from models.account import Account @@ -15,6 +15,7 @@ class Dataset(db.Model): __table_args__ = ( db.PrimaryKeyConstraint('id', name='dataset_pkey'), db.Index('dataset_tenant_idx', 'tenant_id'), + db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin') ) INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy'] @@ -39,7 +40,7 @@ class Dataset(db.Model): embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) collection_binding_id = db.Column(UUID, nullable=True) - + retrieval_model = db.Column(JSONB, nullable=True) @property def dataset_keyword_table(self): @@ -93,6 +94,20 @@ class Dataset(db.Model): return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ .filter(Document.dataset_id == self.id).scalar() + @property + def retrieval_model_dict(self): + default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enable': False + } + return self.retrieval_model if self.retrieval_model else default_retrieval_model + class DatasetProcessRule(db.Model): __tablename__ = 'dataset_process_rules' @@ -120,7 +135,7 @@ class DatasetProcessRule(db.Model): ], 'segmentation': { 'delimiter': '\n', - 'max_tokens': 1000 + 'max_tokens': 512 } } @@ -462,4 +477,3 @@ class DatasetCollectionBinding(db.Model): model_name = db.Column(db.String(40), nullable=False) collection_name = db.Column(db.String(64), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - diff --git a/api/models/model.py b/api/models/model.py index b7cd428839..b3570f7f42 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -160,7 +160,13 @@ class AppModelConfig(db.Model): @property def dataset_configs_dict(self) -> dict: - return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}} + if self.dataset_configs: + dataset_configs = json.loads(self.dataset_configs) + if 'retrieval_model' not in dataset_configs: + return {'retrieval_model': 'single'} + else: + return dataset_configs + return {'retrieval_model': 'single'} @property def file_upload_dict(self) -> dict: diff --git a/api/requirements.txt b/api/requirements.txt index 91b373d406..6e32064a05 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -23,7 +23,6 @@ boto3==1.28.17 tenacity==8.2.2 cachetools~=5.3.0 weaviate-client~=3.21.0 -qdrant_client~=1.1.6 mailchimp-transactional~=1.0.50 scikit-learn==1.2.2 sentry-sdk[flask]~=1.21.1 @@ -53,4 +52,6 @@ xinference-client~=0.5.4 safetensors==0.3.2 zhipuai==1.0.7 werkzeug==2.3.7 -pymilvus==2.3.0 \ No newline at end of file +pymilvus==2.3.0 +qdrant-client==1.6.4 +cohere~=4.32 \ No newline at end of file diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index be7947d7f6..3ffd8b0431 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -470,7 +470,16 @@ class AppModelConfigService: # dataset_configs if 'dataset_configs' not in config or not config["dataset_configs"]: - config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}} + config["dataset_configs"] = {'retrieval_model': 'single'} + + if not isinstance(config["dataset_configs"], dict): + raise ValueError("dataset_configs must be of object type") + + if config["dataset_configs"]['retrieval_model'] == 'multiple': + if not config["dataset_configs"]['reranking_model']: + raise ValueError("reranking_model has not been set") + if not isinstance(config["dataset_configs"]['reranking_model'], dict): + raise ValueError("reranking_model must be of object type") if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ede3124694..defe539ae9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -173,6 +173,9 @@ class DatasetService: filtered_data['updated_by'] = user.id filtered_data['updated_at'] = datetime.datetime.now() + # update Retrieval model + filtered_data['retrieval_model'] = data['retrieval_model'] + dataset.query.filter_by(id=dataset_id).update(filtered_data) db.session.commit() @@ -473,7 +476,19 @@ class DocumentService: embedding_model.name ) dataset.collection_binding_id = dataset_collection_binding.id + if not dataset.retrieval_model: + default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enable': False + } + dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model documents = [] batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) @@ -733,6 +748,7 @@ class DocumentService: raise ValueError(f"All your documents have overed limit {tenant_document_count}.") embedding_model = None dataset_collection_binding_id = None + retrieval_model = None if document_data['indexing_technique'] == 'high_quality': embedding_model = ModelFactory.get_embedding_model( tenant_id=tenant_id @@ -742,6 +758,20 @@ class DocumentService: embedding_model.name ) dataset_collection_binding_id = dataset_collection_binding.id + if 'retrieval_model' in document_data and document_data['retrieval_model']: + retrieval_model = document_data['retrieval_model'] + else: + default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enable': False + } + retrieval_model = default_retrieval_model # save dataset dataset = Dataset( tenant_id=tenant_id, @@ -751,7 +781,8 @@ class DocumentService: created_by=account.id, embedding_model=embedding_model.name if embedding_model else None, embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None, - collection_binding_id=dataset_collection_binding_id + collection_binding_id=dataset_collection_binding_id, + retrieval_model=retrieval_model ) db.session.add(dataset) @@ -768,7 +799,7 @@ class DocumentService: return dataset, documents, batch @classmethod - def document_create_args_validate(cls, args: dict): + def document_create_args_validate(cls, args: dict): if 'original_document_id' not in args or not args['original_document_id']: DocumentService.data_source_args_validate(args) DocumentService.process_rule_args_validate(args) diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 063292969c..d9725a66d8 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,4 +1,6 @@ +import json import logging +import threading import time from typing import List @@ -9,16 +11,26 @@ from langchain.schema import Document from sklearn.manifold import TSNE from core.embedding.cached_embedding import CacheEmbedding -from core.index.vector_index.vector_index import VectorIndex from core.model_providers.model_factory import ModelFactory from extensions.ext_database import db from models.account import Account from models.dataset import Dataset, DocumentSegment, DatasetQuery +from services.retrieval_service import RetrievalService +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enable': False +} class HitTestingService: @classmethod - def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: + def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict: if dataset.available_document_count == 0 or dataset.available_segment_count == 0: return { "query": { @@ -28,31 +40,68 @@ class HitTestingService: "records": [] } + start = time.perf_counter() + + # get retrieval model , if the model is not setting , using default + if not retrieval_model: + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + + # get embedding model embedding_model = ModelFactory.get_embedding_model( tenant_id=dataset.tenant_id, model_provider_name=dataset.embedding_model_provider, model_name=dataset.embedding_model ) - embeddings = CacheEmbedding(embedding_model) - vector_index = VectorIndex( - dataset=dataset, - config=current_app.config, - embeddings=embeddings - ) + all_documents = [] + threads = [] + + # retrieval_model source with semantic + if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': + embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset': dataset, + 'query': query, + 'top_k': retrieval_model['top_k'], + 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, + 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, + 'all_documents': all_documents, + 'search_method': retrieval_model['search_method'], + 'embeddings': embeddings + }) + threads.append(embedding_thread) + embedding_thread.start() + + # retrieval source with full text + if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': + full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset': dataset, + 'query': query, + 'search_method': retrieval_model['search_method'], + 'embeddings': embeddings, + 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, + 'top_k': retrieval_model['top_k'], + 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, + 'all_documents': all_documents + }) + threads.append(full_text_index_thread) + full_text_index_thread.start() + + for thread in threads: + thread.join() + + if retrieval_model['search_method'] == 'hybrid_search': + hybrid_rerank = ModelFactory.get_reranking_model( + tenant_id=dataset.tenant_id, + model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'], + model_name=retrieval_model['reranking_model']['reranking_model_name'] + ) + all_documents = hybrid_rerank.rerank(query, all_documents, + retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, + retrieval_model['top_k']) - start = time.perf_counter() - documents = vector_index.search( - query, - search_type='similarity_score_threshold', - search_kwargs={ - 'k': 10, - 'filter': { - 'group_id': [dataset.id] - } - } - ) end = time.perf_counter() logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") @@ -67,7 +116,7 @@ class HitTestingService: db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(dataset, embeddings, query, documents) + return cls.compact_retrieve_response(dataset, embeddings, query, all_documents) @classmethod def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]): @@ -99,7 +148,7 @@ class HitTestingService: record = { "segment": segment, - "score": document.metadata['score'], + "score": document.metadata.get('score', None), "tsne_position": tsne_position_data[i] } @@ -136,3 +185,11 @@ class HitTestingService: tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])}) return tsne_position_data + + @classmethod + def hit_testing_args_check(cls, args): + query = args['query'] + + if not query or len(query) > 250: + raise ValueError('Query is required and cannot exceed 250 characters') + diff --git a/api/services/retrieval_service.py b/api/services/retrieval_service.py new file mode 100644 index 0000000000..3e6b93f862 --- /dev/null +++ b/api/services/retrieval_service.py @@ -0,0 +1,88 @@ + +from typing import Optional +from flask import current_app, Flask +from langchain.embeddings.base import Embeddings +from core.index.vector_index.vector_index import VectorIndex +from core.model_providers.model_factory import ModelFactory +from models.dataset import Dataset + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enable': False +} + + +class RetrievalService: + + @classmethod + def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str, + top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], + all_documents: list, search_method: str, embeddings: Embeddings): + with flask_app.app_context(): + + vector_index = VectorIndex( + dataset=dataset, + config=current_app.config, + embeddings=embeddings + ) + + documents = vector_index.search( + query, + search_type='similarity_score_threshold', + search_kwargs={ + 'k': top_k, + 'score_threshold': score_threshold, + 'filter': { + 'group_id': [dataset.id] + } + } + ) + + if documents: + if reranking_model and search_method == 'semantic_search': + rerank = ModelFactory.get_reranking_model( + tenant_id=dataset.tenant_id, + model_provider_name=reranking_model['reranking_provider_name'], + model_name=reranking_model['reranking_model_name'] + ) + all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents))) + else: + all_documents.extend(documents) + + @classmethod + def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str, + top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], + all_documents: list, search_method: str, embeddings: Embeddings): + with flask_app.app_context(): + + vector_index = VectorIndex( + dataset=dataset, + config=current_app.config, + embeddings=embeddings + ) + + documents = vector_index.search_by_full_text_index( + query, + search_type='similarity_score_threshold', + top_k=top_k + ) + if documents: + if reranking_model and search_method == 'full_text_search': + rerank = ModelFactory.get_reranking_model( + tenant_id=dataset.tenant_id, + model_provider_name=reranking_model['reranking_provider_name'], + model_name=reranking_model['reranking_model_name'] + ) + all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents))) + else: + all_documents.extend(documents) + + + +