mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 17:58:59 +08:00
Feat/add retriever rerank (#1560)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
parent
a4f37220a0
commit
4588831bff
@ -8,6 +8,8 @@ import time
|
||||
import uuid
|
||||
|
||||
import click
|
||||
import qdrant_client
|
||||
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
|
||||
from tqdm import tqdm
|
||||
from flask import current_app, Flask
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
@ -484,6 +486,38 @@ def normalization_collections():
|
||||
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
|
||||
|
||||
|
||||
@click.command('add-qdrant-full-text-index', help='add qdrant full text index')
|
||||
def add_qdrant_full_text_index():
|
||||
click.echo(click.style('Start add full text index.', fg='green'))
|
||||
binds = db.session.query(DatasetCollectionBinding).all()
|
||||
if binds and current_app.config['VECTOR_STORE'] == 'qdrant':
|
||||
qdrant_url = current_app.config['QDRANT_URL']
|
||||
qdrant_api_key = current_app.config['QDRANT_API_KEY']
|
||||
client = qdrant_client.QdrantClient(
|
||||
qdrant_url,
|
||||
api_key=qdrant_api_key, # For Qdrant Cloud, None for local instance
|
||||
)
|
||||
for bind in binds:
|
||||
try:
|
||||
text_index_params = TextIndexParams(
|
||||
type=TextIndexType.TEXT,
|
||||
tokenizer=TokenizerType.MULTILINGUAL,
|
||||
min_token_len=2,
|
||||
max_token_len=20,
|
||||
lowercase=True
|
||||
)
|
||||
client.create_payload_index(bind.collection_name, 'page_content',
|
||||
field_schema=text_index_params)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.echo(
|
||||
click.style(
|
||||
'Congratulations! add collection {} full text index successful.'.format(bind.collection_name),
|
||||
fg='green'))
|
||||
|
||||
|
||||
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
@ -647,10 +681,10 @@ def update_app_model_configs(batch_size):
|
||||
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
|
||||
@click.command('migrate_default_input_to_dataset_query_variable')
|
||||
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
|
||||
def migrate_default_input_to_dataset_query_variable(batch_size):
|
||||
|
||||
click.secho("Starting...", fg='green')
|
||||
|
||||
total_records = db.session.query(AppModelConfig) \
|
||||
@ -658,13 +692,13 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
|
||||
.filter(App.mode == 'completion') \
|
||||
.filter(AppModelConfig.dataset_query_variable == None) \
|
||||
.count()
|
||||
|
||||
|
||||
if total_records == 0:
|
||||
click.secho("No data to migrate.", fg='green')
|
||||
return
|
||||
|
||||
num_batches = (total_records + batch_size - 1) // batch_size
|
||||
|
||||
|
||||
with tqdm(total=total_records, desc="Migrating Data") as pbar:
|
||||
for i in range(num_batches):
|
||||
offset = i * batch_size
|
||||
@ -697,14 +731,14 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
|
||||
for form in user_input_form:
|
||||
paragraph = form.get('paragraph')
|
||||
if paragraph \
|
||||
and paragraph.get('variable') == 'query':
|
||||
data.dataset_query_variable = 'query'
|
||||
break
|
||||
|
||||
and paragraph.get('variable') == 'query':
|
||||
data.dataset_query_variable = 'query'
|
||||
break
|
||||
|
||||
if paragraph \
|
||||
and paragraph.get('variable') == 'default_input':
|
||||
data.dataset_query_variable = 'default_input'
|
||||
break
|
||||
and paragraph.get('variable') == 'default_input':
|
||||
data.dataset_query_variable = 'default_input'
|
||||
break
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@ -712,7 +746,7 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
|
||||
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
|
||||
fg='red')
|
||||
continue
|
||||
|
||||
|
||||
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
|
||||
|
||||
pbar.update(len(data_batch))
|
||||
@ -731,3 +765,4 @@ def register_commands(app):
|
||||
app.cli.add_command(update_app_model_configs)
|
||||
app.cli.add_command(normalization_collections)
|
||||
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
|
||||
app.cli.add_command(add_qdrant_full_text_index)
|
||||
|
@ -170,6 +170,7 @@ class DatasetApi(Resource):
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
@ -401,6 +402,7 @@ class DatasetApiKeyApi(Resource):
|
||||
|
||||
class DatasetApiDeleteApi(Resource):
|
||||
resource_type = 'dataset'
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -436,6 +438,50 @@ class DatasetApiBaseUrlApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
class DatasetRetrievalSettingApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = current_app.config['VECTOR_STORE']
|
||||
if vector_type == 'milvus':
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
]
|
||||
}
|
||||
elif vector_type == 'qdrant' or vector_type == 'weaviate':
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search', 'full_text_search', 'hybrid_search'
|
||||
]
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unsupported vector db type.")
|
||||
|
||||
|
||||
class DatasetRetrievalSettingMockApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
|
||||
if vector_type == 'milvus':
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
]
|
||||
}
|
||||
elif vector_type == 'qdrant' or vector_type == 'weaviate':
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search', 'full_text_search', 'hybrid_search'
|
||||
]
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unsupported vector db type.")
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
@ -445,3 +491,5 @@ api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing
|
||||
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
|
||||
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
|
@ -221,6 +221,8 @@ class DatasetDocumentListApi(Resource):
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
@ -263,6 +265,8 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
if args['indexing_technique'] == 'high_quality':
|
||||
try:
|
||||
|
@ -42,19 +42,18 @@ class HitTestingApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, location='json')
|
||||
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
query = args['query']
|
||||
|
||||
if not query or len(query) > 250:
|
||||
raise ValueError('Query is required and cannot exceed 250 characters')
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
query=args['query'],
|
||||
account=current_user,
|
||||
limit=10,
|
||||
retrieval_model=args['retrieval_model'],
|
||||
limit=10
|
||||
)
|
||||
|
||||
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
|
||||
|
@ -19,7 +19,7 @@ class DefaultModelApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
|
||||
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
@ -71,19 +71,18 @@ class DefaultModelApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
|
||||
parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.update_default_model_of_model_type(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=args['model_type'],
|
||||
provider_name=args['provider_name'],
|
||||
model_name=args['model_name']
|
||||
)
|
||||
model_settings = args['model_settings']
|
||||
for model_setting in model_settings:
|
||||
provider_service.update_default_model_of_model_type(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=model_setting['model_type'],
|
||||
provider_name=model_setting['provider_name'],
|
||||
model_name=model_setting['model_name']
|
||||
)
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
@ -36,6 +36,8 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
location='json')
|
||||
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
@ -95,6 +97,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
@ -14,7 +14,6 @@ from pydantic import root_validator
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
@ -60,7 +59,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
return AgentFinish(return_values={"output": ''}, log='')
|
||||
elif len(self.tools) == 1:
|
||||
tool = next(iter(self.tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||
# output = ''
|
||||
# rst_json = json.loads(rst)
|
||||
|
158
api/core/agent/agent/output_parser/retirver_dataset_agent.py
Normal file
158
api/core/agent/agent/output_parser/retirver_dataset_agent.py
Normal file
@ -0,0 +1,158 @@
|
||||
import json
|
||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
"""
|
||||
An Multi Dataset Retrieve Agent driven by Router.
|
||||
"""
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(self.tools) == 0:
|
||||
return AgentFinish(return_values={"output": ''}, log='')
|
||||
elif len(self.tools) == 1:
|
||||
tool = next(iter(self.tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||
# output = ''
|
||||
# rst_json = json.loads(rst)
|
||||
# for item in rst_json:
|
||||
# output += f'{item["content"]}\n'
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
if intermediate_steps:
|
||||
_, observation = intermediate_steps[-1]
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
try:
|
||||
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
else:
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
def real_plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
}
|
||||
)
|
||||
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
return agent_decision
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
@ -89,7 +89,6 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
return AgentFinish(return_values={"output": ''}, log='')
|
||||
elif len(self.dataset_tools) == 1:
|
||||
tool = next(iter(self.dataset_tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
|
@ -18,6 +18,7 @@ from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||
from core.helper import moderation
|
||||
from core.model_providers.error import LLMError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
@ -78,7 +79,7 @@ class AgentExecutor:
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
tools=self.configuration.tools,
|
||||
@ -86,7 +87,7 @@ class AgentExecutor:
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
tools=self.configuration.tools,
|
||||
|
@ -10,8 +10,7 @@ from models.dataset import DocumentSegment
|
||||
class DatasetIndexToolCallbackHandler:
|
||||
"""Callback handler for dataset tool."""
|
||||
|
||||
def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None:
|
||||
self.dataset_id = dataset_id
|
||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
def on_tool_end(self, documents: List[Document]) -> None:
|
||||
@ -21,7 +20,6 @@ class DatasetIndexToolCallbackHandler:
|
||||
|
||||
# add hit count to document segment
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset_id,
|
||||
DocumentSegment.index_node_id == doc_id
|
||||
).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
|
@ -127,6 +127,7 @@ class Completion:
|
||||
memory=memory,
|
||||
rest_tokens=rest_tokens_for_context_and_memory,
|
||||
chain_callback=chain_callback,
|
||||
tenant_id=app.tenant_id,
|
||||
retriever_from=retriever_from
|
||||
)
|
||||
|
||||
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import requests
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.data_loader.loader.csv_loader import CSVLoader
|
||||
@ -20,13 +20,13 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document] | str]:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
return cls.load_from_file(file_path, return_text, upload_file)
|
||||
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
|
||||
@ -44,24 +44,34 @@ class FileExtractor:
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||
upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]:
|
||||
upload_file: Optional[UploadFile] = None,
|
||||
is_automatic: bool = False) -> Union[List[Document] | str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
file_extension = input_file.suffix.lower()
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
if is_automatic:
|
||||
loader = UnstructuredFileLoader(
|
||||
file_path, strategy="hi_res", mode="elements"
|
||||
)
|
||||
# loader = UnstructuredAPIFileLoader(
|
||||
# file_path=filenames[0],
|
||||
# api_key="FAKE_API_KEY",
|
||||
# )
|
||||
else:
|
||||
# txt
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_extension == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
|
||||
|
@ -40,6 +40,13 @@ class BaseVectorIndex(BaseIndex):
|
||||
def _get_vector_store_class(self) -> type:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_full_text_index(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
|
@ -1,16 +1,14 @@
|
||||
from typing import Optional, cast
|
||||
from typing import cast, Any, List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from langchain.vectorstores import VectorStore, milvus
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.milvus_vector_store import MilvusVectorStore
|
||||
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
@ -74,7 +72,7 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
index_params = {
|
||||
'metric_type': 'IP',
|
||||
'index_type': "HNSW",
|
||||
'params': {"M": 8, "efConstruction": 64}
|
||||
'params': {"M": 8, "efConstruction": 64}
|
||||
}
|
||||
self._vector_store = MilvusVectorStore.from_documents(
|
||||
texts,
|
||||
@ -152,3 +150,7 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
|
||||
# milvus/zilliz doesn't support bm25 search
|
||||
return []
|
||||
|
@ -191,3 +191,21 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
return vector_store.similarity_search_by_bm25(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self.dataset.id),
|
||||
),
|
||||
models.FieldCondition(
|
||||
key="page_content",
|
||||
match=models.MatchText(text=query),
|
||||
)
|
||||
],
|
||||
), kwargs.get('top_k', 2))
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Optional, cast
|
||||
from typing import Optional, cast, Any, List
|
||||
|
||||
import requests
|
||||
import weaviate
|
||||
@ -26,6 +26,7 @@ class WeaviateConfig(BaseModel):
|
||||
|
||||
|
||||
class WeaviateVectorIndex(BaseVectorIndex):
|
||||
|
||||
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client = self._init_client(config)
|
||||
@ -148,3 +149,9 @@ class WeaviateVectorIndex(BaseVectorIndex):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
|
||||
|
||||
|
@ -49,14 +49,14 @@ class IndexingRunner:
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document)
|
||||
|
||||
# get the process rule
|
||||
processing_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||
first()
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document)
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
|
||||
@ -380,7 +380,7 @@ class IndexingRunner:
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
|
||||
def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]:
|
||||
# load file
|
||||
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
|
||||
return []
|
||||
@ -396,7 +396,7 @@ class IndexingRunner:
|
||||
one_or_none()
|
||||
|
||||
if file_detail:
|
||||
text_docs = FileExtractor.load(file_detail)
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=False)
|
||||
elif dataset_document.data_source_type == 'notion_import':
|
||||
loader = NotionLoader.from_document(dataset_document)
|
||||
text_docs = loader.load()
|
||||
|
@ -9,6 +9,7 @@ from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.moderation.base import BaseModeration
|
||||
from core.model_providers.models.reranking.base import BaseReranking
|
||||
from core.model_providers.models.speech2text.base import BaseSpeech2Text
|
||||
from extensions.ext_database import db
|
||||
from models.provider import TenantDefaultModel
|
||||
@ -140,6 +141,44 @@ class ModelFactory:
|
||||
name=model_name
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_reranking_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None) -> Optional[BaseReranking]:
|
||||
"""
|
||||
get reranking model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.RERANKING)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default Reranking Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init reranking model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.RERANKING)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_speech2text_model(cls,
|
||||
tenant_id: str,
|
||||
|
@ -72,6 +72,9 @@ class ModelProviderFactory:
|
||||
elif provider_name == 'localai':
|
||||
from core.model_providers.providers.localai_provider import LocalAIProvider
|
||||
return LocalAIProvider
|
||||
elif provider_name == 'cohere':
|
||||
from core.model_providers.providers.cohere_provider import CohereProvider
|
||||
return CohereProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -17,7 +17,7 @@ class ModelType(enum.Enum):
|
||||
IMAGE = 'image'
|
||||
VIDEO = 'video'
|
||||
MODERATION = 'moderation'
|
||||
|
||||
RERANKING = 'reranking'
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ModelType:
|
||||
|
36
api/core/model_providers/models/reranking/base.py
Normal file
36
api/core/model_providers/models/reranking/base.py
Normal file
@ -0,0 +1,36 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Optional, List
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseReranking(BaseProviderModel):
|
||||
name: str
|
||||
type: ModelType = ModelType.RERANKING
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
|
||||
super().__init__(model_provider, client)
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@abstractmethod
|
||||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
raise NotImplementedError
|
@ -0,0 +1,73 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
import cohere
|
||||
import openai
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.reranking.base import BaseReranking
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class CohereReranking(BaseReranking):
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
self.credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = cohere.Client(self.credentials.get('api_key'))
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
||||
docs = []
|
||||
doc_id = []
|
||||
for document in documents:
|
||||
if document.metadata['doc_id'] not in doc_id:
|
||||
doc_id.append(document.metadata['doc_id'])
|
||||
docs.append(document.page_content)
|
||||
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
|
||||
rerank_documents = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
# format document
|
||||
rerank_document = Document(
|
||||
page_content=result.document['text'],
|
||||
metadata={
|
||||
"doc_id": documents[result.index].metadata['doc_id'],
|
||||
"doc_hash": documents[result.index].metadata['doc_hash'],
|
||||
"document_id": documents[result.index].metadata['document_id'],
|
||||
"dataset_id": documents[result.index].metadata['dataset_id'],
|
||||
'score': result.relevance_score
|
||||
}
|
||||
)
|
||||
# score threshold check
|
||||
if score_threshold is not None:
|
||||
if result.relevance_score >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
else:
|
||||
rerank_documents.append(rerank_document)
|
||||
return rerank_documents
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
152
api/core/model_providers/providers/cohere_provider.py
Normal file
152
api/core/model_providers/providers/cohere_provider.py
Normal file
@ -0,0 +1,152 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.reranking.cohere_reranking import CohereReranking
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class CohereProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'cohere'
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.RERANKING:
|
||||
return [
|
||||
{
|
||||
'id': 'rerank-english-v2.0',
|
||||
'name': 'rerank-english-v2.0'
|
||||
},
|
||||
{
|
||||
'id': 'rerank-multilingual-v2.0',
|
||||
'name': 'rerank-multilingual-v2.0'
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.RERANKING:
|
||||
model_class = CohereReranking
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](enabled=False),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Cohere api_key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'api_key': credentials['api_key'],
|
||||
}
|
||||
# todo validate
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'api_key': None,
|
||||
}
|
||||
|
||||
if credentials['api_key']:
|
||||
credentials['api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||
|
||||
return credentials
|
||||
else:
|
||||
return {}
|
||||
|
||||
def should_deduct_quota(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
@ -13,5 +13,6 @@
|
||||
"huggingface_hub",
|
||||
"xinference",
|
||||
"openllm",
|
||||
"localai"
|
||||
"localai",
|
||||
"cohere"
|
||||
]
|
||||
|
7
api/core/model_providers/rules/cohere.json
Normal file
7
api/core/model_providers/rules/cohere.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "fixed"
|
||||
}
|
@ -1,11 +1,17 @@
|
||||
from typing import Optional
|
||||
import json
|
||||
import threading
|
||||
from typing import Optional, List
|
||||
|
||||
from flask import Flask
|
||||
from langchain import WikipediaAPIWrapper
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
@ -17,6 +23,7 @@ from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.current_datetime_tool import DatetimeTool
|
||||
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
|
||||
@ -25,6 +32,16 @@ from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.model import AppModelConfig
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
|
||||
class OrchestratorRuleParser:
|
||||
"""Parse the orchestrator rule to entities."""
|
||||
@ -34,7 +51,7 @@ class OrchestratorRuleParser:
|
||||
self.app_model_config = app_model_config
|
||||
|
||||
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
|
||||
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
|
||||
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, tenant_id: str,
|
||||
retriever_from: str = 'dev') -> Optional[AgentExecutor]:
|
||||
if not self.app_model_config.agent_mode_dict:
|
||||
return None
|
||||
@ -101,7 +118,8 @@ class OrchestratorRuleParser:
|
||||
rest_tokens=rest_tokens,
|
||||
return_resource=return_resource,
|
||||
retriever_from=retriever_from,
|
||||
dataset_configs=dataset_configs
|
||||
dataset_configs=dataset_configs,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if len(tools) == 0:
|
||||
@ -123,7 +141,7 @@ class OrchestratorRuleParser:
|
||||
|
||||
return chain
|
||||
|
||||
def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
|
||||
def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
|
||||
"""
|
||||
Convert app agent tool configs to tools
|
||||
|
||||
@ -132,6 +150,7 @@ class OrchestratorRuleParser:
|
||||
:return:
|
||||
"""
|
||||
tools = []
|
||||
dataset_tools = []
|
||||
for tool_config in tool_configs:
|
||||
tool_type = list(tool_config.keys())[0]
|
||||
tool_val = list(tool_config.values())[0]
|
||||
@ -140,7 +159,7 @@ class OrchestratorRuleParser:
|
||||
|
||||
tool = None
|
||||
if tool_type == "dataset":
|
||||
tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
|
||||
dataset_tools.append(tool_config)
|
||||
elif tool_type == "web_reader":
|
||||
tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
|
||||
elif tool_type == "google_search":
|
||||
@ -156,57 +175,81 @@ class OrchestratorRuleParser:
|
||||
else:
|
||||
tool.callbacks = callbacks
|
||||
tools.append(tool)
|
||||
|
||||
# format dataset tool
|
||||
if len(dataset_tools) > 0:
|
||||
dataset_retriever_tools = self.to_dataset_retriever_tool(tool_configs=dataset_tools, **kwargs)
|
||||
if dataset_retriever_tools:
|
||||
tools.extend(dataset_retriever_tools)
|
||||
return tools
|
||||
|
||||
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
|
||||
dataset_configs: dict, rest_tokens: int,
|
||||
def to_dataset_retriever_tool(self, tool_configs: List, conversation_message_task: ConversationMessageTask,
|
||||
return_resource: bool = False, retriever_from: str = 'dev',
|
||||
**kwargs) \
|
||||
-> Optional[BaseTool]:
|
||||
-> Optional[List[BaseTool]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param rest_tokens:
|
||||
:param tool_config:
|
||||
:param dataset_configs:
|
||||
:param tool_configs:
|
||||
:param conversation_message_task:
|
||||
:param return_resource:
|
||||
:param retriever_from:
|
||||
:return:
|
||||
"""
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == tool_config.get("id")
|
||||
).first()
|
||||
dataset_configs = kwargs['dataset_configs']
|
||||
retrieval_model = dataset_configs.get('retrieval_model', 'single')
|
||||
tools = []
|
||||
dataset_ids = []
|
||||
tenant_id = None
|
||||
for tool_config in tool_configs:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == tool_config.get('dataset').get("id")
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return None
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
|
||||
return None
|
||||
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
|
||||
return None
|
||||
dataset_ids.append(dataset.id)
|
||||
if retrieval_model == 'single':
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
top_k = retrieval_model['top_k']
|
||||
|
||||
top_k = dataset_configs.get("top_k", 2)
|
||||
# dynamically adjust top_k when the remaining token number is not enough to support top_k
|
||||
# top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
|
||||
|
||||
# dynamically adjust top_k when the remaining token number is not enough to support top_k
|
||||
top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
|
||||
score_threshold = None
|
||||
score_threshold_enable = retrieval_model.get("score_threshold_enable")
|
||||
if score_threshold_enable:
|
||||
score_threshold = retrieval_model.get("score_threshold")
|
||||
|
||||
score_threshold = None
|
||||
score_threshold_config = dataset_configs.get("score_threshold")
|
||||
if score_threshold_config and score_threshold_config.get("enable"):
|
||||
score_threshold = score_threshold_config.get("value")
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
|
||||
conversation_message_task=conversation_message_task,
|
||||
return_resource=return_resource,
|
||||
retriever_from=retriever_from
|
||||
)
|
||||
tools.append(tool)
|
||||
if retrieval_model == 'multiple':
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=dataset_ids,
|
||||
tenant_id=kwargs['tenant_id'],
|
||||
top_k=dataset_configs.get('top_k', 2),
|
||||
score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None,
|
||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
|
||||
conversation_message_task=conversation_message_task,
|
||||
return_resource=return_resource,
|
||||
retriever_from=retriever_from,
|
||||
reranking_provider_name=dataset_configs.get('reranking_model').get('reranking_provider_name'),
|
||||
reranking_model_name=dataset_configs.get('reranking_model').get('reranking_model_name')
|
||||
)
|
||||
tools.append(tool)
|
||||
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
|
||||
conversation_message_task=conversation_message_task,
|
||||
return_resource=return_resource,
|
||||
retriever_from=retriever_from
|
||||
)
|
||||
|
||||
return tool
|
||||
return tools
|
||||
|
||||
def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
|
||||
"""
|
||||
|
227
api/core/tool/dataset_multi_retriever_tool.py
Normal file
227
api/core/tool/dataset_multi_retriever_tool.py
Normal file
@ -0,0 +1,227 @@
|
||||
import json
|
||||
import threading
|
||||
from typing import Type, Optional, List
|
||||
|
||||
from flask import current_app, Flask
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment, Document
|
||||
from services.retrieval_service import RetrievalService
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
|
||||
|
||||
class DatasetMultiRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="dataset multi retriever and rerank")
|
||||
|
||||
|
||||
class DatasetMultiRetrieverTool(BaseTool):
|
||||
"""Tool for querying multi dataset."""
|
||||
name: str = "dataset-"
|
||||
args_schema: Type[BaseModel] = DatasetMultiRetrieverToolInput
|
||||
description: str = "dataset multi retriever and rerank. "
|
||||
tenant_id: str
|
||||
dataset_ids: List[str]
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
conversation_message_task: ConversationMessageTask
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset_ids: List[str], tenant_id: str, **kwargs):
|
||||
return cls(
|
||||
name=f'dataset-{tenant_id}',
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=dataset_ids,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
threads = []
|
||||
all_documents = []
|
||||
for dataset_id in self.dataset_ids:
|
||||
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'all_documents': all_documents
|
||||
})
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# do rerank for searched documents
|
||||
rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=self.tenant_id,
|
||||
model_provider_name=self.reranking_provider_name,
|
||||
model_name=self.reranking_model_name
|
||||
)
|
||||
all_documents = rerank.rerank(query, all_documents, self.score_threshold, self.top_k)
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
|
||||
hit_callback.on_tool_end(all_documents)
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(segments,
|
||||
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
|
||||
float('inf')))
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
|
||||
else:
|
||||
document_context_list.append(segment.content)
|
||||
if self.return_resource:
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=segment.dataset_id
|
||||
).first()
|
||||
document = Document.query.filter(Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
'position': resource_number,
|
||||
'dataset_id': dataset.id,
|
||||
'dataset_name': dataset.name,
|
||||
'document_id': document.id,
|
||||
'document_name': document.name,
|
||||
'data_source_type': document.data_source_type,
|
||||
'segment_id': segment.id,
|
||||
'retriever_from': self.retriever_from
|
||||
}
|
||||
if self.retriever_from == 'dev':
|
||||
source['hit_count'] = segment.hit_count
|
||||
source['word_count'] = segment.word_count
|
||||
source['segment_position'] = segment.position
|
||||
source['index_node_hash'] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
||||
else:
|
||||
source['content'] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: List):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
kw_table_index = KeywordTableIndex(
|
||||
dataset=dataset,
|
||||
config=KeywordTableConfig(
|
||||
max_keywords_per_chunk=5
|
||||
)
|
||||
)
|
||||
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
return []
|
||||
except ProviderTokenNotInitError:
|
||||
return []
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
documents = []
|
||||
threads = []
|
||||
if self.top_k > 0:
|
||||
# retrieval_model source with semantic
|
||||
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[
|
||||
'search_method'] == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'query': query,
|
||||
'top_k': self.top_k,
|
||||
'score_threshold': self.score_threshold,
|
||||
'reranking_model': None,
|
||||
'all_documents': documents,
|
||||
'search_method': 'hybrid_search',
|
||||
'embeddings': embeddings
|
||||
})
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval_model source with full text
|
||||
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[
|
||||
'search_method'] == 'hybrid_search':
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
|
||||
kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'query': query,
|
||||
'search_method': 'hybrid_search',
|
||||
'embeddings': embeddings,
|
||||
'score_threshold': retrieval_model[
|
||||
'score_threshold'] if retrieval_model[
|
||||
'score_threshold_enable'] else None,
|
||||
'top_k': self.top_k,
|
||||
'reranking_model': retrieval_model[
|
||||
'reranking_model'] if retrieval_model[
|
||||
'reranking_enable'] else None,
|
||||
'all_documents': documents
|
||||
})
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
all_documents.extend(documents)
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from typing import Type, Optional
|
||||
import threading
|
||||
from typing import Type, Optional, List
|
||||
|
||||
from flask import current_app
|
||||
from langchain.tools import BaseTool
|
||||
@ -14,6 +15,18 @@ from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitE
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment, Document
|
||||
from services.retrieval_service import RetrievalService
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
|
||||
|
||||
class DatasetRetrieverToolInput(BaseModel):
|
||||
@ -56,7 +69,9 @@ class DatasetRetrieverTool(BaseTool):
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
return f'[{self.name} failed to find dataset with id {self.dataset_id}.]'
|
||||
return ''
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
@ -83,28 +98,62 @@ class DatasetRetrieverTool(BaseTool):
|
||||
return ''
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = []
|
||||
threads = []
|
||||
if self.top_k > 0:
|
||||
documents = vector_index.search(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': self.top_k,
|
||||
'score_threshold': self.score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
# retrieval source with semantic
|
||||
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'query': query,
|
||||
'top_k': self.top_k,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
||||
'score_threshold_enable'] else None,
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
|
||||
'reranking_enable'] else None,
|
||||
'all_documents': documents,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings
|
||||
})
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval_model source with full text
|
||||
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'query': query,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
||||
'score_threshold_enable'] else None,
|
||||
'top_k': self.top_k,
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
|
||||
'reranking_enable'] else None,
|
||||
'all_documents': documents
|
||||
})
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# hybrid search: rerank after all documents have been searched
|
||||
if retrieval_model['search_method'] == 'hybrid_search':
|
||||
hybrid_rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_name=retrieval_model['reranking_model']['reranking_model_name']
|
||||
)
|
||||
documents = hybrid_rerank.rerank(query, documents,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
self.top_k)
|
||||
else:
|
||||
documents = []
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task)
|
||||
hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
|
@ -1,4 +1,4 @@
|
||||
from core.index.vector_index.milvus import Milvus
|
||||
from core.vector_store.vector.milvus import Milvus
|
||||
|
||||
|
||||
class MilvusVectorStore(Milvus):
|
||||
|
@ -4,7 +4,7 @@ from langchain.schema import Document
|
||||
from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
|
||||
from core.index.vector_index.qdrant import Qdrant
|
||||
from core.vector_store.vector.qdrant import Qdrant
|
||||
|
||||
|
||||
class QdrantVectorStore(Qdrant):
|
||||
@ -73,3 +73,4 @@ class QdrantVectorStore(Qdrant):
|
||||
if isinstance(self.client, QdrantLocal):
|
||||
self.client = cast(QdrantLocal, self.client)
|
||||
self.client._load()
|
||||
|
||||
|
@ -28,7 +28,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
from qdrant_client.http.models import PayloadSchemaType, FilterSelector, TextIndexParams, TokenizerType, TextIndexType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
@ -189,14 +189,25 @@ class Qdrant(VectorStore):
|
||||
texts, metadatas, ids, batch_size
|
||||
):
|
||||
self.client.upsert(
|
||||
collection_name=self.collection_name, points=points, **kwargs
|
||||
collection_name=self.collection_name, points=points
|
||||
)
|
||||
added_ids.extend(batch_ids)
|
||||
# if is new collection, create payload index on group_id
|
||||
if self.is_new_collection:
|
||||
# create payload index
|
||||
self.client.create_payload_index(self.collection_name, self.group_payload_key,
|
||||
field_schema=PayloadSchemaType.KEYWORD,
|
||||
field_type=PayloadSchemaType.KEYWORD)
|
||||
# creat full text index
|
||||
text_index_params = TextIndexParams(
|
||||
type=TextIndexType.TEXT,
|
||||
tokenizer=TokenizerType.MULTILINGUAL,
|
||||
min_token_len=2,
|
||||
max_token_len=20,
|
||||
lowercase=True
|
||||
)
|
||||
self.client.create_payload_index(self.collection_name, self.content_payload_key,
|
||||
field_schema=text_index_params)
|
||||
return added_ids
|
||||
|
||||
@sync_call_fallback
|
||||
@ -600,7 +611,7 @@ class Qdrant(VectorStore):
|
||||
limit=k,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
with_vectors=True, # Langchain does not expect vectors to be returned
|
||||
with_vectors=True,
|
||||
score_threshold=score_threshold,
|
||||
consistency=consistency,
|
||||
**kwargs,
|
||||
@ -615,6 +626,39 @@ class Qdrant(VectorStore):
|
||||
for result in results
|
||||
]
|
||||
|
||||
def similarity_search_by_bm25(
|
||||
self,
|
||||
filter: Optional[MetadataFilter] = None,
|
||||
k: int = 4
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar by bm25.
|
||||
|
||||
Args:
|
||||
embedding: Embedding vector to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: Filter by metadata. Defaults to None.
|
||||
search_params: Additional search params
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance for each.
|
||||
"""
|
||||
response = self.client.scroll(
|
||||
collection_name=self.collection_name,
|
||||
scroll_filter=filter,
|
||||
limit=k,
|
||||
with_payload=True,
|
||||
with_vectors=True
|
||||
|
||||
)
|
||||
results = response[0]
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
documents.append(self._document_from_scored_point(
|
||||
result, self.content_payload_key, self.metadata_payload_key
|
||||
))
|
||||
|
||||
return documents
|
||||
|
||||
@sync_call_fallback
|
||||
async def asimilarity_search_with_score_by_vector(
|
||||
self,
|
505
api/core/vector_store/vector/weaviate.py
Normal file
505
api/core/vector_store/vector/weaviate.py
Normal file
@ -0,0 +1,505 @@
|
||||
"""Wrapper around weaviate vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
|
||||
def _default_schema(index_name: str) -> Dict:
|
||||
return {
|
||||
"class": index_name,
|
||||
"properties": [
|
||||
{
|
||||
"name": "text",
|
||||
"dataType": ["text"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _create_weaviate_client(**kwargs: Any) -> Any:
|
||||
client = kwargs.get("client")
|
||||
if client is not None:
|
||||
return client
|
||||
|
||||
weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
|
||||
|
||||
try:
|
||||
# the weaviate api key param should not be mandatory
|
||||
weaviate_api_key = get_from_dict_or_env(
|
||||
kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
|
||||
)
|
||||
except ValueError:
|
||||
weaviate_api_key = None
|
||||
|
||||
try:
|
||||
import weaviate
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import weaviate python package. "
|
||||
"Please install it with `pip install weaviate-client`"
|
||||
)
|
||||
|
||||
auth = (
|
||||
weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
|
||||
if weaviate_api_key is not None
|
||||
else None
|
||||
)
|
||||
client = weaviate.Client(weaviate_url, auth_client_secret=auth)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def _default_score_normalizer(val: float) -> float:
|
||||
return 1 - 1 / (1 + np.exp(val))
|
||||
|
||||
|
||||
def _json_serializable(value: Any) -> Any:
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value.isoformat()
|
||||
return value
|
||||
|
||||
|
||||
class Weaviate(VectorStore):
|
||||
"""Wrapper around Weaviate vector database.
|
||||
|
||||
To use, you should have the ``weaviate-client`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
import weaviate
|
||||
from langchain.vectorstores import Weaviate
|
||||
client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
|
||||
weaviate = Weaviate(client, index_name, text_key)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Any,
|
||||
index_name: str,
|
||||
text_key: str,
|
||||
embedding: Optional[Embeddings] = None,
|
||||
attributes: Optional[List[str]] = None,
|
||||
relevance_score_fn: Optional[
|
||||
Callable[[float], float]
|
||||
] = _default_score_normalizer,
|
||||
by_text: bool = True,
|
||||
):
|
||||
"""Initialize with Weaviate client."""
|
||||
try:
|
||||
import weaviate
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import weaviate python package. "
|
||||
"Please install it with `pip install weaviate-client`."
|
||||
)
|
||||
if not isinstance(client, weaviate.Client):
|
||||
raise ValueError(
|
||||
f"client should be an instance of weaviate.Client, got {type(client)}"
|
||||
)
|
||||
self._client = client
|
||||
self._index_name = index_name
|
||||
self._embedding = embedding
|
||||
self._text_key = text_key
|
||||
self._query_attrs = [self._text_key]
|
||||
self.relevance_score_fn = relevance_score_fn
|
||||
self._by_text = by_text
|
||||
if attributes is not None:
|
||||
self._query_attrs.extend(attributes)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return self._embedding
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return (
|
||||
self.relevance_score_fn
|
||||
if self.relevance_score_fn
|
||||
else _default_score_normalizer
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Upload texts with metadata (properties) to Weaviate."""
|
||||
from weaviate.util import get_valid_uuid
|
||||
|
||||
ids = []
|
||||
embeddings: Optional[List[List[float]]] = None
|
||||
if self._embedding:
|
||||
if not isinstance(texts, list):
|
||||
texts = list(texts)
|
||||
embeddings = self._embedding.embed_documents(texts)
|
||||
|
||||
with self._client.batch as batch:
|
||||
for i, text in enumerate(texts):
|
||||
data_properties = {self._text_key: text}
|
||||
if metadatas is not None:
|
||||
for key, val in metadatas[i].items():
|
||||
data_properties[key] = _json_serializable(val)
|
||||
|
||||
# Allow for ids (consistent w/ other methods)
|
||||
# # Or uuids (backwards compatble w/ existing arg)
|
||||
# If the UUID of one of the objects already exists
|
||||
# then the existing object will be replaced by the new object.
|
||||
_id = get_valid_uuid(uuid4())
|
||||
if "uuids" in kwargs:
|
||||
_id = kwargs["uuids"][i]
|
||||
elif "ids" in kwargs:
|
||||
_id = kwargs["ids"][i]
|
||||
|
||||
batch.add_data_object(
|
||||
data_object=data_properties,
|
||||
class_name=self._index_name,
|
||||
uuid=_id,
|
||||
vector=embeddings[i] if embeddings else None,
|
||||
)
|
||||
ids.append(_id)
|
||||
return ids
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
if self._by_text:
|
||||
return self.similarity_search_by_text(query, k, **kwargs)
|
||||
else:
|
||||
if self._embedding is None:
|
||||
raise ValueError(
|
||||
"_embedding cannot be None for similarity_search when "
|
||||
"_by_text=False"
|
||||
)
|
||||
embedding = self._embedding.embed_query(query)
|
||||
return self.similarity_search_by_vector(embedding, k, **kwargs)
|
||||
|
||||
def similarity_search_by_text(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
content: Dict[str, Any] = {"concepts": [query]}
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
result = query_obj.with_near_text(content).with_limit(k).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][self._index_name]:
|
||||
text = res.pop(self._text_key)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
||||
|
||||
def similarity_search_by_bm25(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs using BM25F.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
content: Dict[str, Any] = {"concepts": [query]}
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
result = query_obj.with_bm25(query=content).with_limit(k).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][self._index_name]:
|
||||
text = res.pop(self._text_key)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
vector = {"vector": embedding}
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
result = query_obj.with_near_vector(vector).with_limit(k).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][self._index_name]:
|
||||
text = res.pop(self._text_key)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
if self._embedding is not None:
|
||||
embedding = self._embedding.embed_query(query)
|
||||
else:
|
||||
raise ValueError(
|
||||
"max_marginal_relevance_search requires a suitable Embeddings object"
|
||||
)
|
||||
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
vector = {"vector": embedding}
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
results = (
|
||||
query_obj.with_additional("vector")
|
||||
.with_near_vector(vector)
|
||||
.with_limit(fetch_k)
|
||||
.do()
|
||||
)
|
||||
|
||||
payload = results["data"]["Get"][self._index_name]
|
||||
embeddings = [result["_additional"]["vector"] for result in payload]
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
||||
)
|
||||
|
||||
docs = []
|
||||
for idx in mmr_selected:
|
||||
text = payload[idx].pop(self._text_key)
|
||||
payload[idx].pop("_additional")
|
||||
meta = payload[idx]
|
||||
docs.append(Document(page_content=text, metadata=meta))
|
||||
return docs
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""
|
||||
Return list of documents most similar to the query
|
||||
text and cosine distance in float for each.
|
||||
Lower score represents more similarity.
|
||||
"""
|
||||
if self._embedding is None:
|
||||
raise ValueError(
|
||||
"_embedding cannot be None for similarity_search_with_score"
|
||||
)
|
||||
content: Dict[str, Any] = {"concepts": [query]}
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
|
||||
embedded_query = self._embedding.embed_query(query)
|
||||
if not self._by_text:
|
||||
vector = {"vector": embedded_query}
|
||||
result = (
|
||||
query_obj.with_near_vector(vector)
|
||||
.with_limit(k)
|
||||
.with_additional("vector")
|
||||
.do()
|
||||
)
|
||||
else:
|
||||
result = (
|
||||
query_obj.with_near_text(content)
|
||||
.with_limit(k)
|
||||
.with_additional("vector")
|
||||
.do()
|
||||
)
|
||||
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
|
||||
docs_and_scores = []
|
||||
for res in result["data"]["Get"][self._index_name]:
|
||||
text = res.pop(self._text_key)
|
||||
score = np.dot(res["_additional"]["vector"], embedded_query)
|
||||
docs_and_scores.append((Document(page_content=text, metadata=res), score))
|
||||
return docs_and_scores
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[Weaviate],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Weaviate:
|
||||
"""Construct Weaviate wrapper from raw documents.
|
||||
|
||||
This is a user-friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates a new index for the embeddings in the Weaviate instance.
|
||||
3. Adds the documents to the newly created Weaviate index.
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores.weaviate import Weaviate
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
weaviate = Weaviate.from_texts(
|
||||
texts,
|
||||
embeddings,
|
||||
weaviate_url="http://localhost:8080"
|
||||
)
|
||||
"""
|
||||
|
||||
client = _create_weaviate_client(**kwargs)
|
||||
|
||||
from weaviate.util import get_valid_uuid
|
||||
|
||||
index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
|
||||
embeddings = embedding.embed_documents(texts) if embedding else None
|
||||
text_key = "text"
|
||||
schema = _default_schema(index_name)
|
||||
attributes = list(metadatas[0].keys()) if metadatas else None
|
||||
|
||||
# check whether the index already exists
|
||||
if not client.schema.contains(schema):
|
||||
client.schema.create_class(schema)
|
||||
|
||||
with client.batch as batch:
|
||||
for i, text in enumerate(texts):
|
||||
data_properties = {
|
||||
text_key: text,
|
||||
}
|
||||
if metadatas is not None:
|
||||
for key in metadatas[i].keys():
|
||||
data_properties[key] = metadatas[i][key]
|
||||
|
||||
# If the UUID of one of the objects already exists
|
||||
# then the existing objectwill be replaced by the new object.
|
||||
if "uuids" in kwargs:
|
||||
_id = kwargs["uuids"][i]
|
||||
else:
|
||||
_id = get_valid_uuid(uuid4())
|
||||
|
||||
# if an embedding strategy is not provided, we let
|
||||
# weaviate create the embedding. Note that this will only
|
||||
# work if weaviate has been installed with a vectorizer module
|
||||
# like text2vec-contextionary for example
|
||||
params = {
|
||||
"uuid": _id,
|
||||
"data_object": data_properties,
|
||||
"class_name": index_name,
|
||||
}
|
||||
if embeddings is not None:
|
||||
params["vector"] = embeddings[i]
|
||||
|
||||
batch.add_data_object(**params)
|
||||
|
||||
batch.flush()
|
||||
|
||||
relevance_score_fn = kwargs.get("relevance_score_fn")
|
||||
by_text: bool = kwargs.get("by_text", False)
|
||||
|
||||
return cls(
|
||||
client,
|
||||
index_name,
|
||||
text_key,
|
||||
embedding=embedding,
|
||||
attributes=attributes,
|
||||
relevance_score_fn=relevance_score_fn,
|
||||
by_text=by_text,
|
||||
)
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
|
||||
"""Delete by vector IDs.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
"""
|
||||
|
||||
if ids is None:
|
||||
raise ValueError("No ids provided to delete.")
|
||||
|
||||
# TODO: Check if this can be done in bulk
|
||||
for id in ids:
|
||||
self._client.data_object.delete(uuid=id)
|
@ -12,6 +12,21 @@ dataset_fields = {
|
||||
'created_at': TimestampField,
|
||||
}
|
||||
|
||||
reranking_model_fields = {
|
||||
'reranking_provider_name': fields.String,
|
||||
'reranking_model_name': fields.String
|
||||
}
|
||||
|
||||
dataset_retrieval_model_fields = {
|
||||
'search_method': fields.String,
|
||||
'reranking_enable': fields.Boolean,
|
||||
'reranking_model': fields.Nested(reranking_model_fields),
|
||||
'top_k': fields.Integer,
|
||||
'score_threshold_enable': fields.Boolean,
|
||||
'score_threshold': fields.Float
|
||||
}
|
||||
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
@ -29,7 +44,8 @@ dataset_detail_fields = {
|
||||
'updated_at': TimestampField,
|
||||
'embedding_model': fields.String,
|
||||
'embedding_model_provider': fields.String,
|
||||
'embedding_available': fields.Boolean
|
||||
'embedding_available': fields.Boolean,
|
||||
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields)
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
@ -41,3 +57,5 @@ dataset_query_detail_fields = {
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField
|
||||
}
|
||||
|
||||
|
||||
|
@ -0,0 +1,43 @@
|
||||
"""add-dataset-retrival-model
|
||||
|
||||
Revision ID: fca025d3b60f
|
||||
Revises: b3a09c049e8e
|
||||
Create Date: 2023-11-03 13:08:23.246396
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'fca025d3b60f'
|
||||
down_revision = '8fe468ba0ca5'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('sessions')
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('retrieval_model', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
|
||||
batch_op.create_index('retrieval_model_idx', ['retrieval_model'], unique=False, postgresql_using='gin')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_index('retrieval_model_idx', postgresql_using='gin')
|
||||
batch_op.drop_column('retrieval_model')
|
||||
|
||||
op.create_table('sessions',
|
||||
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('session_id', sa.VARCHAR(length=255), autoincrement=False, nullable=True),
|
||||
sa.Column('data', postgresql.BYTEA(), autoincrement=False, nullable=True),
|
||||
sa.Column('expiry', postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='sessions_pkey'),
|
||||
sa.UniqueConstraint('session_id', name='sessions_session_id_key')
|
||||
)
|
||||
# ### end Alembic commands ###
|
@ -3,7 +3,7 @@ import pickle
|
||||
from json import JSONDecodeError
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
@ -15,6 +15,7 @@ class Dataset(db.Model):
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='dataset_pkey'),
|
||||
db.Index('dataset_tenant_idx', 'tenant_id'),
|
||||
db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin')
|
||||
)
|
||||
|
||||
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy']
|
||||
@ -39,7 +40,7 @@ class Dataset(db.Model):
|
||||
embedding_model = db.Column(db.String(255), nullable=True)
|
||||
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
||||
collection_binding_id = db.Column(UUID, nullable=True)
|
||||
|
||||
retrieval_model = db.Column(JSONB, nullable=True)
|
||||
|
||||
@property
|
||||
def dataset_keyword_table(self):
|
||||
@ -93,6 +94,20 @@ class Dataset(db.Model):
|
||||
return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
|
||||
.filter(Document.dataset_id == self.id).scalar()
|
||||
|
||||
@property
|
||||
def retrieval_model_dict(self):
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
return self.retrieval_model if self.retrieval_model else default_retrieval_model
|
||||
|
||||
|
||||
class DatasetProcessRule(db.Model):
|
||||
__tablename__ = 'dataset_process_rules'
|
||||
@ -120,7 +135,7 @@ class DatasetProcessRule(db.Model):
|
||||
],
|
||||
'segmentation': {
|
||||
'delimiter': '\n',
|
||||
'max_tokens': 1000
|
||||
'max_tokens': 512
|
||||
}
|
||||
}
|
||||
|
||||
@ -462,4 +477,3 @@ class DatasetCollectionBinding(db.Model):
|
||||
model_name = db.Column(db.String(40), nullable=False)
|
||||
collection_name = db.Column(db.String(64), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
@ -160,7 +160,13 @@ class AppModelConfig(db.Model):
|
||||
|
||||
@property
|
||||
def dataset_configs_dict(self) -> dict:
|
||||
return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
|
||||
if self.dataset_configs:
|
||||
dataset_configs = json.loads(self.dataset_configs)
|
||||
if 'retrieval_model' not in dataset_configs:
|
||||
return {'retrieval_model': 'single'}
|
||||
else:
|
||||
return dataset_configs
|
||||
return {'retrieval_model': 'single'}
|
||||
|
||||
@property
|
||||
def file_upload_dict(self) -> dict:
|
||||
|
@ -23,7 +23,6 @@ boto3==1.28.17
|
||||
tenacity==8.2.2
|
||||
cachetools~=5.3.0
|
||||
weaviate-client~=3.21.0
|
||||
qdrant_client~=1.1.6
|
||||
mailchimp-transactional~=1.0.50
|
||||
scikit-learn==1.2.2
|
||||
sentry-sdk[flask]~=1.21.1
|
||||
@ -53,4 +52,6 @@ xinference-client~=0.5.4
|
||||
safetensors==0.3.2
|
||||
zhipuai==1.0.7
|
||||
werkzeug==2.3.7
|
||||
pymilvus==2.3.0
|
||||
pymilvus==2.3.0
|
||||
qdrant-client==1.6.4
|
||||
cohere~=4.32
|
@ -470,7 +470,16 @@ class AppModelConfigService:
|
||||
|
||||
# dataset_configs
|
||||
if 'dataset_configs' not in config or not config["dataset_configs"]:
|
||||
config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
|
||||
config["dataset_configs"] = {'retrieval_model': 'single'}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if config["dataset_configs"]['retrieval_model'] == 'multiple':
|
||||
if not config["dataset_configs"]['reranking_model']:
|
||||
raise ValueError("reranking_model has not been set")
|
||||
if not isinstance(config["dataset_configs"]['reranking_model'], dict):
|
||||
raise ValueError("reranking_model must be of object type")
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
@ -173,6 +173,9 @@ class DatasetService:
|
||||
filtered_data['updated_by'] = user.id
|
||||
filtered_data['updated_at'] = datetime.datetime.now()
|
||||
|
||||
# update Retrieval model
|
||||
filtered_data['retrieval_model'] = data['retrieval_model']
|
||||
|
||||
dataset.query.filter_by(id=dataset_id).update(filtered_data)
|
||||
|
||||
db.session.commit()
|
||||
@ -473,7 +476,19 @@ class DocumentService:
|
||||
embedding_model.name
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
|
||||
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
|
||||
|
||||
documents = []
|
||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||
@ -733,6 +748,7 @@ class DocumentService:
|
||||
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
||||
embedding_model = None
|
||||
dataset_collection_binding_id = None
|
||||
retrieval_model = None
|
||||
if document_data['indexing_technique'] == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
@ -742,6 +758,20 @@ class DocumentService:
|
||||
embedding_model.name
|
||||
)
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
if 'retrieval_model' in document_data and document_data['retrieval_model']:
|
||||
retrieval_model = document_data['retrieval_model']
|
||||
else:
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
retrieval_model = default_retrieval_model
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
@ -751,7 +781,8 @@ class DocumentService:
|
||||
created_by=account.id,
|
||||
embedding_model=embedding_model.name if embedding_model else None,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
|
||||
collection_binding_id=dataset_collection_binding_id
|
||||
collection_binding_id=dataset_collection_binding_id,
|
||||
retrieval_model=retrieval_model
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
@ -768,7 +799,7 @@ class DocumentService:
|
||||
return dataset, documents, batch
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if 'original_document_id' not in args or not args['original_document_id']:
|
||||
DocumentService.data_source_args_validate(args)
|
||||
DocumentService.process_rule_args_validate(args)
|
||||
|
@ -1,4 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
@ -9,16 +11,26 @@ from langchain.schema import Document
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DocumentSegment, DatasetQuery
|
||||
from services.retrieval_service import RetrievalService
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
|
||||
class HitTestingService:
|
||||
@classmethod
|
||||
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
|
||||
def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
|
||||
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return {
|
||||
"query": {
|
||||
@ -28,31 +40,68 @@ class HitTestingService:
|
||||
"records": []
|
||||
}
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
if not retrieval_model:
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get embedding model
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
all_documents = []
|
||||
threads = []
|
||||
|
||||
# retrieval_model source with semantic
|
||||
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'query': query,
|
||||
'top_k': retrieval_model['top_k'],
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
|
||||
'all_documents': all_documents,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings
|
||||
})
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval source with full text
|
||||
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'query': query,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
'top_k': retrieval_model['top_k'],
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
|
||||
'all_documents': all_documents
|
||||
})
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if retrieval_model['search_method'] == 'hybrid_search':
|
||||
hybrid_rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
|
||||
model_name=retrieval_model['reranking_model']['reranking_model_name']
|
||||
)
|
||||
all_documents = hybrid_rerank.rerank(query, all_documents,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
retrieval_model['top_k'])
|
||||
|
||||
start = time.perf_counter()
|
||||
documents = vector_index.search(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': 10,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
end = time.perf_counter()
|
||||
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
|
||||
|
||||
@ -67,7 +116,7 @@ class HitTestingService:
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
return cls.compact_retrieve_response(dataset, embeddings, query, documents)
|
||||
return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
|
||||
|
||||
@classmethod
|
||||
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
|
||||
@ -99,7 +148,7 @@ class HitTestingService:
|
||||
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata['score'],
|
||||
"score": document.metadata.get('score', None),
|
||||
"tsne_position": tsne_position_data[i]
|
||||
}
|
||||
|
||||
@ -136,3 +185,11 @@ class HitTestingService:
|
||||
tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])})
|
||||
|
||||
return tsne_position_data
|
||||
|
||||
@classmethod
|
||||
def hit_testing_args_check(cls, args):
|
||||
query = args['query']
|
||||
|
||||
if not query or len(query) > 250:
|
||||
raise ValueError('Query is required and cannot exceed 250 characters')
|
||||
|
||||
|
88
api/services/retrieval_service.py
Normal file
88
api/services/retrieval_service.py
Normal file
@ -0,0 +1,88 @@
|
||||
|
||||
from typing import Optional
|
||||
from flask import current_app, Flask
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from models.dataset import Dataset
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
}
|
||||
|
||||
|
||||
class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, search_method: str, embeddings: Embeddings):
|
||||
with flask_app.app_context():
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = vector_index.search(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': top_k,
|
||||
'score_threshold': score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if documents:
|
||||
if reranking_model and search_method == 'semantic_search':
|
||||
rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=reranking_model['reranking_provider_name'],
|
||||
model_name=reranking_model['reranking_model_name']
|
||||
)
|
||||
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
|
||||
@classmethod
|
||||
def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, search_method: str, embeddings: Embeddings):
|
||||
with flask_app.app_context():
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = vector_index.search_by_full_text_index(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
if reranking_model and search_method == 'full_text_search':
|
||||
rerank = ModelFactory.get_reranking_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=reranking_model['reranking_provider_name'],
|
||||
model_name=reranking_model['reranking_model_name']
|
||||
)
|
||||
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user