diff --git a/api/app.py b/api/app.py index 4e217c6aac..3beb2cf706 100644 --- a/api/app.py +++ b/api/app.py @@ -14,7 +14,7 @@ from flask import Flask, request, Response, session import flask_login from flask_cors import CORS -from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \ +from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ ext_database, ext_storage from extensions.ext_database import db from extensions.ext_login import login_manager @@ -79,7 +79,6 @@ def initialize_extensions(app): ext_database.init_app(app) ext_migrate.init(app, db) ext_redis.init_app(app) - ext_vector_store.init_app(app) ext_storage.init_app(app) ext_celery.init_app(app) ext_session.init_app(app) diff --git a/api/commands.py b/api/commands.py index 10beec6c9a..544b3110e7 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,15 +1,19 @@ import datetime +import logging import random import string import click from flask import current_app +from werkzeug.exceptions import NotFound +from core.index.index import IndexBuilder from libs.password import password_pattern, valid_password, hash_password from libs.helper import email as email_validate from extensions.ext_database import db from libs.rsa import generate_key_pair from models.account import InvitationCode, Tenant +from models.dataset import Dataset from models.model import Account import secrets import base64 @@ -159,8 +163,39 @@ def generate_upper_string(): return result +@click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.') +def recreate_all_dataset_indexes(): + click.echo(click.style('Start recreate all dataset indexes.', fg='green')) + recreate_count = 0 + + page = 1 + while True: + try: + datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality')\ + .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) + except NotFound: + break + + page += 1 + for dataset in datasets: + try: + click.echo('Recreating dataset index: {}'.format(dataset.id)) + index = IndexBuilder.get_index(dataset, 'high_quality') + if index and index._is_origin(): + index.recreate_dataset(dataset) + recreate_count += 1 + else: + click.echo('passed.') + except Exception as e: + click.echo(click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) + continue + + click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green')) + + def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(generate_invitation_codes) app.cli.add_command(reset_encrypt_key_pair) + app.cli.add_command(recreate_all_dataset_indexes) diff --git a/api/config.py b/api/config.py index d892b75c78..e4c188821e 100644 --- a/api/config.py +++ b/api/config.py @@ -187,11 +187,13 @@ class Config: # For temp use only # set default LLM provider, default is 'openai', support `azure_openai` self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') + # notion import setting self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE') self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET') + self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN') class CloudEditionConfig(Config): diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index f0efc0504f..4151132621 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.data_source.notion import NotionPageReader +from core.data_loader.loader.notion import NotionLoader from core.indexing_runner import IndexingRunner from extensions.ext_database import db from libs.helper import TimestampField -from libs.oauth_data_source import NotionOAuth from models.dataset import Document from models.source import DataSourceBinding from services.dataset_service import DatasetService, DocumentService @@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource): ).first() if not data_source_binding: raise NotFound('Data source binding not found.') - reader = NotionPageReader(integration_token=data_source_binding.access_token) - if page_type == 'page': - page_content = reader.read_page(page_id) - elif page_type == 'database': - page_content = reader.query_database_data(page_id) - else: - page_content = "" + + loader = NotionLoader( + notion_access_token=data_source_binding.access_token, + notion_workspace_id=workspace_id, + notion_obj_id=page_id, + notion_page_type=page_type + ) + + text_docs = loader.load() return { - 'content': page_content + 'content': "\n".join([doc.page_content for doc in text_docs]) }, 200 @setup_required diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index 2b13e9e09e..2d6a25e91b 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles UnsupportedFileTypeError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.index.readers.html_parser import HTMLParser -from core.index.readers.pdf_parser import PDFParser -from core.index.readers.xlsx_parser import XLSXParser +from core.data_loader.file_extractor import FileExtractor from extensions.ext_storage import storage from libs.helper import TimestampField from extensions.ext_database import db @@ -123,31 +121,7 @@ class FilePreviewApi(Resource): if extension not in ALLOWED_EXTENSIONS: raise UnsupportedFileTypeError() - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(upload_file.key).suffix - filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" - storage.download(upload_file.key, filepath) - - if extension == 'pdf': - parser = PDFParser({'upload_file': upload_file}) - text = parser.parse_file(Path(filepath)) - elif extension in ['html', 'htm']: - # Use BeautifulSoup to extract text - parser = HTMLParser() - text = parser.parse_file(Path(filepath)) - elif extension == 'xlsx': - parser = XLSXParser() - text = parser.parse_file(filepath) - else: - # ['txt', 'markdown', 'md'] - with open(filepath, "rb") as fp: - data = fp.read() - encoding = chardet.detect(data)['encoding'] - if encoding: - text = data.decode(encoding=encoding).strip() if data else '' - else: - text = data.decode(encoding='utf-8').strip() if data else '' - + text = FileExtractor.load(upload_file, return_text=True) text = text[0:PREVIEW_WORDS_LIMIT] if text else '' return {'content': text} diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 285cdcba5c..2561b06b5d 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -32,8 +32,13 @@ class VersionApi(Resource): 'current_version': args.get('current_version') }) except Exception as error: - logging.exception("Check update error.") - raise InternalServerError() + logging.warning("Check update version error: {}.".format(str(error))) + return { + 'version': args.get('current_version'), + 'release_date': '', + 'release_notes': '', + 'can_auto_update': False + } content = json.loads(response.content) return { diff --git a/api/core/__init__.py b/api/core/__init__.py index f6257d8b36..d7e00f73fa 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -3,19 +3,11 @@ from typing import Optional import langchain from flask import Flask -from jieba.analyse import default_tfidf -from langchain import set_handler from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING -from llama_index import IndexStructType, QueryMode -from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP from pydantic import BaseModel from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex -from core.index.keyword_table.stopwords import STOPWORDS from core.prompt.prompt_template import OneLineFormatter -from core.vector_store.vector_store import VectorStore -from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery class HostedOpenAICredential(BaseModel): @@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials() def init_app(app: Flask): formatter = OneLineFormatter() DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format - INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map() - INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = { - QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery, - QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery, - } - INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = { - QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery, - QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery, - } - - default_tfidf.stop_words = STOPWORDS if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': langchain.verbose = True - set_handler(DifyStdOutCallbackHandler()) if app.config.get("OPENAI_API_KEY"): hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) diff --git a/api/core/agent/agent_builder.py b/api/core/agent/agent_builder.py index b1d6948467..9f290e8f67 100644 --- a/api/core/agent/agent_builder.py +++ b/api/core/agent/agent_builder.py @@ -2,7 +2,7 @@ from typing import Optional from langchain import LLMChain from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent -from langchain.callbacks import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.memory.chat_memory import BaseChatMemory from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler @@ -16,23 +16,20 @@ class AgentBuilder: def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory], dataset_tool_callback_handler: DatasetToolCallbackHandler, agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler): - llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]) llm = LLMBuilder.to_llm( tenant_id=tenant_id, model_name=agent_loop_gather_callback_handler.model_name, temperature=0, max_tokens=1024, - callback_manager=llm_callback_manager + callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()] ) - tool_callback_manager = CallbackManager([ - agent_loop_gather_callback_handler, - dataset_tool_callback_handler, - DifyStdOutCallbackHandler() - ]) - for tool in tools: - tool.callback_manager = tool_callback_manager + tool.callbacks = [ + agent_loop_gather_callback_handler, + dataset_tool_callback_handler, + DifyStdOutCallbackHandler() + ] prompt = cls.build_agent_prompt_template( tools=tools, @@ -54,7 +51,7 @@ class AgentBuilder: tools=tools, agent=agent, memory=memory, - callback_manager=agent_callback_manager, + callbacks=agent_callback_manager, max_iterations=6, early_stopping_method="generate", # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index f37411cacc..600d4e65d6 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask class AgentLoopGatherCallbackHandler(BaseCallbackHandler): """Callback Handler that prints to std out.""" + raise_error: bool = True def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None: """Initialize callback handler.""" @@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): self._current_loop.completion = response.generations[0][0].text self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing.""" - pass - def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: @@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): self._agent_loops = [] self._current_loop = None - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Print out that we are entering a chain.""" - pass - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" - pass - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - logging.error(error) - def on_tool_start( self, serialized: Dict[str, Any], @@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): self._agent_loops = [] self._current_loop = None - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Optional[str], - ) -> None: - """Run on additional input from chains and agents.""" - pass - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: """Run on agent end.""" # Final Answer diff --git a/api/core/callback_handler/dataset_tool_callback_handler.py b/api/core/callback_handler/dataset_tool_callback_handler.py index e3fce66511..b7edc3b9c6 100644 --- a/api/core/callback_handler/dataset_tool_callback_handler.py +++ b/api/core/callback_handler/dataset_tool_callback_handler.py @@ -3,7 +3,6 @@ import logging from typing import Any, Dict, List, Union, Optional from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult from core.callback_handler.entity.dataset_query import DatasetQueryObj from core.conversation_message_task import ConversationMessageTask @@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask class DatasetToolCallbackHandler(BaseCallbackHandler): """Callback Handler that prints to std out.""" + raise_error: bool = True def __init__(self, conversation_message_task: ConversationMessageTask) -> None: """Initialize callback handler.""" @@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler): ) -> None: """Do nothing.""" logging.error(error) - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - pass - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - pass - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - pass - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - pass - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - logging.error(error) - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - pass - - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Optional[str], - ) -> None: - """Run on additional input from chains and agents.""" - pass - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run on agent end.""" - pass diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index db430efe08..59518667a0 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,39 +1,26 @@ -from llama_index import Response +from typing import List + +from langchain.schema import Document from extensions.ext_database import db from models.dataset import DocumentSegment -class IndexToolCallbackHandler: - - def __init__(self) -> None: - self._response = None - - @property - def response(self) -> Response: - return self._response - - def on_tool_end(self, response: Response) -> None: - """Handle tool end.""" - self._response = response - - -class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler): +class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" def __init__(self, dataset_id: str) -> None: - super().__init__() self.dataset_id = dataset_id - def on_tool_end(self, response: Response) -> None: + def on_tool_end(self, documents: List[Document]) -> None: """Handle tool end.""" - for node in response.source_nodes: - index_node_id = node.node.doc_id + for document in documents: + doc_id = document.metadata['doc_id'] # add hit count to document segment db.session.query(DocumentSegment).filter( DocumentSegment.dataset_id == self.dataset_id, - DocumentSegment.index_node_id == index_node_id + DocumentSegment.index_node_id == doc_id ).update( {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index 89a7630737..c9db70fa76 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -3,7 +3,7 @@ import time from typing import Any, Dict, List, Union, Optional from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage +from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage from core.callback_handler.entity.llm_message import LLMMessage from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException @@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI class LLMCallbackHandler(BaseCallbackHandler): + raise_error: bool = True def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI], conversation_message_task: ConversationMessageTask): @@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler): """Whether to call verbose callbacks even if verbose is False.""" return True + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + **kwargs: Any + ) -> Any: + self.start_at = time.perf_counter() + real_prompts = [] + for message in messages[0]: + if message.type == 'human': + role = 'user' + elif message.type == 'ai': + role = 'assistant' + else: + role = 'system' + + real_prompts.append({ + "role": role, + "text": message.content + }) + + self.llm_message.prompt = real_prompts + self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0]) + def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: self.start_at = time.perf_counter() - if 'Chat' in serialized['name']: - real_prompts = [] - messages = [] - for prompt in prompts: - role, content = prompt.split(': ', maxsplit=1) - if role == 'human': - role = 'user' - message = HumanMessage(content=content) - elif role == 'ai': - role = 'assistant' - message = AIMessage(content=content) - else: - message = SystemMessage(content=content) + self.llm_message.prompt = [{ + "role": 'user', + "text": prompts[0] + }] - real_prompt = { - "role": role, - "text": content - } - real_prompts.append(real_prompt) - messages.append(message) - - self.llm_message.prompt = real_prompts - self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages) - else: - self.llm_message.prompt = [{ - "role": 'user', - "text": prompts[0] - }] - - self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) + self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: end_at = time.perf_counter() @@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler): self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) else: logging.error(error) - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - pass - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - pass - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - pass - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - pass - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - pass - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - pass - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - pass - - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Optional[str], - ) -> None: - pass - - def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - pass diff --git a/api/core/callback_handler/main_chain_gather_callback_handler.py b/api/core/callback_handler/main_chain_gather_callback_handler.py index 1bd41edd6c..bb294072d9 100644 --- a/api/core/callback_handler/main_chain_gather_callback_handler.py +++ b/api/core/callback_handler/main_chain_gather_callback_handler.py @@ -1,10 +1,9 @@ import logging import time -from typing import Any, Dict, List, Union, Optional +from typing import Any, Dict, Union from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.entity.chain_result import ChainResult @@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask class MainChainGatherCallbackHandler(BaseCallbackHandler): """Callback Handler that prints to std out.""" + raise_error: bool = True def __init__(self, conversation_message_task: ConversationMessageTask) -> None: """Initialize callback handler.""" @@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): ) -> None: """Print out that we are entering a chain.""" if not self._current_chain_result: - self._current_chain_result = ChainResult( - type=serialized['name'], - prompt=inputs, - started_at=time.perf_counter() - ) - self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) - self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message + chain_type = serialized['id'][-1] + if chain_type: + self._current_chain_result = ChainResult( + type=chain_type, + prompt=inputs, + started_at=time.perf_counter() + ) + self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) + self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain.""" @@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: logging.error(error) - self.clear_chain_results() - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - pass - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing.""" - pass - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - logging.error(error) - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - pass - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - pass - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - pass - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - logging.error(error) - - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Optional[str], - ) -> None: - """Run on additional input from chains and agents.""" - pass - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run on agent end.""" - pass + self.clear_chain_results() \ No newline at end of file diff --git a/api/core/callback_handler/std_out_callback_handler.py b/api/core/callback_handler/std_out_callback_handler.py index 352e6cb4d8..e0ff7950e5 100644 --- a/api/core/callback_handler/std_out_callback_handler.py +++ b/api/core/callback_handler/std_out_callback_handler.py @@ -1,9 +1,10 @@ +import os import sys from typing import Any, Dict, List, Optional, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.input import print_text -from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage class DifyStdOutCallbackHandler(BaseCallbackHandler): @@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): """Initialize callback handler.""" self.color = color + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + **kwargs: Any + ) -> Any: + print_text("\n[on_chat_model_start]\n", color='blue') + for sub_messages in messages: + for sub_message in sub_messages: + print_text(str(sub_message) + "\n", color='blue') + def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Print out the prompts.""" print_text("\n[on_llm_start]\n", color='blue') - - if 'Chat' in serialized['name']: - for prompt in prompts: - print_text(prompt + "\n", color='blue') - else: - print_text(prompts[0] + "\n", color='blue') + print_text(prompts[0] + "\n", color='blue') def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Do nothing.""" @@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Print out that we are entering a chain.""" - class_name = serialized["name"] - print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink') + chain_type = serialized['id'][-1] + print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink') def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain.""" @@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): """Run on agent end.""" print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n") + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + + @property + def ignore_chat_model(self) -> bool: + """Whether to ignore chat model callbacks.""" + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler): """Callback handler for streaming. Only works with LLMs that support streaming.""" diff --git a/api/core/chain/chain_builder.py b/api/core/chain/chain_builder.py index b7583ed890..5e75093111 100644 --- a/api/core/chain/chain_builder.py +++ b/api/core/chain/chain_builder.py @@ -1,7 +1,5 @@ from typing import Optional -from langchain.callbacks import CallbackManager - from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain from core.chain.tool_chain import ToolChain @@ -14,7 +12,7 @@ class ChainBuilder: tool=tool, input_key=kwargs.get('input_key', 'input'), output_key=kwargs.get('output_key', 'tool_output'), - callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) + callbacks=[DifyStdOutCallbackHandler()] ) @classmethod @@ -27,7 +25,7 @@ class ChainBuilder: sensitive_words=sensitive_words.split(","), canned_response=tool_config.get("canned_response", ''), output_key="sensitive_word_avoidance_output", - callback_manager=CallbackManager([DifyStdOutCallbackHandler()]), + callbacks=[DifyStdOutCallbackHandler()], **kwargs ) diff --git a/api/core/chain/llm_router_chain.py b/api/core/chain/llm_router_chain.py index e3779c3612..21b3c2a525 100644 --- a/api/core/chain/llm_router_chain.py +++ b/api/core/chain/llm_router_chain.py @@ -1,15 +1,16 @@ """Base classes for LLM-powered router chains.""" from __future__ import annotations -import json from typing import Any, Dict, List, Optional, Type, cast, NamedTuple +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from pydantic import root_validator from langchain.chains import LLMChain from langchain.prompts import BasePromptTemplate -from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel +from langchain.schema import BaseOutputParser, OutputParserException from libs.json_in_md_parser import parse_and_check_json_markdown @@ -51,8 +52,9 @@ class LLMRouterChain(Chain): raise ValueError def _call( - self, - inputs: Dict[str, Any] + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: output = cast( Dict[str, Any], diff --git a/api/core/chain/main_chain_builder.py b/api/core/chain/main_chain_builder.py index 4cb6205fcb..c36178ffd8 100644 --- a/api/core/chain/main_chain_builder.py +++ b/api/core/chain/main_chain_builder.py @@ -1,11 +1,9 @@ -from typing import Optional, List +from typing import Optional, List, cast -from langchain.callbacks import SharedCallbackManager, CallbackManager from langchain.chains import SequentialChain from langchain.chains.base import Chain from langchain.memory.chat_memory import BaseChatMemory -from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.chain.chain_builder import ChainBuilder @@ -18,6 +16,7 @@ from models.dataset import Dataset class MainChainBuilder: @classmethod def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], + rest_tokens: int, conversation_message_task: ConversationMessageTask): first_input_key = "input" final_output_key = "output" @@ -30,6 +29,7 @@ class MainChainBuilder: tool_chains, chains_output_key = cls.get_agent_chains( tenant_id=tenant_id, agent_mode=agent_mode, + rest_tokens=rest_tokens, memory=memory, conversation_message_task=conversation_message_task ) @@ -42,9 +42,8 @@ class MainChainBuilder: return None for chain in chains: - # do not add handler into singleton callback manager - if not isinstance(chain.callback_manager, SharedCallbackManager): - chain.callback_manager.add_handler(chain_callback_handler) + chain = cast(Chain, chain) + chain.callbacks.append(chain_callback_handler) # build main chain overall_chain = SequentialChain( @@ -57,7 +56,9 @@ class MainChainBuilder: return overall_chain @classmethod - def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], + def get_agent_chains(cls, tenant_id: str, agent_mode: dict, + rest_tokens: int, + memory: Optional[BaseChatMemory], conversation_message_task: ConversationMessageTask): # agent mode chains = [] @@ -93,7 +94,8 @@ class MainChainBuilder: tenant_id=tenant_id, datasets=datasets, conversation_message_task=conversation_message_task, - callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) + rest_tokens=rest_tokens, + callbacks=[DifyStdOutCallbackHandler()] ) chains.append(multi_dataset_router_chain) diff --git a/api/core/chain/multi_dataset_router_chain.py b/api/core/chain/multi_dataset_router_chain.py index fb0bc35f93..edbf07e87d 100644 --- a/api/core/chain/multi_dataset_router_chain.py +++ b/api/core/chain/multi_dataset_router_chain.py @@ -1,9 +1,9 @@ +import math from typing import Mapping, List, Dict, Any, Optional -from langchain import LLMChain, PromptTemplate, ConversationChain -from langchain.callbacks import CallbackManager +from langchain import PromptTemplate +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain -from langchain.schema import BaseLanguageModel from pydantic import Extra from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler @@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser from core.conversation_message_task import ConversationMessageTask from core.llm.llm_builder import LLMBuilder -from core.tool.dataset_tool_builder import DatasetToolBuilder -from core.tool.llama_index_tool import EnhanceLlamaIndexTool -from models.dataset import Dataset +from core.tool.dataset_index_tool import DatasetTool +from models.dataset import Dataset, DatasetProcessRule +DEFAULT_K = 2 +CONTEXT_TOKENS_PERCENT = 0.3 MULTI_PROMPT_ROUTER_TEMPLATE = """ Given a raw text input to a language model select the model prompt best suited for \ the input. You will be given the names of the available prompts and a description of \ @@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain): router_chain: LLMRouterChain """Chain for deciding a destination chain and the input to it.""" - dataset_tools: Mapping[str, EnhanceLlamaIndexTool] + dataset_tools: Mapping[str, DatasetTool] """Map of name to candidate chains that inputs can be routed to.""" class Config: @@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain): tenant_id: str, datasets: List[Dataset], conversation_message_task: ConversationMessageTask, + rest_tokens: int, **kwargs: Any, ): """Convenience constructor for instantiating from destination prompts.""" - llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()]) llm = LLMBuilder.to_llm( tenant_id=tenant_id, model_name='gpt-3.5-turbo', temperature=0, max_tokens=1024, - callback_manager=llm_callback_manager + callbacks=[DifyStdOutCallbackHandler()] ) - destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description + destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description else ('useful for when you want to answer queries about the ' + d.name)) for d in datasets] destinations_str = "\n".join(destinations) router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format( destinations=destinations_str ) + router_prompt = PromptTemplate( template=router_template, input_variables=["input"], output_parser=RouterOutputParser(), ) + router_chain = LLMRouterChain.from_llm(llm, router_prompt) dataset_tools = {} for dataset in datasets: - dataset_tool = DatasetToolBuilder.build_dataset_tool( + # fulfill description when it is empty + if dataset.available_document_count == 0 or dataset.available_document_count == 0: + continue + + description = dataset.description + if not description: + description = 'useful for when you want to answer queries about the ' + dataset.name + + k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens) + if k == 0: + continue + + dataset_tool = DatasetTool( + name=f"dataset-{dataset.id}", + description=description, + k=k, dataset=dataset, - response_mode='no_synthesizer', # "compact" - callback_handler=DatasetToolCallbackHandler(conversation_message_task) + callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()] ) - if dataset_tool: - dataset_tools[dataset.id] = dataset_tool + dataset_tools[str(dataset.id)] = dataset_tool return cls( router_chain=router_chain, @@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain): **kwargs, ) + @classmethod + def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int: + processing_rule = dataset.latest_process_rule + if not processing_rule: + return DEFAULT_K + + if processing_rule.mode == "custom": + rules = processing_rule.rules_dict + if not rules: + return DEFAULT_K + + segmentation = rules["segmentation"] + segment_max_tokens = segmentation["max_tokens"] + else: + segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'] + + # when rest_tokens is less than default context tokens + if rest_tokens < segment_max_tokens * DEFAULT_K: + return rest_tokens // segment_max_tokens + + context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT) + + # when context_limit_tokens is less than default context tokens, use default_k + if context_limit_tokens <= segment_max_tokens * DEFAULT_K: + return DEFAULT_K + + # Expand the k value when there's still some room left in the 30% rest tokens space + return context_limit_tokens // segment_max_tokens + def _call( self, - inputs: Dict[str, Any] + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: if len(self.dataset_tools) == 0: return {"text": ''} diff --git a/api/core/chain/sensitive_word_avoidance_chain.py b/api/core/chain/sensitive_word_avoidance_chain.py index a552551c0f..3820840912 100644 --- a/api/core/chain/sensitive_word_avoidance_chain.py +++ b/api/core/chain/sensitive_word_avoidance_chain.py @@ -1,5 +1,6 @@ -from typing import List, Dict +from typing import List, Dict, Optional, Any +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain @@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain): return self.canned_response return text - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: text = inputs[self.input_key] output = self._check_sensitive_word(text) return {self.output_key: output} diff --git a/api/core/chain/tool_chain.py b/api/core/chain/tool_chain.py index 458a35eb82..5d6f2bc88a 100644 --- a/api/core/chain/tool_chain.py +++ b/api/core/chain/tool_chain.py @@ -1,5 +1,6 @@ -from typing import List, Dict +from typing import List, Dict, Optional, Any +from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun from langchain.chains.base import Chain from langchain.tools import BaseTool @@ -30,12 +31,20 @@ class ToolChain(Chain): """ return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: input = inputs[self.input_key] output = self.tool.run(input, self.verbose) return {self.output_key: output} - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: """Run the logic of this chain and return the output.""" input = inputs[self.input_key] output = await self.tool.arun(input, self.verbose) diff --git a/api/core/completion.py b/api/core/completion.py index a4bc5a2498..a999d34c5c 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -1,17 +1,18 @@ import logging from typing import Optional, List, Union, Tuple -from langchain.callbacks import CallbackManager +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.base import BaseCallbackHandler from langchain.chat_models.base import BaseChatModel from langchain.llms import BaseLLM -from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage +from langchain.schema import BaseMessage, HumanMessage from requests.exceptions import ChunkedEncodingError from core.constant import llm_constant from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ DifyStdOutCallbackHandler -from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler +from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.llm.error import LLMBadRequestError from core.llm.llm_builder import LLMBuilder from core.chain.main_chain_builder import MainChainBuilder @@ -34,8 +35,6 @@ class Completion: """ errors: ProviderTokenNotInitError """ - cls.validate_query_tokens(app.tenant_id, app_model_config, query) - memory = None if conversation: # get memory of conversation (read-only) @@ -48,6 +47,14 @@ class Completion: inputs = conversation.inputs + rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( + mode=app.mode, + tenant_id=app.tenant_id, + app_model_config=app_model_config, + query=query, + inputs=inputs + ) + conversation_message_task = ConversationMessageTask( task_id=task_id, app=app, @@ -64,6 +71,7 @@ class Completion: main_chain = MainChainBuilder.to_langchain_components( tenant_id=app.tenant_id, agent_mode=app_model_config.agent_mode_dict, + rest_tokens=rest_tokens_for_context_and_memory, memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None, conversation_message_task=conversation_message_task ) @@ -115,7 +123,7 @@ class Completion: memory=memory ) - final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task) + final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task) cls.recale_llm_max_tokens( final_llm=final_llm, @@ -247,16 +255,14 @@ And answer according to the language of the user's question. return messages, ['\nHuman:'] @classmethod - def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], - streaming: bool, - conversation_message_task: ConversationMessageTask) -> CallbackManager: + def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], + streaming: bool, + conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]: llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) if streaming: - callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] + return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] else: - callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()] - - return CallbackManager(callback_handlers) + return [llm_callback_handler, DifyStdOutCallbackHandler()] @classmethod def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, @@ -293,7 +299,8 @@ And answer according to the language of the user's question. return memory @classmethod - def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str): + def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig, + query: str, inputs: dict) -> int: llm = LLMBuilder.to_llm_from_model( tenant_id=tenant_id, model=app_model_config.model_dict @@ -302,8 +309,26 @@ And answer according to the language of the user's question. model_limited_tokens = llm_constant.max_context_token_length[llm.model_name] max_tokens = llm.max_tokens - if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0: - raise LLMBadRequestError("Query is too long") + # get prompt without memory and context + prompt, _ = cls.get_main_llm_prompt( + mode=mode, + llm=llm, + pre_prompt=app_model_config.pre_prompt, + query=query, + inputs=inputs, + chain_output=None, + memory=None + ) + + prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \ + else llm.get_num_tokens_from_messages(prompt) + + rest_tokens = model_limited_tokens - max_tokens - prompt_tokens + if rest_tokens < 0: + raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " + "or shrink the max token, or switch to a llm with a larger token limit size.") + + return rest_tokens @classmethod def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI], @@ -360,7 +385,7 @@ And answer according to the language of the user's question. streaming=streaming ) - llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task) + llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task) cls.recale_llm_max_tokens( final_llm=llm, diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 3b73b19d1f..6057e4b63b 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -293,12 +293,12 @@ class PubHandler: if not user: raise ValueError("user is required") - user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id + user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id) return "generate_result:{}-{}".format(user_str, task_id) @classmethod def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str): - user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id + user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id) return "generate_result_stopped:{}-{}".format(user_str, task_id) def pub_text(self, text: str): @@ -306,10 +306,10 @@ class PubHandler: 'event': 'message', 'data': { 'task_id': self._task_id, - 'message_id': self._message.id, + 'message_id': str(self._message.id), 'text': text, 'mode': self._conversation.mode, - 'conversation_id': self._conversation.id + 'conversation_id': str(self._conversation.id) } } diff --git a/api/core/data_loader/file_extractor.py b/api/core/data_loader/file_extractor.py new file mode 100644 index 0000000000..cb24fa9eb8 --- /dev/null +++ b/api/core/data_loader/file_extractor.py @@ -0,0 +1,43 @@ +import tempfile +from pathlib import Path +from typing import List, Union + +from langchain.document_loaders import TextLoader, Docx2txtLoader +from langchain.schema import Document + +from core.data_loader.loader.csv import CSVLoader +from core.data_loader.loader.excel import ExcelLoader +from core.data_loader.loader.html import HTMLLoader +from core.data_loader.loader.markdown import MarkdownLoader +from core.data_loader.loader.pdf import PdfLoader +from extensions.ext_storage import storage +from models.model import UploadFile + + +class FileExtractor: + @classmethod + def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]: + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(upload_file.key).suffix + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + storage.download(upload_file.key, file_path) + + input_file = Path(file_path) + delimiter = '\n' + if input_file.suffix == '.xlsx': + loader = ExcelLoader(file_path) + elif input_file.suffix == '.pdf': + loader = PdfLoader(file_path, upload_file=upload_file) + elif input_file.suffix in ['.md', '.markdown']: + loader = MarkdownLoader(file_path, autodetect_encoding=True) + elif input_file.suffix in ['.htm', '.html']: + loader = HTMLLoader(file_path) + elif input_file.suffix == '.docx': + loader = Docx2txtLoader(file_path) + elif input_file.suffix == '.csv': + loader = CSVLoader(file_path, autodetect_encoding=True) + else: + # txt + loader = TextLoader(file_path, autodetect_encoding=True) + + return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load() diff --git a/api/core/data_loader/loader/csv.py b/api/core/data_loader/loader/csv.py new file mode 100644 index 0000000000..ff57ef61e7 --- /dev/null +++ b/api/core/data_loader/loader/csv.py @@ -0,0 +1,67 @@ +import logging +from typing import Optional, Dict, List + +from langchain.document_loaders import CSVLoader as LCCSVLoader +from langchain.document_loaders.helpers import detect_file_encodings + +from models.dataset import Document + +logger = logging.getLogger(__name__) + + +class CSVLoader(LCCSVLoader): + def __init__( + self, + file_path: str, + source_column: Optional[str] = None, + csv_args: Optional[Dict] = None, + encoding: Optional[str] = None, + autodetect_encoding: bool = True, + ): + self.file_path = file_path + self.source_column = source_column + self.encoding = encoding + self.csv_args = csv_args or {} + self.autodetect_encoding = autodetect_encoding + + def load(self) -> List[Document]: + """Load data into document objects.""" + try: + with open(self.file_path, newline="", encoding=self.encoding) as csvfile: + docs = self._read_from_file(csvfile) + except UnicodeDecodeError as e: + if self.autodetect_encoding: + detected_encodings = detect_file_encodings(self.file_path) + for encoding in detected_encodings: + logger.debug("Trying encoding: ", encoding.encoding) + try: + with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile: + docs = self._read_from_file(csvfile) + break + except UnicodeDecodeError: + continue + else: + raise RuntimeError(f"Error loading {self.file_path}") from e + + return docs + + def _read_from_file(self, csvfile): + docs = [] + csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore + for i, row in enumerate(csv_reader): + content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items()) + try: + source = ( + row[self.source_column] + if self.source_column is not None + else '' + ) + except KeyError: + raise ValueError( + f"Source column '{self.source_column}' not found in CSV file." + ) + metadata = {"source": source, "row": i} + doc = Document(page_content=content, metadata=metadata) + docs.append(doc) + + return docs diff --git a/api/core/data_loader/loader/excel.py b/api/core/data_loader/loader/excel.py new file mode 100644 index 0000000000..7dd62f5094 --- /dev/null +++ b/api/core/data_loader/loader/excel.py @@ -0,0 +1,43 @@ +import json +import logging +from typing import List + +from langchain.document_loaders.base import BaseLoader +from langchain.schema import Document +from openpyxl.reader.excel import load_workbook + +logger = logging.getLogger(__name__) + + +class ExcelLoader(BaseLoader): + """Load xlxs files. + + + Args: + file_path: Path to the file to load. + """ + + def __init__( + self, + file_path: str + ): + """Initialize with file path.""" + self._file_path = file_path + + def load(self) -> List[Document]: + data = [] + keys = [] + wb = load_workbook(filename=self._file_path, read_only=True) + # loop over all sheets + for sheet in wb: + for row in sheet.iter_rows(values_only=True): + if all(v is None for v in row): + continue + if keys == []: + keys = list(map(str, row)) + else: + row_dict = dict(zip(keys, row)) + row_dict = {k: v for k, v in row_dict.items() if v} + data.append(json.dumps(row_dict, ensure_ascii=False)) + + return [Document(page_content='\n\n'.join(data))] diff --git a/api/core/data_loader/loader/html.py b/api/core/data_loader/loader/html.py new file mode 100644 index 0000000000..414975007b --- /dev/null +++ b/api/core/data_loader/loader/html.py @@ -0,0 +1,35 @@ +import logging +from typing import List + +from bs4 import BeautifulSoup +from langchain.document_loaders.base import BaseLoader +from langchain.schema import Document + +logger = logging.getLogger(__name__) + + +class HTMLLoader(BaseLoader): + """Load html files. + + + Args: + file_path: Path to the file to load. + """ + + def __init__( + self, + file_path: str + ): + """Initialize with file path.""" + self._file_path = file_path + + def load(self) -> List[Document]: + return [Document(page_content=self._load_as_text())] + + def _load_as_text(self) -> str: + with open(self._file_path, "rb") as fp: + soup = BeautifulSoup(fp, 'html.parser') + text = soup.get_text() + text = text.strip() if text else '' + + return text diff --git a/api/core/data_loader/loader/markdown.py b/api/core/data_loader/loader/markdown.py new file mode 100644 index 0000000000..4e6c0d5637 --- /dev/null +++ b/api/core/data_loader/loader/markdown.py @@ -0,0 +1,134 @@ +import logging +import re +from typing import Optional, List, Tuple, cast + +from langchain.document_loaders.base import BaseLoader +from langchain.document_loaders.helpers import detect_file_encodings +from langchain.schema import Document + +logger = logging.getLogger(__name__) + + +class MarkdownLoader(BaseLoader): + """Load md files. + + + Args: + file_path: Path to the file to load. + + remove_hyperlinks: Whether to remove hyperlinks from the text. + + remove_images: Whether to remove images from the text. + + encoding: File encoding to use. If `None`, the file will be loaded + with the default system encoding. + + autodetect_encoding: Whether to try to autodetect the file encoding + if the specified encoding fails. + """ + + def __init__( + self, + file_path: str, + remove_hyperlinks: bool = True, + remove_images: bool = True, + encoding: Optional[str] = None, + autodetect_encoding: bool = True, + ): + """Initialize with file path.""" + self._file_path = file_path + self._remove_hyperlinks = remove_hyperlinks + self._remove_images = remove_images + self._encoding = encoding + self._autodetect_encoding = autodetect_encoding + + def load(self) -> List[Document]: + tups = self.parse_tups(self._file_path) + documents = [] + for header, value in tups: + value = value.strip() + if header is None: + documents.append(Document(page_content=value)) + else: + documents.append(Document(page_content=f"\n\n{header}\n{value}")) + + return documents + + def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: + """Convert a markdown file to a dictionary. + + The keys are the headers and the values are the text under each header. + + """ + markdown_tups: List[Tuple[Optional[str], str]] = [] + lines = markdown_text.split("\n") + + current_header = None + current_text = "" + + for line in lines: + header_match = re.match(r"^#+\s", line) + if header_match: + if current_header is not None: + markdown_tups.append((current_header, current_text)) + + current_header = line + current_text = "" + else: + current_text += line + "\n" + markdown_tups.append((current_header, current_text)) + + if current_header is not None: + # pass linting, assert keys are defined + markdown_tups = [ + (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) + for key, value in markdown_tups + ] + else: + markdown_tups = [ + (key, re.sub("\n", "", value)) for key, value in markdown_tups + ] + + return markdown_tups + + def remove_images(self, content: str) -> str: + """Get a dictionary of a markdown file from its path.""" + pattern = r"!{1}\[\[(.*)\]\]" + content = re.sub(pattern, "", content) + return content + + def remove_hyperlinks(self, content: str) -> str: + """Get a dictionary of a markdown file from its path.""" + pattern = r"\[(.*?)\]\((.*?)\)" + content = re.sub(pattern, r"\1", content) + return content + + def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]: + """Parse file into tuples.""" + content = "" + try: + with open(filepath, "r", encoding=self._encoding) as f: + content = f.read() + except UnicodeDecodeError as e: + if self._autodetect_encoding: + detected_encodings = detect_file_encodings(filepath) + for encoding in detected_encodings: + logger.debug("Trying encoding: ", encoding.encoding) + try: + with open(filepath, encoding=encoding.encoding) as f: + content = f.read() + break + except UnicodeDecodeError: + continue + else: + raise RuntimeError(f"Error loading {filepath}") from e + except Exception as e: + raise RuntimeError(f"Error loading {filepath}") from e + + if self._remove_hyperlinks: + content = self.remove_hyperlinks(content) + + if self._remove_images: + content = self.remove_images(content) + + return self.markdown_to_tups(content) diff --git a/api/core/data_source/notion.py b/api/core/data_loader/loader/notion.py similarity index 59% rename from api/core/data_source/notion.py rename to api/core/data_loader/loader/notion.py index 8307af3835..913128d9fe 100644 --- a/api/core/data_source/notion.py +++ b/api/core/data_loader/loader/notion.py @@ -1,67 +1,224 @@ -"""Notion reader.""" import json import logging -import os -from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import List, Dict, Any, Optional -import requests # type: ignore +import requests +from flask import current_app +from langchain.document_loaders.base import BaseLoader +from langchain.schema import Document -from llama_index.readers.base import BaseReader -from llama_index.readers.schema.base import Document +from extensions.ext_database import db +from models.dataset import Document as DocumentModel +from models.source import DataSourceBinding + +logger = logging.getLogger(__name__) -INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN" BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" SEARCH_URL = "https://api.notion.com/v1/search" RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] -logger = logging.getLogger(__name__) -# TODO: Notion DB reader coming soon! -class NotionPageReader(BaseReader): - """Notion Page reader. +class NotionLoader(BaseLoader): + def __init__( + self, + notion_access_token: str, + notion_workspace_id: str, + notion_obj_id: str, + notion_page_type: str, + document_model: Optional[DocumentModel] = None + ): + self._document_model = document_model + self._notion_workspace_id = notion_workspace_id + self._notion_obj_id = notion_obj_id + self._notion_page_type = notion_page_type + self._notion_access_token = notion_access_token - Reads a set of Notion pages. - - Args: - integration_token (str): Notion integration token. - - """ - - def __init__(self, integration_token: Optional[str] = None) -> None: - """Initialize with parameters.""" - if integration_token is None: - integration_token = os.getenv(INTEGRATION_TOKEN_NAME) + if not self._notion_access_token: + integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') if integration_token is None: raise ValueError( "Must specify `integration_token` or set environment " "variable `NOTION_INTEGRATION_TOKEN`." ) - self.token = integration_token - self.headers = { - "Authorization": "Bearer " + self.token, - "Content-Type": "application/json", - "Notion-Version": "2022-06-28", - } - def _read_block(self, block_id: str, num_tabs: int = 0) -> str: - """Read a block.""" - done = False + self._notion_access_token = integration_token + + @classmethod + def from_document(cls, document_model: DocumentModel): + data_source_info = document_model.data_source_info_dict + if not data_source_info or 'notion_page_id' not in data_source_info \ + or 'notion_workspace_id' not in data_source_info: + raise ValueError("no notion page found") + + notion_workspace_id = data_source_info['notion_workspace_id'] + notion_obj_id = data_source_info['notion_page_id'] + notion_page_type = data_source_info['type'] + notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id) + + return cls( + notion_access_token=notion_access_token, + notion_workspace_id=notion_workspace_id, + notion_obj_id=notion_obj_id, + notion_page_type=notion_page_type, + document_model=document_model + ) + + def load(self) -> List[Document]: + self.update_last_edited_time( + self._document_model + ) + + text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type) + + return text_docs + + def _load_data_as_documents( + self, notion_obj_id: str, notion_page_type: str + ) -> List[Document]: + docs = [] + if notion_page_type == 'database': + # get all the pages in the database + page_text = self._get_notion_database_data(notion_obj_id) + docs.append(Document(page_content=page_text)) + elif notion_page_type == 'page': + page_text_list = self._get_notion_block_data(notion_obj_id) + for page_text in page_text_list: + docs.append(Document(page_content=page_text)) + else: + raise ValueError("notion page type not supported") + + return docs + + def _get_notion_database_data( + self, database_id: str, query_dict: Dict[str, Any] = {} + ) -> str: + """Get all the pages from a Notion database.""" + res = requests.post( + DATABASE_URL_TMPL.format(database_id=database_id), + headers={ + "Authorization": "Bearer " + self._notion_access_token, + "Content-Type": "application/json", + "Notion-Version": "2022-06-28", + }, + json=query_dict, + ) + + data = res.json() + + database_content_list = [] + if 'results' not in data or data["results"] is None: + return "" + for result in data["results"]: + properties = result['properties'] + data = {} + for property_name, property_value in properties.items(): + type = property_value['type'] + if type == 'multi_select': + value = [] + multi_select_list = property_value[type] + for multi_select in multi_select_list: + value.append(multi_select['name']) + elif type == 'rich_text' or type == 'title': + if len(property_value[type]) > 0: + value = property_value[type][0]['plain_text'] + else: + value = '' + elif type == 'select' or type == 'status': + if property_value[type]: + value = property_value[type]['name'] + else: + value = '' + else: + value = property_value[type] + data[property_name] = value + database_content_list.append(json.dumps(data, ensure_ascii=False)) + + return "\n\n".join(database_content_list) + + def _get_notion_block_data(self, page_id: str) -> List[str]: result_lines_arr = [] - cur_block_id = block_id - while not done: + cur_block_id = page_id + while True: block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) query_dict: Dict[str, Any] = {} res = requests.request( - "GET", block_url, headers=self.headers, json=query_dict + "GET", + block_url, + headers={ + "Authorization": "Bearer " + self._notion_access_token, + "Content-Type": "application/json", + "Notion-Version": "2022-06-28", + }, + json=query_dict + ) + data = res.json() + # current block's heading + heading = '' + for result in data["results"]: + result_type = result["type"] + result_obj = result[result_type] + cur_result_text_arr = [] + if result_type == 'table': + result_block_id = result["id"] + text = self._read_table_rows(result_block_id) + text += "\n\n" + result_lines_arr.append(text) + else: + if "rich_text" in result_obj: + for rich_text in result_obj["rich_text"]: + # skip if doesn't have text object + if "text" in rich_text: + text = rich_text["text"]["content"] + cur_result_text_arr.append(text) + if result_type in HEADING_TYPE: + heading = text + + result_block_id = result["id"] + has_children = result["has_children"] + block_type = result["type"] + if has_children and block_type != 'child_page': + children_text = self._read_block( + result_block_id, num_tabs=1 + ) + cur_result_text_arr.append(children_text) + + cur_result_text = "\n".join(cur_result_text_arr) + cur_result_text += "\n\n" + if result_type in HEADING_TYPE: + result_lines_arr.append(cur_result_text) + else: + result_lines_arr.append(f'{heading}\n{cur_result_text}') + + if data["next_cursor"] is None: + break + else: + cur_block_id = data["next_cursor"] + return result_lines_arr + + def _read_block(self, block_id: str, num_tabs: int = 0) -> str: + """Read a block.""" + result_lines_arr = [] + cur_block_id = block_id + while True: + block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) + query_dict: Dict[str, Any] = {} + + res = requests.request( + "GET", + block_url, + headers={ + "Authorization": "Bearer " + self._notion_access_token, + "Content-Type": "application/json", + "Notion-Version": "2022-06-28", + }, + json=query_dict ) data = res.json() if 'results' not in data or data["results"] is None: - done = True break heading = '' for result in data["results"]: @@ -98,7 +255,6 @@ class NotionPageReader(BaseReader): result_lines_arr.append(f'{heading}\n{cur_result_text}') if data["next_cursor"] is None: - done = True break else: cur_block_id = data["next_cursor"] @@ -116,7 +272,14 @@ class NotionPageReader(BaseReader): query_dict: Dict[str, Any] = {} res = requests.request( - "GET", block_url, headers=self.headers, json=query_dict + "GET", + block_url, + headers={ + "Authorization": "Bearer " + self._notion_access_token, + "Content-Type": "application/json", + "Notion-Version": "2022-06-28", + }, + json=query_dict ) data = res.json() # get table headers text @@ -129,9 +292,9 @@ class NotionPageReader(BaseReader): table_header_cell_texts.append(text) # get table columns text and format results = data["results"] - for i in range(len(results)-1): + for i in range(len(results) - 1): column_texts = [] - tabel_column_cells = data["results"][i+1]['table_row']['cells'] + tabel_column_cells = data["results"][i + 1]['table_row']['cells'] for j in range(len(tabel_column_cells)): if tabel_column_cells[j]: for table_column_cell_text in tabel_column_cells[j]: @@ -149,221 +312,58 @@ class NotionPageReader(BaseReader): result_lines = "\n".join(result_lines_arr) return result_lines - def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]: - """Read a block.""" - done = False - result_lines_arr = [] - cur_block_id = block_id - while not done: - block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) - query_dict: Dict[str, Any] = {} - res = requests.request( - "GET", block_url, headers=self.headers, json=query_dict + def update_last_edited_time(self, document_model: DocumentModel): + if not document_model: + return + + last_edited_time = self.get_notion_last_edited_time() + data_source_info = document_model.data_source_info_dict + data_source_info['last_edited_time'] = last_edited_time + update_params = { + DocumentModel.data_source_info: json.dumps(data_source_info) + } + + DocumentModel.query.filter_by(id=document_model.id).update(update_params) + db.session.commit() + + def get_notion_last_edited_time(self) -> str: + obj_id = self._notion_obj_id + page_type = self._notion_page_type + if page_type == 'database': + retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id) + else: + retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id) + + query_dict: Dict[str, Any] = {} + + res = requests.request( + "GET", + retrieve_page_url, + headers={ + "Authorization": "Bearer " + self._notion_access_token, + "Content-Type": "application/json", + "Notion-Version": "2022-06-28", + }, + json=query_dict + ) + + data = res.json() + return data["last_edited_time"] + + @classmethod + def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: + data_source_binding = DataSourceBinding.query.filter( + db.and_( + DataSourceBinding.tenant_id == tenant_id, + DataSourceBinding.provider == 'notion', + DataSourceBinding.disabled == False, + DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' ) - data = res.json() - # current block's heading - heading = '' - for result in data["results"]: - result_type = result["type"] - result_obj = result[result_type] - cur_result_text_arr = [] - if result_type == 'table': - result_block_id = result["id"] - text = self._read_table_rows(result_block_id) - text += "\n\n" - result_lines_arr.append(text) - else: - if "rich_text" in result_obj: - for rich_text in result_obj["rich_text"]: - # skip if doesn't have text object - if "text" in rich_text: - text = rich_text["text"]["content"] - cur_result_text_arr.append(text) - if result_type in HEADING_TYPE: - heading = text + ).first() - result_block_id = result["id"] - has_children = result["has_children"] - block_type = result["type"] - if has_children and block_type != 'child_page': - children_text = self._read_block( - result_block_id, num_tabs=num_tabs + 1 - ) - cur_result_text_arr.append(children_text) + if not data_source_binding: + raise Exception(f'No notion data source binding found for tenant {tenant_id} ' + f'and notion workspace {notion_workspace_id}') - cur_result_text = "\n".join(cur_result_text_arr) - cur_result_text += "\n\n" - if result_type in HEADING_TYPE: - result_lines_arr.append(cur_result_text) - else: - result_lines_arr.append(f'{heading}\n{cur_result_text}') - - if data["next_cursor"] is None: - done = True - break - else: - cur_block_id = data["next_cursor"] - return result_lines_arr - - def read_page(self, page_id: str) -> str: - """Read a page.""" - return self._read_block(page_id) - - def read_page_as_documents(self, page_id: str) -> List[str]: - """Read a page as documents.""" - return self._read_parent_blocks(page_id) - - def query_database_data( - self, database_id: str, query_dict: Dict[str, Any] = {} - ) -> str: - """Get all the pages from a Notion database.""" - res = requests.post\ - ( - DATABASE_URL_TMPL.format(database_id=database_id), - headers=self.headers, - json=query_dict, - ) - data = res.json() - database_content_list = [] - if 'results' not in data or data["results"] is None: - return "" - for result in data["results"]: - properties = result['properties'] - data = {} - for property_name, property_value in properties.items(): - type = property_value['type'] - if type == 'multi_select': - value = [] - multi_select_list = property_value[type] - for multi_select in multi_select_list: - value.append(multi_select['name']) - elif type == 'rich_text' or type == 'title': - if len(property_value[type]) > 0: - value = property_value[type][0]['plain_text'] - else: - value = '' - elif type == 'select' or type == 'status': - if property_value[type]: - value = property_value[type]['name'] - else: - value = '' - else: - value = property_value[type] - data[property_name] = value - database_content_list.append(json.dumps(data, ensure_ascii=False)) - - return "\n\n".join(database_content_list) - - def query_database( - self, database_id: str, query_dict: Dict[str, Any] = {} - ) -> List[str]: - """Get all the pages from a Notion database.""" - res = requests.post\ - ( - DATABASE_URL_TMPL.format(database_id=database_id), - headers=self.headers, - json=query_dict, - ) - data = res.json() - page_ids = [] - for result in data["results"]: - page_id = result["id"] - page_ids.append(page_id) - - return page_ids - - def search(self, query: str) -> List[str]: - """Search Notion page given a text query.""" - done = False - next_cursor: Optional[str] = None - page_ids = [] - while not done: - query_dict = { - "query": query, - } - if next_cursor is not None: - query_dict["start_cursor"] = next_cursor - res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict) - data = res.json() - for result in data["results"]: - page_id = result["id"] - page_ids.append(page_id) - - if data["next_cursor"] is None: - done = True - break - else: - next_cursor = data["next_cursor"] - return page_ids - - def load_data( - self, page_ids: List[str] = [], database_id: Optional[str] = None - ) -> List[Document]: - """Load data from the input directory. - - Args: - page_ids (List[str]): List of page ids to load. - - Returns: - List[Document]: List of documents. - - """ - if not page_ids and not database_id: - raise ValueError("Must specify either `page_ids` or `database_id`.") - docs = [] - if database_id is not None: - # get all the pages in the database - page_ids = self.query_database(database_id) - for page_id in page_ids: - page_text = self.read_page(page_id) - docs.append(Document(page_text)) - else: - for page_id in page_ids: - page_text = self.read_page(page_id) - docs.append(Document(page_text)) - - return docs - - def load_data_as_documents( - self, page_ids: List[str] = [], database_id: Optional[str] = None - ) -> List[Document]: - if not page_ids and not database_id: - raise ValueError("Must specify either `page_ids` or `database_id`.") - docs = [] - if database_id is not None: - # get all the pages in the database - page_text = self.query_database_data(database_id) - docs.append(Document(page_text)) - else: - for page_id in page_ids: - page_text_list = self.read_page_as_documents(page_id) - for page_text in page_text_list: - docs.append(Document(page_text)) - - return docs - - def get_page_last_edited_time(self, page_id: str) -> str: - retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id) - query_dict: Dict[str, Any] = {} - - res = requests.request( - "GET", retrieve_page_url, headers=self.headers, json=query_dict - ) - data = res.json() - return data["last_edited_time"] - - def get_database_last_edited_time(self, database_id: str) -> str: - retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=database_id) - query_dict: Dict[str, Any] = {} - - res = requests.request( - "GET", retrieve_page_url, headers=self.headers, json=query_dict - ) - data = res.json() - return data["last_edited_time"] - - -if __name__ == "__main__": - reader = NotionPageReader() - logger.info(reader.search("What I")) + return data_source_binding.access_token diff --git a/api/core/data_loader/loader/pdf.py b/api/core/data_loader/loader/pdf.py new file mode 100644 index 0000000000..881d0026b5 --- /dev/null +++ b/api/core/data_loader/loader/pdf.py @@ -0,0 +1,55 @@ +import logging +from typing import List, Optional + +from langchain.document_loaders import PyPDFium2Loader +from langchain.document_loaders.base import BaseLoader +from langchain.schema import Document + +from extensions.ext_storage import storage +from models.model import UploadFile + +logger = logging.getLogger(__name__) + + +class PdfLoader(BaseLoader): + """Load pdf files. + + + Args: + file_path: Path to the file to load. + """ + + def __init__( + self, + file_path: str, + upload_file: Optional[UploadFile] = None + ): + """Initialize with file path.""" + self._file_path = file_path + self._upload_file = upload_file + + def load(self) -> List[Document]: + plaintext_file_key = '' + plaintext_file_exists = False + if self._upload_file: + if self._upload_file.hash: + plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \ + + self._upload_file.hash + '.0625.plaintext' + try: + text = storage.load(plaintext_file_key).decode('utf-8') + plaintext_file_exists = True + return [Document(page_content=text)] + except FileNotFoundError: + pass + documents = PyPDFium2Loader(file_path=self._file_path).load() + text_list = [] + for document in documents: + text_list.append(document.page_content) + text = "\n\n".join(text_list) + + # save plaintext file for caching + if not plaintext_file_exists and plaintext_file_key: + storage.save(plaintext_file_key, text.encode('utf-8')) + + return documents + diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index b3b968532c..b8af3bf01b 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -1,10 +1,6 @@ from typing import Any, Dict, Optional, Sequence -import tiktoken -from llama_index.data_structs import Node -from llama_index.docstore.types import BaseDocumentStore -from llama_index.docstore.utils import json_to_doc -from llama_index.schema import BaseDocument +from langchain.schema import Document from sqlalchemy import func from core.llm.token_calculator import TokenCalculator @@ -12,7 +8,7 @@ from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment -class DatesetDocumentStore(BaseDocumentStore): +class DatesetDocumentStore: def __init__( self, dataset: Dataset, @@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore): return self._embedding_model_name @property - def docs(self) -> Dict[str, BaseDocument]: + def docs(self) -> Dict[str, Document]: document_segments = db.session.query(DocumentSegment).filter( DocumentSegment.dataset_id == self._dataset.id ).all() @@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore): output = {} for document_segment in document_segments: doc_id = document_segment.index_node_id - result = self.segment_to_dict(document_segment) - output[doc_id] = json_to_doc(result) + output[doc_id] = Document( + page_content=document_segment.content, + metadata={ + "doc_id": document_segment.index_node_id, + "doc_hash": document_segment.index_node_hash, + "document_id": document_segment.document_id, + "dataset_id": document_segment.dataset_id, + } + ) return output def add_documents( - self, docs: Sequence[BaseDocument], allow_update: bool = True + self, docs: Sequence[Document], allow_update: bool = True ) -> None: max_position = db.session.query(func.max(DocumentSegment.position)).filter( DocumentSegment.document == self._document_id @@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore): max_position = 0 for doc in docs: - if doc.is_doc_id_none: - raise ValueError("doc_id not set") + if not isinstance(doc, Document): + raise ValueError("doc must be a Document") - if not isinstance(doc, Node): - raise ValueError("doc must be a Node") - - segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False) + segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False) # NOTE: doc could already exist in the store, but we overwrite it if not allow_update and segment_document: raise ValueError( - f"doc_id {doc.get_doc_id()} already exists. " + f"doc_id {doc.metadata['doc_id']} already exists. " "Set allow_update to True to overwrite." ) # calc embedding use tokens - tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text()) + tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content) if not segment_document: max_position += 1 @@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore): tenant_id=self._dataset.tenant_id, dataset_id=self._dataset.id, document_id=self._document_id, - index_node_id=doc.get_doc_id(), - index_node_hash=doc.get_doc_hash(), + index_node_id=doc.metadata['doc_id'], + index_node_hash=doc.metadata['doc_hash'], position=max_position, - content=doc.get_text(), - word_count=len(doc.get_text()), + content=doc.page_content, + word_count=len(doc.page_content), tokens=tokens, created_by=self._user_id, ) db.session.add(segment_document) else: - segment_document.content = doc.get_text() - segment_document.index_node_hash = doc.get_doc_hash() - segment_document.word_count = len(doc.get_text()) + segment_document.content = doc.page_content + segment_document.index_node_hash = doc.metadata['doc_hash'] + segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens db.session.commit() @@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore): def get_document( self, doc_id: str, raise_error: bool = True - ) -> Optional[BaseDocument]: + ) -> Optional[Document]: document_segment = self.get_document_segment(doc_id) if document_segment is None: @@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore): else: return None - result = self.segment_to_dict(document_segment) - return json_to_doc(result) + return Document( + page_content=document_segment.content, + metadata={ + "doc_id": document_segment.index_node_id, + "doc_hash": document_segment.index_node_hash, + "document_id": document_segment.document_id, + "dataset_id": document_segment.dataset_id, + } + ) def delete_document(self, doc_id: str, raise_error: bool = True) -> None: document_segment = self.get_document_segment(doc_id) @@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore): return document_segment.index_node_hash - def update_docstore(self, other: "BaseDocumentStore") -> None: - """Update docstore. - - Args: - other (BaseDocumentStore): docstore to update from - - """ - self.add_documents(list(other.docs.values())) - def get_document_segment(self, doc_id: str) -> DocumentSegment: document_segment = db.session.query(DocumentSegment).filter( DocumentSegment.dataset_id == self._dataset.id, @@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore): ).first() return document_segment - - def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]: - return { - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "text": segment.content, - "__type__": Node.get_type() - } diff --git a/api/core/docstore/empty_docstore.py b/api/core/docstore/empty_docstore.py deleted file mode 100644 index e19f1824cb..0000000000 --- a/api/core/docstore/empty_docstore.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Any, Dict, Optional, Sequence -from llama_index.docstore.types import BaseDocumentStore -from llama_index.schema import BaseDocument - - -class EmptyDocumentStore(BaseDocumentStore): - @classmethod - def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore": - return cls() - - def to_dict(self) -> Dict[str, Any]: - """Serialize to dict.""" - return {} - - @property - def docs(self) -> Dict[str, BaseDocument]: - return {} - - def add_documents( - self, docs: Sequence[BaseDocument], allow_update: bool = True - ) -> None: - pass - - def document_exists(self, doc_id: str) -> bool: - """Check if document exists.""" - return False - - def get_document( - self, doc_id: str, raise_error: bool = True - ) -> Optional[BaseDocument]: - return None - - def delete_document(self, doc_id: str, raise_error: bool = True) -> None: - pass - - def set_document_hash(self, doc_id: str, doc_hash: str) -> None: - """Set the hash for a given doc_id.""" - pass - - def get_document_hash(self, doc_id: str) -> Optional[str]: - """Get the stored hash for a document, if it exists.""" - return None - - def update_docstore(self, other: "BaseDocumentStore") -> None: - """Update docstore. - - Args: - other (BaseDocumentStore): docstore to update from - - """ - self.add_documents(list(other.docs.values())) diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py new file mode 100644 index 0000000000..4030eb158c --- /dev/null +++ b/api/core/embedding/cached_embedding.py @@ -0,0 +1,72 @@ +import logging +from typing import List + +from langchain.embeddings.base import Embeddings +from sqlalchemy.exc import IntegrityError + +from extensions.ext_database import db +from libs import helper +from models.dataset import Embedding + + +class CacheEmbedding(Embeddings): + def __init__(self, embeddings: Embeddings): + self._embeddings = embeddings + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed search docs.""" + # use doc embedding cache or store if not exists + text_embeddings = [] + embedding_queue_texts = [] + for text in texts: + hash = helper.generate_text_hash(text) + embedding = db.session.query(Embedding).filter_by(hash=hash).first() + if embedding: + text_embeddings.append(embedding.get_embedding()) + else: + embedding_queue_texts.append(text) + + embedding_results = self._embeddings.embed_documents(embedding_queue_texts) + + i = 0 + for text in embedding_queue_texts: + hash = helper.generate_text_hash(text) + + try: + embedding = Embedding(hash=hash) + embedding.set_embedding(embedding_results[i]) + db.session.add(embedding) + db.session.commit() + except IntegrityError: + db.session.rollback() + continue + except: + logging.exception('Failed to add embedding to db') + continue + + i += 1 + + text_embeddings.extend(embedding_results) + return text_embeddings + + def embed_query(self, text: str) -> List[float]: + """Embed query text.""" + # use doc embedding cache or store if not exists + hash = helper.generate_text_hash(text) + embedding = db.session.query(Embedding).filter_by(hash=hash).first() + if embedding: + return embedding.get_embedding() + + embedding_results = self._embeddings.embed_query(text) + + try: + embedding = Embedding(hash=hash) + embedding.set_embedding(embedding_results) + db.session.add(embedding) + db.session.commit() + except IntegrityError: + db.session.rollback() + except: + logging.exception('Failed to add embedding to db') + + return embedding_results diff --git a/api/core/embedding/openai_embedding.py b/api/core/embedding/openai_embedding.py deleted file mode 100644 index d1179180f6..0000000000 --- a/api/core/embedding/openai_embedding.py +++ /dev/null @@ -1,214 +0,0 @@ -from typing import Optional, Any, List - -import openai -from llama_index.embeddings.base import BaseEmbedding -from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \ - _TEXT_MODE_MODEL_DICT -from tenacity import wait_random_exponential, retry, stop_after_attempt - -from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async - - -@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) -def get_embedding( - text: str, - engine: Optional[str] = None, - api_key: Optional[str] = None, - **kwargs -) -> List[float]: - """Get embedding. - - NOTE: Copied from OpenAI's embedding utils: - https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py - - Copied here to avoid importing unnecessary dependencies - like matplotlib, plotly, scipy, sklearn. - - """ - text = text.replace("\n", " ") - return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"] - - -@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) -async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[ - float]: - """Asynchronously get embedding. - - NOTE: Copied from OpenAI's embedding utils: - https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py - - Copied here to avoid importing unnecessary dependencies - like matplotlib, plotly, scipy, sklearn. - - """ - # replace newlines, which can negatively affect performance. - text = text.replace("\n", " ") - - return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][ - "embedding" - ] - - -@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) -def get_embeddings( - list_of_text: List[str], - engine: Optional[str] = None, - api_key: Optional[str] = None, - **kwargs -) -> List[List[float]]: - """Get embeddings. - - NOTE: Copied from OpenAI's embedding utils: - https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py - - Copied here to avoid importing unnecessary dependencies - like matplotlib, plotly, scipy, sklearn. - - """ - assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." - - # replace newlines, which can negatively affect performance. - list_of_text = [text.replace("\n", " ") for text in list_of_text] - - data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data - data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. - return [d["embedding"] for d in data] - - -@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) -async def aget_embeddings( - list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs -) -> List[List[float]]: - """Asynchronously get embeddings. - - NOTE: Copied from OpenAI's embedding utils: - https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py - - Copied here to avoid importing unnecessary dependencies - like matplotlib, plotly, scipy, sklearn. - - """ - assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." - - # replace newlines, which can negatively affect performance. - list_of_text = [text.replace("\n", " ") for text in list_of_text] - - data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data - data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. - return [d["embedding"] for d in data] - - -class OpenAIEmbedding(BaseEmbedding): - - def __init__( - self, - mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, - model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, - deployment_name: Optional[str] = None, - openai_api_key: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Init params.""" - new_kwargs = {} - - if 'embed_batch_size' in kwargs: - new_kwargs['embed_batch_size'] = kwargs['embed_batch_size'] - - if 'tokenizer' in kwargs: - new_kwargs['tokenizer'] = kwargs['tokenizer'] - - super().__init__(**new_kwargs) - self.mode = OpenAIEmbeddingMode(mode) - self.model = OpenAIEmbeddingModelType(model) - self.deployment_name = deployment_name - self.openai_api_key = openai_api_key - self.openai_api_type = kwargs.get('openai_api_type') - self.openai_api_version = kwargs.get('openai_api_version') - self.openai_api_base = kwargs.get('openai_api_base') - - @handle_llm_exceptions - def _get_query_embedding(self, query: str) -> List[float]: - """Get query embedding.""" - if self.deployment_name is not None: - engine = self.deployment_name - else: - key = (self.mode, self.model) - if key not in _QUERY_MODE_MODEL_DICT: - raise ValueError(f"Invalid mode, model combination: {key}") - engine = _QUERY_MODE_MODEL_DICT[key] - return get_embedding(query, engine=engine, api_key=self.openai_api_key, - api_type=self.openai_api_type, api_version=self.openai_api_version, - api_base=self.openai_api_base) - - def _get_text_embedding(self, text: str) -> List[float]: - """Get text embedding.""" - if self.deployment_name is not None: - engine = self.deployment_name - else: - key = (self.mode, self.model) - if key not in _TEXT_MODE_MODEL_DICT: - raise ValueError(f"Invalid mode, model combination: {key}") - engine = _TEXT_MODE_MODEL_DICT[key] - return get_embedding(text, engine=engine, api_key=self.openai_api_key, - api_type=self.openai_api_type, api_version=self.openai_api_version, - api_base=self.openai_api_base) - - async def _aget_text_embedding(self, text: str) -> List[float]: - """Asynchronously get text embedding.""" - if self.deployment_name is not None: - engine = self.deployment_name - else: - key = (self.mode, self.model) - if key not in _TEXT_MODE_MODEL_DICT: - raise ValueError(f"Invalid mode, model combination: {key}") - engine = _TEXT_MODE_MODEL_DICT[key] - return await aget_embedding(text, engine=engine, api_key=self.openai_api_key, - api_type=self.openai_api_type, api_version=self.openai_api_version, - api_base=self.openai_api_base) - - def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: - """Get text embeddings. - - By default, this is a wrapper around _get_text_embedding. - Can be overriden for batch queries. - - """ - if self.openai_api_type and self.openai_api_type == 'azure': - embeddings = [] - for text in texts: - embeddings.append(self._get_text_embedding(text)) - - return embeddings - - if self.deployment_name is not None: - engine = self.deployment_name - else: - key = (self.mode, self.model) - if key not in _TEXT_MODE_MODEL_DICT: - raise ValueError(f"Invalid mode, model combination: {key}") - engine = _TEXT_MODE_MODEL_DICT[key] - embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key, - api_type=self.openai_api_type, api_version=self.openai_api_version, - api_base=self.openai_api_base) - return embeddings - - async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: - """Asynchronously get text embeddings.""" - if self.openai_api_type and self.openai_api_type == 'azure': - embeddings = [] - for text in texts: - embeddings.append(await self._aget_text_embedding(text)) - - return embeddings - - if self.deployment_name is not None: - engine = self.deployment_name - else: - key = (self.mode, self.model) - if key not in _TEXT_MODE_MODEL_DICT: - raise ValueError(f"Invalid mode, model combination: {key}") - engine = _TEXT_MODE_MODEL_DICT[key] - embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key, - api_type=self.openai_api_type, api_version=self.openai_api_version, - api_base=self.openai_api_base) - return embeddings diff --git a/api/core/index/base.py b/api/core/index/base.py new file mode 100644 index 0000000000..a8755f9182 --- /dev/null +++ b/api/core/index/base.py @@ -0,0 +1,59 @@ +from __future__ import annotations +from abc import abstractmethod, ABC +from typing import List, Any + +from langchain.schema import Document, BaseRetriever + +from models.dataset import Dataset + + +class BaseIndex(ABC): + + def __init__(self, dataset: Dataset): + self.dataset = dataset + + @abstractmethod + def create(self, texts: list[Document], **kwargs) -> BaseIndex: + raise NotImplementedError + + @abstractmethod + def add_texts(self, texts: list[Document], **kwargs): + raise NotImplementedError + + @abstractmethod + def text_exists(self, id: str) -> bool: + raise NotImplementedError + + @abstractmethod + def delete_by_ids(self, ids: list[str]) -> None: + raise NotImplementedError + + @abstractmethod + def delete_by_document_id(self, document_id: str): + raise NotImplementedError + + @abstractmethod + def get_retriever(self, **kwargs: Any) -> BaseRetriever: + raise NotImplementedError + + @abstractmethod + def search( + self, query: str, + **kwargs: Any + ) -> List[Document]: + raise NotImplementedError + + def delete(self) -> None: + raise NotImplementedError + + def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: + for text in texts: + doc_id = text.metadata['doc_id'] + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) + + return texts + + def _get_uuids(self, texts: list[Document]) -> list[str]: + return [text.metadata['doc_id'] for text in texts] diff --git a/api/core/index/index.py b/api/core/index/index.py new file mode 100644 index 0000000000..617b763982 --- /dev/null +++ b/api/core/index/index.py @@ -0,0 +1,41 @@ +from flask import current_app +from langchain.embeddings import OpenAIEmbeddings + +from core.embedding.cached_embedding import CacheEmbedding +from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig +from core.index.vector_index.vector_index import VectorIndex +from core.llm.llm_builder import LLMBuilder +from models.dataset import Dataset + + +class IndexBuilder: + @classmethod + def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False): + if indexing_technique == "high_quality": + if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': + return None + + model_credentials = LLMBuilder.get_model_credentials( + tenant_id=dataset.tenant_id, + model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), + model_name='text-embedding-ada-002' + ) + + embeddings = CacheEmbedding(OpenAIEmbeddings( + **model_credentials + )) + + return VectorIndex( + dataset=dataset, + config=current_app.config, + embeddings=embeddings + ) + elif indexing_technique == "economy": + return KeywordTableIndex( + dataset=dataset, + config=KeywordTableConfig( + max_keywords_per_chunk=10 + ) + ) + else: + raise ValueError('Unknown indexing technique') \ No newline at end of file diff --git a/api/core/index/index_builder.py b/api/core/index/index_builder.py deleted file mode 100644 index 05f08075d4..0000000000 --- a/api/core/index/index_builder.py +++ /dev/null @@ -1,60 +0,0 @@ -from langchain.callbacks import CallbackManager -from llama_index import ServiceContext, PromptHelper, LLMPredictor -from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.embedding.openai_embedding import OpenAIEmbedding -from core.llm.llm_builder import LLMBuilder - - -class IndexBuilder: - @classmethod - def get_default_service_context(cls, tenant_id: str) -> ServiceContext: - # set number of output tokens - num_output = 512 - - # only for verbose - callback_manager = CallbackManager([DifyStdOutCallbackHandler()]) - - llm = LLMBuilder.to_llm( - tenant_id=tenant_id, - model_name='text-davinci-003', - temperature=0, - max_tokens=num_output, - callback_manager=callback_manager, - ) - - llm_predictor = LLMPredictor(llm=llm) - - # These parameters here will affect the logic of segmenting the final synthesized response. - # The number of refinement iterations in the synthesis process depends - # on whether the length of the segmented output exceeds the max_input_size. - prompt_helper = PromptHelper( - max_input_size=3500, - num_output=num_output, - max_chunk_overlap=20 - ) - - provider = LLMBuilder.get_default_provider(tenant_id) - - model_credentials = LLMBuilder.get_model_credentials( - tenant_id=tenant_id, - model_provider=provider, - model_name='text-embedding-ada-002' - ) - - return ServiceContext.from_defaults( - llm_predictor=llm_predictor, - prompt_helper=prompt_helper, - embed_model=OpenAIEmbedding(**model_credentials), - ) - - @classmethod - def get_fake_llm_service_context(cls, tenant_id: str) -> ServiceContext: - llm = LLMBuilder.to_llm( - tenant_id=tenant_id, - model_name='fake' - ) - - return ServiceContext.from_defaults( - llm_predictor=LLMPredictor(llm=llm), - embed_model=OpenAIEmbedding() - ) diff --git a/api/core/index/keyword_table/jieba_keyword_table.py b/api/core/index/keyword_table/jieba_keyword_table.py deleted file mode 100644 index 89dcca5802..0000000000 --- a/api/core/index/keyword_table/jieba_keyword_table.py +++ /dev/null @@ -1,159 +0,0 @@ -import re -from typing import ( - Any, - Dict, - List, - Set, - Optional -) - -import jieba.analyse - -from core.index.keyword_table.stopwords import STOPWORDS -from llama_index.indices.query.base import IS -from llama_index import QueryMode -from llama_index.indices.base import QueryMap -from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex -from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery -from llama_index.docstore import BaseDocumentStore -from llama_index.indices.postprocessor.node import ( - BaseNodePostprocessor, -) -from llama_index.indices.response.response_builder import ResponseMode -from llama_index.indices.service_context import ServiceContext -from llama_index.optimization.optimizer import BaseTokenUsageOptimizer -from llama_index.prompts.prompts import ( - QuestionAnswerPrompt, - RefinePrompt, - SimpleInputPrompt, -) - -from core.index.query.synthesizer import EnhanceResponseSynthesizer - - -def jieba_extract_keywords( - text_chunk: str, - max_keywords: Optional[int] = None, - expand_with_subtokens: bool = True, -) -> Set[str]: - """Extract keywords with JIEBA tfidf.""" - keywords = jieba.analyse.extract_tags( - sentence=text_chunk, - topK=max_keywords, - ) - - if expand_with_subtokens: - return set(expand_tokens_with_subtokens(keywords)) - else: - return set(keywords) - - -def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]: - """Get subtokens from a list of tokens., filtering for stopwords.""" - results = set() - for token in tokens: - results.add(token) - sub_tokens = re.findall(r"\w+", token) - if len(sub_tokens) > 1: - results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) - - return results - - -class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex): - """GPT JIEBA Keyword Table Index. - - This index uses a JIEBA keyword extractor to extract keywords from the text. - - """ - - def _extract_keywords(self, text: str) -> Set[str]: - """Extract keywords from text.""" - return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk) - - @classmethod - def get_query_map(self) -> QueryMap: - """Get query map.""" - super_map = super().get_query_map() - super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery - return super_map - - def _delete(self, doc_id: str, **delete_kwargs: Any) -> None: - """Delete a document.""" - # get set of ids that correspond to node - node_idxs_to_delete = {doc_id} - - # delete node_idxs from keyword to node idxs mapping - keywords_to_delete = set() - for keyword, node_idxs in self._index_struct.table.items(): - if node_idxs_to_delete.intersection(node_idxs): - self._index_struct.table[keyword] = node_idxs.difference( - node_idxs_to_delete - ) - if not self._index_struct.table[keyword]: - keywords_to_delete.add(keyword) - - for keyword in keywords_to_delete: - del self._index_struct.table[keyword] - - -class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery): - """GPT Keyword Table Index JIEBA Query. - - Extracts keywords using JIEBA keyword extractor. - Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`. - - .. code-block:: python - - response = index.query("", mode="jieba") - - See BaseGPTKeywordTableQuery for arguments. - - """ - - @classmethod - def from_args( - cls, - index_struct: IS, - service_context: ServiceContext, - docstore: Optional[BaseDocumentStore] = None, - node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, - verbose: bool = False, - # response synthesizer args - response_mode: ResponseMode = ResponseMode.DEFAULT, - text_qa_template: Optional[QuestionAnswerPrompt] = None, - refine_template: Optional[RefinePrompt] = None, - simple_template: Optional[SimpleInputPrompt] = None, - response_kwargs: Optional[Dict] = None, - use_async: bool = False, - streaming: bool = False, - optimizer: Optional[BaseTokenUsageOptimizer] = None, - # class-specific args - **kwargs: Any, - ) -> "BaseGPTIndexQuery": - response_synthesizer = EnhanceResponseSynthesizer.from_args( - service_context=service_context, - text_qa_template=text_qa_template, - refine_template=refine_template, - simple_template=simple_template, - response_mode=response_mode, - response_kwargs=response_kwargs, - use_async=use_async, - streaming=streaming, - optimizer=optimizer, - ) - return cls( - index_struct=index_struct, - service_context=service_context, - response_synthesizer=response_synthesizer, - docstore=docstore, - node_postprocessors=node_postprocessors, - verbose=verbose, - **kwargs, - ) - - def _get_keywords(self, query_str: str) -> List[str]: - """Extract keywords.""" - return list( - jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query) - ) diff --git a/api/core/index/keyword_table_index.py b/api/core/index/keyword_table_index.py deleted file mode 100644 index f0b3905557..0000000000 --- a/api/core/index/keyword_table_index.py +++ /dev/null @@ -1,135 +0,0 @@ -import json -from typing import List, Optional - -from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding -from llama_index.data_structs import KeywordTable, Node -from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex -from llama_index.indices.registry import load_index_struct_from_dict - -from core.docstore.dataset_docstore import DatesetDocumentStore -from core.docstore.empty_docstore import EmptyDocumentStore -from core.index.index_builder import IndexBuilder -from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex -from core.llm.llm_builder import LLMBuilder -from extensions.ext_database import db -from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment - - -class KeywordTableIndex: - - def __init__(self, dataset: Dataset): - self._dataset = dataset - - def add_nodes(self, nodes: List[Node]): - llm = LLMBuilder.to_llm( - tenant_id=self._dataset.tenant_id, - model_name='fake' - ) - - service_context = ServiceContext.from_defaults( - llm_predictor=LLMPredictor(llm=llm), - embed_model=OpenAIEmbedding() - ) - - dataset_keyword_table = self.get_keyword_table() - if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict: - index_struct = KeywordTable() - else: - index_struct_dict = dataset_keyword_table.keyword_table_dict - index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict) - - # create index - index = GPTJIEBAKeywordTableIndex( - index_struct=index_struct, - docstore=EmptyDocumentStore(), - service_context=service_context - ) - - for node in nodes: - keywords = index._extract_keywords(node.get_text()) - self.update_segment_keywords(node.doc_id, list(keywords)) - index._index_struct.add_node(list(keywords), node) - - index_struct_dict = index.index_struct.to_dict() - - if not dataset_keyword_table: - dataset_keyword_table = DatasetKeywordTable( - dataset_id=self._dataset.id, - keyword_table=json.dumps(index_struct_dict) - ) - db.session.add(dataset_keyword_table) - else: - dataset_keyword_table.keyword_table = json.dumps(index_struct_dict) - - db.session.commit() - - def del_nodes(self, node_ids: List[str]): - llm = LLMBuilder.to_llm( - tenant_id=self._dataset.tenant_id, - model_name='fake' - ) - - service_context = ServiceContext.from_defaults( - llm_predictor=LLMPredictor(llm=llm), - embed_model=OpenAIEmbedding() - ) - - dataset_keyword_table = self.get_keyword_table() - if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict: - return - else: - index_struct_dict = dataset_keyword_table.keyword_table_dict - index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict) - - # create index - index = GPTJIEBAKeywordTableIndex( - index_struct=index_struct, - docstore=EmptyDocumentStore(), - service_context=service_context - ) - - for node_id in node_ids: - index.delete(node_id) - - index_struct_dict = index.index_struct.to_dict() - - if not dataset_keyword_table: - dataset_keyword_table = DatasetKeywordTable( - dataset_id=self._dataset.id, - keyword_table=json.dumps(index_struct_dict) - ) - db.session.add(dataset_keyword_table) - else: - dataset_keyword_table.keyword_table = json.dumps(index_struct_dict) - - db.session.commit() - - @property - def query_index(self) -> Optional[BaseGPTKeywordTableIndex]: - docstore = DatesetDocumentStore( - dataset=self._dataset, - user_id=self._dataset.created_by, - embedding_model_name="text-embedding-ada-002" - ) - - service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) - - dataset_keyword_table = self.get_keyword_table() - if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict: - return None - - index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict) - - return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context) - - def get_keyword_table(self): - dataset_keyword_table = self._dataset.dataset_keyword_table - if dataset_keyword_table: - return dataset_keyword_table - return None - - def update_segment_keywords(self, node_id: str, keywords: List[str]): - document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first() - if document_segment: - document_segment.keywords = keywords - db.session.commit() diff --git a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py b/api/core/index/keyword_table_index/jieba_keyword_table_handler.py new file mode 100644 index 0000000000..db9fd027a0 --- /dev/null +++ b/api/core/index/keyword_table_index/jieba_keyword_table_handler.py @@ -0,0 +1,33 @@ +import re +from typing import Set + +import jieba +from jieba.analyse import default_tfidf + +from core.index.keyword_table_index.stopwords import STOPWORDS + + +class JiebaKeywordTableHandler: + + def __init__(self): + default_tfidf.stop_words = STOPWORDS + + def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]: + """Extract keywords with JIEBA tfidf.""" + keywords = jieba.analyse.extract_tags( + sentence=text, + topK=max_keywords_per_chunk, + ) + + return set(self._expand_tokens_with_subtokens(keywords)) + + def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]: + """Get subtokens from a list of tokens., filtering for stopwords.""" + results = set() + for token in tokens: + results.add(token) + sub_tokens = re.findall(r"\w+", token) + if len(sub_tokens) > 1: + results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) + + return results \ No newline at end of file diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/index/keyword_table_index/keyword_table_index.py new file mode 100644 index 0000000000..1a205cd572 --- /dev/null +++ b/api/core/index/keyword_table_index/keyword_table_index.py @@ -0,0 +1,238 @@ +import json +from collections import defaultdict +from typing import Any, List, Optional, Dict + +from langchain.schema import Document, BaseRetriever +from pydantic import BaseModel, Field, Extra + +from core.index.base import BaseIndex +from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler +from extensions.ext_database import db +from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable + + +class KeywordTableConfig(BaseModel): + max_keywords_per_chunk: int = 10 + + +class KeywordTableIndex(BaseIndex): + def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()): + super().__init__(dataset) + self._config = config + + def create(self, texts: list[Document], **kwargs) -> BaseIndex: + keyword_table_handler = JiebaKeywordTableHandler() + keyword_table = {} + for text in texts: + keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) + self._update_segment_keywords(text.metadata['doc_id'], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + + dataset_keyword_table = DatasetKeywordTable( + dataset_id=self.dataset.id, + keyword_table=json.dumps({ + '__type__': 'keyword_table', + '__data__': { + "index_id": self.dataset.id, + "summary": None, + "table": {} + } + }, cls=SetEncoder) + ) + db.session.add(dataset_keyword_table) + db.session.commit() + + self._save_dataset_keyword_table(keyword_table) + + return self + + def add_texts(self, texts: list[Document], **kwargs): + keyword_table_handler = JiebaKeywordTableHandler() + + keyword_table = self._get_dataset_keyword_table() + for text in texts: + keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) + self._update_segment_keywords(text.metadata['doc_id'], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + + self._save_dataset_keyword_table(keyword_table) + + def text_exists(self, id: str) -> bool: + keyword_table = self._get_dataset_keyword_table() + return id in set.union(*keyword_table.values()) + + def delete_by_ids(self, ids: list[str]) -> None: + keyword_table = self._get_dataset_keyword_table() + keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) + + self._save_dataset_keyword_table(keyword_table) + + def delete_by_document_id(self, document_id: str): + # get segment ids by document_id + segments = db.session.query(DocumentSegment).filter( + DocumentSegment.dataset_id == self.dataset.id, + DocumentSegment.document_id == document_id + ).all() + + ids = [segment.id for segment in segments] + + keyword_table = self._get_dataset_keyword_table() + keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) + + self._save_dataset_keyword_table(keyword_table) + + def get_retriever(self, **kwargs: Any) -> BaseRetriever: + return KeywordTableRetriever(index=self, **kwargs) + + def search( + self, query: str, + **kwargs: Any + ) -> List[Document]: + keyword_table = self._get_dataset_keyword_table() + + search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} + k = search_kwargs.get('k') if search_kwargs.get('k') else 4 + + sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) + + documents = [] + for chunk_index in sorted_chunk_indices: + segment = db.session.query(DocumentSegment).filter( + DocumentSegment.dataset_id == self.dataset.id, + DocumentSegment.index_node_id == chunk_index + ).first() + + if segment: + documents.append(Document( + page_content=segment.content, + metadata={ + "doc_id": chunk_index, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + )) + + return documents + + def delete(self) -> None: + dataset_keyword_table = self.dataset.dataset_keyword_table + if dataset_keyword_table: + db.session.delete(dataset_keyword_table) + db.session.commit() + + def _save_dataset_keyword_table(self, keyword_table): + keyword_table_dict = { + '__type__': 'keyword_table', + '__data__': { + "index_id": self.dataset.id, + "summary": None, + "table": keyword_table + } + } + self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) + db.session.commit() + + def _get_dataset_keyword_table(self) -> Optional[dict]: + dataset_keyword_table = self.dataset.dataset_keyword_table + if dataset_keyword_table: + if dataset_keyword_table.keyword_table_dict: + return dataset_keyword_table.keyword_table_dict['__data__']['table'] + else: + dataset_keyword_table = DatasetKeywordTable( + dataset_id=self.dataset.id, + keyword_table=json.dumps({ + '__type__': 'keyword_table', + '__data__': { + "index_id": self.dataset.id, + "summary": None, + "table": {} + } + }, cls=SetEncoder) + ) + db.session.add(dataset_keyword_table) + db.session.commit() + + return {} + + def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict: + for keyword in keywords: + if keyword not in keyword_table: + keyword_table[keyword] = set() + keyword_table[keyword].add(id) + return keyword_table + + def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict: + # get set of ids that correspond to node + node_idxs_to_delete = set(ids) + + # delete node_idxs from keyword to node idxs mapping + keywords_to_delete = set() + for keyword, node_idxs in keyword_table.items(): + if node_idxs_to_delete.intersection(node_idxs): + keyword_table[keyword] = node_idxs.difference( + node_idxs_to_delete + ) + if not keyword_table[keyword]: + keywords_to_delete.add(keyword) + + for keyword in keywords_to_delete: + del keyword_table[keyword] + + return keyword_table + + def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4): + keyword_table_handler = JiebaKeywordTableHandler() + keywords = keyword_table_handler.extract_keywords(query) + + # go through text chunks in order of most matching keywords + chunk_indices_count: Dict[str, int] = defaultdict(int) + keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] + for keyword in keywords: + for node_id in keyword_table[keyword]: + chunk_indices_count[node_id] += 1 + + sorted_chunk_indices = sorted( + list(chunk_indices_count.keys()), + key=lambda x: chunk_indices_count[x], + reverse=True, + ) + + return sorted_chunk_indices[: k] + + def _update_segment_keywords(self, node_id: str, keywords: List[str]): + document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first() + if document_segment: + document_segment.keywords = keywords + db.session.commit() + + +class KeywordTableRetriever(BaseRetriever, BaseModel): + index: KeywordTableIndex + search_kwargs: dict = Field(default_factory=dict) + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + def get_relevant_documents(self, query: str) -> List[Document]: + """Get documents relevant for a query. + + Args: + query: string to find relevant documents for + + Returns: + List of relevant documents + """ + return self.index.search(query, **self.search_kwargs) + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError("KeywordTableRetriever does not support async") + + +class SetEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, set): + return list(obj) + return super().default(obj) diff --git a/api/core/index/keyword_table/stopwords.py b/api/core/index/keyword_table_index/stopwords.py similarity index 100% rename from api/core/index/keyword_table/stopwords.py rename to api/core/index/keyword_table_index/stopwords.py diff --git a/api/core/index/query/synthesizer.py b/api/core/index/query/synthesizer.py deleted file mode 100644 index 7ab8b4a8ca..0000000000 --- a/api/core/index/query/synthesizer.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import ( - Any, - Dict, - Optional, Sequence, -) - -from llama_index.indices.response.response_synthesis import ResponseSynthesizer -from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder -from llama_index.indices.service_context import ServiceContext -from llama_index.optimization.optimizer import BaseTokenUsageOptimizer -from llama_index.prompts.prompts import ( - QuestionAnswerPrompt, - RefinePrompt, - SimpleInputPrompt, -) -from llama_index.types import RESPONSE_TEXT_TYPE - - -class EnhanceResponseSynthesizer(ResponseSynthesizer): - @classmethod - def from_args( - cls, - service_context: ServiceContext, - streaming: bool = False, - use_async: bool = False, - text_qa_template: Optional[QuestionAnswerPrompt] = None, - refine_template: Optional[RefinePrompt] = None, - simple_template: Optional[SimpleInputPrompt] = None, - response_mode: ResponseMode = ResponseMode.DEFAULT, - response_kwargs: Optional[Dict] = None, - optimizer: Optional[BaseTokenUsageOptimizer] = None, - ) -> "ResponseSynthesizer": - response_builder: Optional[BaseResponseBuilder] = None - if response_mode != ResponseMode.NO_TEXT: - if response_mode == 'no_synthesizer': - response_builder = NoSynthesizer( - service_context=service_context, - simple_template=simple_template, - streaming=streaming, - ) - else: - response_builder = get_response_builder( - service_context, - text_qa_template, - refine_template, - simple_template, - response_mode, - use_async=use_async, - streaming=streaming, - ) - return cls(response_builder, response_mode, response_kwargs, optimizer) - - -class NoSynthesizer(BaseResponseBuilder): - def __init__( - self, - service_context: ServiceContext, - simple_template: Optional[SimpleInputPrompt] = None, - streaming: bool = False, - ) -> None: - super().__init__(service_context, streaming) - - async def aget_response( - self, - query_str: str, - text_chunks: Sequence[str], - prev_response: Optional[str] = None, - **response_kwargs: Any, - ) -> RESPONSE_TEXT_TYPE: - return "\n".join(text_chunks) - - def get_response( - self, - query_str: str, - text_chunks: Sequence[str], - prev_response: Optional[str] = None, - **response_kwargs: Any, - ) -> RESPONSE_TEXT_TYPE: - return "\n".join(text_chunks) \ No newline at end of file diff --git a/api/core/index/readers/html_parser.py b/api/core/index/readers/html_parser.py deleted file mode 100644 index 2afadb284c..0000000000 --- a/api/core/index/readers/html_parser.py +++ /dev/null @@ -1,22 +0,0 @@ -from pathlib import Path -from typing import Dict - -from bs4 import BeautifulSoup -from llama_index.readers.file.base_parser import BaseParser - - -class HTMLParser(BaseParser): - """HTML parser.""" - - def _init_parser(self) -> Dict: - """Init parser.""" - return {} - - def parse_file(self, file: Path, errors: str = "ignore") -> str: - """Parse file.""" - with open(file, "rb") as fp: - soup = BeautifulSoup(fp, 'html.parser') - text = soup.get_text() - text = text.strip() if text else '' - - return text diff --git a/api/core/index/readers/markdown_parser.py b/api/core/index/readers/markdown_parser.py deleted file mode 100644 index e12c06a78c..0000000000 --- a/api/core/index/readers/markdown_parser.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Markdown parser. - -Contains parser for md files. - -""" -import re -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, cast - -from llama_index.readers.file.base_parser import BaseParser - - -class MarkdownParser(BaseParser): - """Markdown parser. - - Extract text from markdown files. - Returns dictionary with keys as headers and values as the text between headers. - - """ - - def __init__( - self, - *args: Any, - remove_hyperlinks: bool = True, - remove_images: bool = True, - **kwargs: Any, - ) -> None: - """Init params.""" - super().__init__(*args, **kwargs) - self._remove_hyperlinks = remove_hyperlinks - self._remove_images = remove_images - - def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: - """Convert a markdown file to a dictionary. - - The keys are the headers and the values are the text under each header. - - """ - markdown_tups: List[Tuple[Optional[str], str]] = [] - lines = markdown_text.split("\n") - - current_header = None - current_text = "" - - for line in lines: - header_match = re.match(r"^#+\s", line) - if header_match: - if current_header is not None: - markdown_tups.append((current_header, current_text)) - - current_header = line - current_text = "" - else: - current_text += line + "\n" - markdown_tups.append((current_header, current_text)) - - if current_header is not None: - # pass linting, assert keys are defined - markdown_tups = [ - (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) - for key, value in markdown_tups - ] - else: - markdown_tups = [ - (key, re.sub("\n", "", value)) for key, value in markdown_tups - ] - - return markdown_tups - - def remove_images(self, content: str) -> str: - """Get a dictionary of a markdown file from its path.""" - pattern = r"!{1}\[\[(.*)\]\]" - content = re.sub(pattern, "", content) - return content - - def remove_hyperlinks(self, content: str) -> str: - """Get a dictionary of a markdown file from its path.""" - pattern = r"\[(.*?)\]\((.*?)\)" - content = re.sub(pattern, r"\1", content) - return content - - def _init_parser(self) -> Dict: - """Initialize the parser with the config.""" - return {} - - def parse_tups( - self, filepath: Path, errors: str = "ignore" - ) -> List[Tuple[Optional[str], str]]: - """Parse file into tuples.""" - with open(filepath, "r", encoding="utf-8") as f: - content = f.read() - if self._remove_hyperlinks: - content = self.remove_hyperlinks(content) - if self._remove_images: - content = self.remove_images(content) - markdown_tups = self.markdown_to_tups(content) - return markdown_tups - - def parse_file( - self, filepath: Path, errors: str = "ignore" - ) -> Union[str, List[str]]: - """Parse file into string.""" - tups = self.parse_tups(filepath, errors=errors) - results = [] - # TODO: don't include headers right now - for header, value in tups: - if header is None: - results.append(value) - else: - results.append(f"\n\n{header}\n{value}") - return results diff --git a/api/core/index/readers/pdf_parser.py b/api/core/index/readers/pdf_parser.py deleted file mode 100644 index 81c4840c60..0000000000 --- a/api/core/index/readers/pdf_parser.py +++ /dev/null @@ -1,56 +0,0 @@ -from pathlib import Path -from typing import Dict - -from flask import current_app -from llama_index.readers.file.base_parser import BaseParser -from pypdf import PdfReader - -from extensions.ext_storage import storage -from models.model import UploadFile - - -class PDFParser(BaseParser): - """PDF parser.""" - - def _init_parser(self) -> Dict: - """Init parser.""" - return {} - - def parse_file(self, file: Path, errors: str = "ignore") -> str: - """Parse file.""" - if not current_app.config.get('PDF_PREVIEW', True): - return '' - - plaintext_file_key = '' - plaintext_file_exists = False - if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']: - upload_file: UploadFile = self._parser_config['upload_file'] - if upload_file.hash: - plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext' - try: - text = storage.load(plaintext_file_key).decode('utf-8') - plaintext_file_exists = True - return text - except FileNotFoundError: - pass - - text_list = [] - with open(file, "rb") as fp: - # Create a PDF object - pdf = PdfReader(fp) - - # Get the number of pages in the PDF document - num_pages = len(pdf.pages) - - # Iterate over every page - for page in range(num_pages): - # Extract the text from the page - page_text = pdf.pages[page].extract_text() - text_list.append(page_text) - text = "\n".join(text_list) - - # save plaintext file for caching - if not plaintext_file_exists and plaintext_file_key: - storage.save(plaintext_file_key, text.encode('utf-8')) - - return text diff --git a/api/core/index/readers/xlsx_parser.py b/api/core/index/readers/xlsx_parser.py deleted file mode 100644 index 6b38a1bf16..0000000000 --- a/api/core/index/readers/xlsx_parser.py +++ /dev/null @@ -1,33 +0,0 @@ -from pathlib import Path -import json -from typing import Dict -from openpyxl import load_workbook - -from llama_index.readers.file.base_parser import BaseParser -from flask import current_app - - -class XLSXParser(BaseParser): - """XLSX parser.""" - - def _init_parser(self) -> Dict: - """Init parser""" - return {} - - def parse_file(self, file: Path, errors: str = "ignore") -> str: - data = [] - keys = [] - with open(file, "r") as fp: - wb = load_workbook(filename=file, read_only=True) - # loop over all sheets - for sheet in wb: - for row in sheet.iter_rows(values_only=True): - if all(v is None for v in row): - continue - if keys == []: - keys = list(map(str, row)) - else: - row_dict = dict(zip(keys, row)) - row_dict = {k: v for k, v in row_dict.items() if v} - data.append(json.dumps(row_dict, ensure_ascii=False)) - return '\n\n'.join(data) diff --git a/api/core/index/vector_index.py b/api/core/index/vector_index.py deleted file mode 100644 index fa1c93cc06..0000000000 --- a/api/core/index/vector_index.py +++ /dev/null @@ -1,136 +0,0 @@ -import json -import logging -from typing import List, Optional - -from llama_index.data_structs import Node -from requests import ReadTimeout -from sqlalchemy.exc import IntegrityError -from tenacity import retry, stop_after_attempt, retry_if_exception_type - -from core.index.index_builder import IndexBuilder -from core.vector_store.base import BaseGPTVectorStoreIndex -from extensions.ext_vector_store import vector_store -from extensions.ext_database import db -from models.dataset import Dataset, Embedding - - -class VectorIndex: - - def __init__(self, dataset: Dataset): - self._dataset = dataset - - def add_nodes(self, nodes: List[Node], duplicate_check: bool = False): - if not self._dataset.index_struct_dict: - index_id = "Vector_index_" + self._dataset.id.replace("-", "_") - self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id)) - db.session.commit() - - service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) - - index = vector_store.get_index( - service_context=service_context, - index_struct=self._dataset.index_struct_dict - ) - - if duplicate_check: - nodes = self._filter_duplicate_nodes(index, nodes) - - embedding_queue_nodes = [] - embedded_nodes = [] - for node in nodes: - node_hash = node.doc_hash - - # if node hash in cached embedding tables, use cached embedding - embedding = db.session.query(Embedding).filter_by(hash=node_hash).first() - if embedding: - node.embedding = embedding.get_embedding() - embedded_nodes.append(node) - else: - embedding_queue_nodes.append(node) - - if embedding_queue_nodes: - embedding_results = index._get_node_embedding_results( - embedding_queue_nodes, - set(), - ) - - # pre embed nodes for cached embedding - for embedding_result in embedding_results: - node = embedding_result.node - node.embedding = embedding_result.embedding - - try: - embedding = Embedding(hash=node.doc_hash) - embedding.set_embedding(node.embedding) - db.session.add(embedding) - db.session.commit() - except IntegrityError: - db.session.rollback() - continue - except: - logging.exception('Failed to add embedding to db') - continue - - embedded_nodes.append(node) - - self.index_insert_nodes(index, embedded_nodes) - - @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) - def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]): - index.insert_nodes(nodes) - - def del_nodes(self, node_ids: List[str]): - if not self._dataset.index_struct_dict: - return - - service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id) - - index = vector_store.get_index( - service_context=service_context, - index_struct=self._dataset.index_struct_dict - ) - - for node_id in node_ids: - self.index_delete_node(index, node_id) - - @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) - def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str): - index.delete_node(node_id) - - def del_doc(self, doc_id: str): - if not self._dataset.index_struct_dict: - return - - service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id) - - index = vector_store.get_index( - service_context=service_context, - index_struct=self._dataset.index_struct_dict - ) - - self.index_delete_doc(index, doc_id) - - @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) - def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str): - index.delete(doc_id) - - @property - def query_index(self) -> Optional[BaseGPTVectorStoreIndex]: - if not self._dataset.index_struct_dict: - return None - - service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) - - return vector_store.get_index( - service_context=service_context, - index_struct=self._dataset.index_struct_dict - ) - - def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]: - for node in nodes: - node_id = node.doc_id - exists_duplicate_node = index.exists_by_node_id(node_id) - if exists_duplicate_node: - nodes.remove(node) - - return nodes diff --git a/api/core/index/vector_index/base.py b/api/core/index/vector_index/base.py new file mode 100644 index 0000000000..b33b3e8665 --- /dev/null +++ b/api/core/index/vector_index/base.py @@ -0,0 +1,175 @@ +import json +import logging +from abc import abstractmethod +from typing import List, Any, cast + +from langchain.embeddings.base import Embeddings +from langchain.schema import Document, BaseRetriever +from langchain.vectorstores import VectorStore +from weaviate import UnexpectedStatusCodeException + +from core.index.base import BaseIndex +from extensions.ext_database import db +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +class BaseVectorIndex(BaseIndex): + + def __init__(self, dataset: Dataset, embeddings: Embeddings): + super().__init__(dataset) + self._embeddings = embeddings + self._vector_store = None + + def get_type(self) -> str: + raise NotImplementedError + + @abstractmethod + def get_index_name(self, dataset: Dataset) -> str: + raise NotImplementedError + + @abstractmethod + def to_index_struct(self) -> dict: + raise NotImplementedError + + @abstractmethod + def _get_vector_store(self) -> VectorStore: + raise NotImplementedError + + @abstractmethod + def _get_vector_store_class(self) -> type: + raise NotImplementedError + + def search( + self, query: str, + **kwargs: Any + ) -> List[Document]: + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity' + search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} + + if search_type == 'similarity_score_threshold': + score_threshold = search_kwargs.get("score_threshold") + if (score_threshold is None) or (not isinstance(score_threshold, float)): + search_kwargs['score_threshold'] = .0 + + docs_with_similarity = vector_store.similarity_search_with_relevance_scores( + query, **search_kwargs + ) + + docs = [] + for doc, similarity in docs_with_similarity: + doc.metadata['score'] = similarity + docs.append(doc) + + return docs + + # similarity k + # mmr k, fetch_k, lambda_mult + # similarity_score_threshold k + return vector_store.as_retriever( + search_type=search_type, + search_kwargs=search_kwargs + ).get_relevant_documents(query) + + def get_retriever(self, **kwargs: Any) -> BaseRetriever: + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + return vector_store.as_retriever(**kwargs) + + def add_texts(self, texts: list[Document], **kwargs): + if self._is_origin(): + self.recreate_dataset(self.dataset) + + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + if kwargs.get('duplicate_check', False): + texts = self._filter_duplicate_texts(texts) + + uuids = self._get_uuids(texts) + vector_store.add_documents(texts, uuids=uuids) + + def text_exists(self, id: str) -> bool: + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + return vector_store.text_exists(id) + + def delete_by_ids(self, ids: list[str]) -> None: + if self._is_origin(): + self.recreate_dataset(self.dataset) + return + + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + for node_id in ids: + vector_store.del_text(node_id) + + def delete(self) -> None: + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + vector_store.delete() + + def _is_origin(self): + return False + + def recreate_dataset(self, dataset: Dataset): + logging.info(f"Recreating dataset {dataset.id}") + + try: + self.delete() + except UnexpectedStatusCodeException as e: + if e.status_code != 400: + # 400 means index not exists + raise e + + dataset_documents = db.session.query(DatasetDocument).filter( + DatasetDocument.dataset_id == dataset.id, + DatasetDocument.indexing_status == 'completed', + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ).all() + + documents = [] + for dataset_document in dataset_documents: + segments = db.session.query(DocumentSegment).filter( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True + ).all() + + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + + documents.append(document) + + origin_index_struct = self.dataset.index_struct + self.dataset.index_struct = None + + if documents: + try: + self.create(documents) + except Exception as e: + self.dataset.index_struct = origin_index_struct + raise e + + dataset.index_struct = json.dumps(self.to_index_struct()) + + db.session.commit() + + self.dataset = dataset + logging.info(f"Dataset {dataset.id} recreate successfully.") diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py new file mode 100644 index 0000000000..f354a86692 --- /dev/null +++ b/api/core/index/vector_index/qdrant_vector_index.py @@ -0,0 +1,116 @@ +import os +from typing import Optional, Any, List, cast + +import qdrant_client +from langchain.embeddings.base import Embeddings +from langchain.schema import Document, BaseRetriever +from langchain.vectorstores import VectorStore +from pydantic import BaseModel + +from core.index.base import BaseIndex +from core.index.vector_index.base import BaseVectorIndex +from core.vector_store.qdrant_vector_store import QdrantVectorStore +from models.dataset import Dataset + + +class QdrantConfig(BaseModel): + endpoint: str + api_key: Optional[str] + root_path: Optional[str] + + def to_qdrant_params(self): + if self.endpoint and self.endpoint.startswith('path:'): + path = self.endpoint.replace('path:', '') + if not os.path.isabs(path): + path = os.path.join(self.root_path, path) + + return { + 'path': path + } + else: + return { + 'url': self.endpoint, + 'api_key': self.api_key, + } + + +class QdrantVectorIndex(BaseVectorIndex): + def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings): + super().__init__(dataset, embeddings) + self._client_config = config + + def get_type(self) -> str: + return 'qdrant' + + def get_index_name(self, dataset: Dataset) -> str: + if self.dataset.index_struct_dict: + return self.dataset.index_struct_dict['vector_store']['collection_name'] + + dataset_id = dataset.id + return "Index_" + dataset_id.replace("-", "_") + + def to_index_struct(self) -> dict: + return { + "type": self.get_type(), + "vector_store": {"collection_name": self.get_index_name(self.dataset)} + } + + def create(self, texts: list[Document], **kwargs) -> BaseIndex: + uuids = self._get_uuids(texts) + self._vector_store = QdrantVectorStore.from_documents( + texts, + self._embeddings, + collection_name=self.get_index_name(self.dataset), + ids=uuids, + content_payload_key='text', + **self._client_config.to_qdrant_params() + ) + + return self + + def _get_vector_store(self) -> VectorStore: + """Only for created index.""" + if self._vector_store: + return self._vector_store + + client = qdrant_client.QdrantClient( + **self._client_config.to_qdrant_params() + ) + + return QdrantVectorStore( + client=client, + collection_name=self.get_index_name(self.dataset), + embeddings=self._embeddings, + content_payload_key='text' + ) + + def _get_vector_store_class(self) -> type: + return QdrantVectorStore + + def delete_by_document_id(self, document_id: str): + if self._is_origin(): + self.recreate_dataset(self.dataset) + return + + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + from qdrant_client.http import models + + vector_store.del_texts(models.Filter( + must=[ + models.FieldCondition( + key="metadata.document_id", + match=models.MatchValue(value=document_id), + ), + ], + )) + + def _is_origin(self): + if self.dataset.index_struct_dict: + class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name'] + if class_prefix.startswith('Vector_'): + # original class_prefix + return True + + return False diff --git a/api/core/index/vector_index/vector_index.py b/api/core/index/vector_index/vector_index.py new file mode 100644 index 0000000000..ffc7aa17b6 --- /dev/null +++ b/api/core/index/vector_index/vector_index.py @@ -0,0 +1,69 @@ +import json + +from flask import current_app +from langchain.embeddings.base import Embeddings + +from core.index.vector_index.base import BaseVectorIndex +from extensions.ext_database import db +from models.dataset import Dataset, Document + + +class VectorIndex: + def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings): + self._dataset = dataset + self._embeddings = embeddings + self._vector_index = self._init_vector_index(dataset, config, embeddings) + + def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex: + vector_type = config.get('VECTOR_STORE') + + if self._dataset.index_struct_dict: + vector_type = self._dataset.index_struct_dict['type'] + + if not vector_type: + raise ValueError(f"Vector store must be specified.") + + if vector_type == "weaviate": + from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig + + return WeaviateVectorIndex( + dataset=dataset, + config=WeaviateConfig( + endpoint=config.get('WEAVIATE_ENDPOINT'), + api_key=config.get('WEAVIATE_API_KEY'), + batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) + ), + embeddings=embeddings + ) + elif vector_type == "qdrant": + from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig + + return QdrantVectorIndex( + dataset=dataset, + config=QdrantConfig( + endpoint=config.get('QDRANT_URL'), + api_key=config.get('QDRANT_API_KEY'), + root_path=current_app.root_path + ), + embeddings=embeddings + ) + else: + raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") + + def add_texts(self, texts: list[Document], **kwargs): + if not self._dataset.index_struct_dict: + self._vector_index.create(texts, **kwargs) + self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct()) + db.session.commit() + return + + self._vector_index.add_texts(texts, **kwargs) + + def __getattr__(self, name): + if self._vector_index is not None: + method = getattr(self._vector_index, name) + if callable(method): + return method + + raise AttributeError(f"'VectorIndex' object has no attribute '{name}'") + diff --git a/api/core/index/vector_index/weaviate_vector_index.py b/api/core/index/vector_index/weaviate_vector_index.py new file mode 100644 index 0000000000..e9eae4468c --- /dev/null +++ b/api/core/index/vector_index/weaviate_vector_index.py @@ -0,0 +1,132 @@ +from typing import Optional, cast + +import weaviate +from langchain.embeddings.base import Embeddings +from langchain.schema import Document, BaseRetriever +from langchain.vectorstores import VectorStore +from pydantic import BaseModel, root_validator + +from core.index.base import BaseIndex +from core.index.vector_index.base import BaseVectorIndex +from core.vector_store.weaviate_vector_store import WeaviateVectorStore +from models.dataset import Dataset + + +class WeaviateConfig(BaseModel): + endpoint: str + api_key: Optional[str] + batch_size: int = 100 + + @root_validator() + def validate_config(cls, values: dict) -> dict: + if not values['endpoint']: + raise ValueError("config WEAVIATE_ENDPOINT is required") + return values + + +class WeaviateVectorIndex(BaseVectorIndex): + def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings): + super().__init__(dataset, embeddings) + self._client = self._init_client(config) + + def _init_client(self, config: WeaviateConfig) -> weaviate.Client: + auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) + + weaviate.connect.connection.has_grpc = False + + client = weaviate.Client( + url=config.endpoint, + auth_client_secret=auth_config, + timeout_config=(5, 60), + startup_period=None + ) + + client.batch.configure( + # `batch_size` takes an `int` value to enable auto-batching + # (`None` is used for manual batching) + batch_size=config.batch_size, + # dynamically update the `batch_size` based on import speed + dynamic=True, + # `timeout_retries` takes an `int` value to retry on time outs + timeout_retries=3, + ) + + return client + + def get_type(self) -> str: + return 'weaviate' + + def get_index_name(self, dataset: Dataset) -> str: + if self.dataset.index_struct_dict: + class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] + if not class_prefix.endswith('_Node'): + # original class_prefix + class_prefix += '_Node' + + return class_prefix + + dataset_id = dataset.id + return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + + def to_index_struct(self) -> dict: + return { + "type": self.get_type(), + "vector_store": {"class_prefix": self.get_index_name(self.dataset)} + } + + def create(self, texts: list[Document], **kwargs) -> BaseIndex: + uuids = self._get_uuids(texts) + self._vector_store = WeaviateVectorStore.from_documents( + texts, + self._embeddings, + client=self._client, + index_name=self.get_index_name(self.dataset), + uuids=uuids, + by_text=False + ) + + return self + + def _get_vector_store(self) -> VectorStore: + """Only for created index.""" + if self._vector_store: + return self._vector_store + + attributes = ['doc_id', 'dataset_id', 'document_id'] + if self._is_origin(): + attributes = ['doc_id'] + + return WeaviateVectorStore( + client=self._client, + index_name=self.get_index_name(self.dataset), + text_key='text', + embedding=self._embeddings, + attributes=attributes, + by_text=False + ) + + def _get_vector_store_class(self) -> type: + return WeaviateVectorStore + + def delete_by_document_id(self, document_id: str): + if self._is_origin(): + self.recreate_dataset(self.dataset) + return + + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + + vector_store.del_texts({ + "operator": "Equal", + "path": ["document_id"], + "valueText": document_id + }) + + def _is_origin(self): + if self.dataset.index_struct_dict: + class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] + if not class_prefix.endswith('_Node'): + # original class_prefix + return True + + return False diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 319a6bad11..f78f21a157 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -1,35 +1,34 @@ import datetime import json +import logging import re -import tempfile import time -from pathlib import Path -from typing import Optional, List +import uuid +from typing import Optional, List, cast +from flask import current_app from flask_login import current_user -from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.embeddings import OpenAIEmbeddings +from langchain.schema import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter -from llama_index import SimpleDirectoryReader -from llama_index.data_structs import Node -from llama_index.data_structs.node_v2 import DocumentRelationship -from llama_index.node_parser import SimpleNodeParser, NodeParser -from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR -from llama_index.readers.file.markdown_parser import MarkdownParser - -from core.data_source.notion import NotionPageReader -from core.index.readers.xlsx_parser import XLSXParser +from core.data_loader.file_extractor import FileExtractor +from core.data_loader.loader.notion import NotionLoader from core.docstore.dataset_docstore import DatesetDocumentStore -from core.index.keyword_table_index import KeywordTableIndex -from core.index.readers.html_parser import HTMLParser -from core.index.readers.markdown_parser import MarkdownParser -from core.index.readers.pdf_parser import PDFParser -from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter -from core.index.vector_index import VectorIndex +from core.embedding.cached_embedding import CacheEmbedding +from core.index.index import IndexBuilder +from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig +from core.index.vector_index.vector_index import VectorIndex +from core.llm.error import ProviderTokenNotInitError +from core.llm.llm_builder import LLMBuilder +from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter from core.llm.token_calculator import TokenCalculator from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule +from libs import helper +from models.dataset import Document as DatasetDocument +from models.dataset import Dataset, DocumentSegment, DatasetProcessRule from models.model import UploadFile from models.source import DataSourceBinding @@ -40,135 +39,171 @@ class IndexingRunner: self.storage = storage self.embedding_model_name = embedding_model_name - def run(self, documents: List[Document]): + def run(self, dataset_documents: List[DatasetDocument]): """Run the indexing process.""" - for document in documents: + for dataset_document in dataset_documents: + try: + # get dataset + dataset = Dataset.query.filter_by( + id=dataset_document.dataset_id + ).first() + + if not dataset: + raise ValueError("no dataset found") + + # load file + text_docs = self._load_data(dataset_document) + + # get the process rule + processing_rule = db.session.query(DatasetProcessRule). \ + filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ + first() + + # get splitter + splitter = self._get_splitter(processing_rule) + + # split to documents + documents = self._step_split( + text_docs=text_docs, + splitter=splitter, + dataset=dataset, + dataset_document=dataset_document, + processing_rule=processing_rule + ) + + # build index + self._build_index( + dataset=dataset, + dataset_document=dataset_document, + documents=documents + ) + except DocumentIsPausedException: + raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + except ProviderTokenNotInitError as e: + dataset_document.indexing_status = 'error' + dataset_document.error = str(e.description) + dataset_document.stopped_at = datetime.datetime.utcnow() + db.session.commit() + except Exception as e: + logging.exception("consume document failed") + dataset_document.indexing_status = 'error' + dataset_document.error = str(e) + dataset_document.stopped_at = datetime.datetime.utcnow() + db.session.commit() + + def run_in_splitting_status(self, dataset_document: DatasetDocument): + """Run the indexing process when the index_status is splitting.""" + try: # get dataset dataset = Dataset.query.filter_by( - id=document.dataset_id + id=dataset_document.dataset_id ).first() if not dataset: raise ValueError("no dataset found") + # get exist document_segment list and delete + document_segments = DocumentSegment.query.filter_by( + dataset_id=dataset.id, + document_id=dataset_document.id + ).all() + + db.session.delete(document_segments) + db.session.commit() + # load file - text_docs = self._load_data(document) + text_docs = self._load_data(dataset_document) # get the process rule processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ + filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ first() - # get node parser for splitting - node_parser = self._get_node_parser(processing_rule) + # get splitter + splitter = self._get_splitter(processing_rule) - # split to nodes - nodes = self._step_split( + # split to documents + documents = self._step_split( text_docs=text_docs, - node_parser=node_parser, + splitter=splitter, dataset=dataset, - document=document, + dataset_document=dataset_document, processing_rule=processing_rule ) # build index self._build_index( dataset=dataset, - document=document, - nodes=nodes + dataset_document=dataset_document, + documents=documents ) + except DocumentIsPausedException: + raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + except ProviderTokenNotInitError as e: + dataset_document.indexing_status = 'error' + dataset_document.error = str(e.description) + dataset_document.stopped_at = datetime.datetime.utcnow() + db.session.commit() + except Exception as e: + logging.exception("consume document failed") + dataset_document.indexing_status = 'error' + dataset_document.error = str(e) + dataset_document.stopped_at = datetime.datetime.utcnow() + db.session.commit() - def run_in_splitting_status(self, document: Document): - """Run the indexing process when the index_status is splitting.""" - # get dataset - dataset = Dataset.query.filter_by( - id=document.dataset_id - ).first() - - if not dataset: - raise ValueError("no dataset found") - - # get exist document_segment list and delete - document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, - document_id=document.id - ).all() - db.session.delete(document_segments) - db.session.commit() - # load file - text_docs = self._load_data(document) - - # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ - first() - - # get node parser for splitting - node_parser = self._get_node_parser(processing_rule) - - # split to nodes - nodes = self._step_split( - text_docs=text_docs, - node_parser=node_parser, - dataset=dataset, - document=document, - processing_rule=processing_rule - ) - - # build index - self._build_index( - dataset=dataset, - document=document, - nodes=nodes - ) - - def run_in_indexing_status(self, document: Document): + def run_in_indexing_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is indexing.""" - # get dataset - dataset = Dataset.query.filter_by( - id=document.dataset_id - ).first() + try: + # get dataset + dataset = Dataset.query.filter_by( + id=dataset_document.dataset_id + ).first() - if not dataset: - raise ValueError("no dataset found") + if not dataset: + raise ValueError("no dataset found") - # get exist document_segment list and delete - document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, - document_id=document.id - ).all() - nodes = [] - if document_segments: - for document_segment in document_segments: - # transform segment to node - if document_segment.status != "completed": - relationships = { - DocumentRelationship.SOURCE: document_segment.document_id, - } + # get exist document_segment list and delete + document_segments = DocumentSegment.query.filter_by( + dataset_id=dataset.id, + document_id=dataset_document.id + ).all() - previous_segment = document_segment.previous_segment - if previous_segment: - relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id + documents = [] + if document_segments: + for document_segment in document_segments: + # transform segment to node + if document_segment.status != "completed": + document = Document( + page_content=document_segment.content, + metadata={ + "doc_id": document_segment.index_node_id, + "doc_hash": document_segment.index_node_hash, + "document_id": document_segment.document_id, + "dataset_id": document_segment.dataset_id, + } + ) - next_segment = document_segment.next_segment - if next_segment: - relationships[DocumentRelationship.NEXT] = next_segment.index_node_id - node = Node( - doc_id=document_segment.index_node_id, - doc_hash=document_segment.index_node_hash, - text=document_segment.content, - extra_info=None, - node_info=None, - relationships=relationships - ) - nodes.append(node) + documents.append(document) - # build index - self._build_index( - dataset=dataset, - document=document, - nodes=nodes - ) + # build index + self._build_index( + dataset=dataset, + dataset_document=dataset_document, + documents=documents + ) + except DocumentIsPausedException: + raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + except ProviderTokenNotInitError as e: + dataset_document.indexing_status = 'error' + dataset_document.error = str(e.description) + dataset_document.stopped_at = datetime.datetime.utcnow() + db.session.commit() + except Exception as e: + logging.exception("consume document failed") + dataset_document.indexing_status = 'error' + dataset_document.error = str(e) + dataset_document.stopped_at = datetime.datetime.utcnow() + db.session.commit() def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict: """ @@ -179,28 +214,28 @@ class IndexingRunner: total_segments = 0 for file_detail in file_details: # load data from file - text_docs = self._load_data_from_file(file_detail) + text_docs = FileExtractor.load(file_detail) processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) - # get node parser for splitting - node_parser = self._get_node_parser(processing_rule) + # get splitter + splitter = self._get_splitter(processing_rule) - # split to nodes - nodes = self._split_to_nodes( + # split to documents + documents = self._split_to_documents( text_docs=text_docs, - node_parser=node_parser, + splitter=splitter, processing_rule=processing_rule ) - total_segments += len(nodes) - for node in nodes: + total_segments += len(documents) + for document in documents: if len(preview_texts) < 5: - preview_texts.append(node.get_text()) + preview_texts.append(document.page_content) - tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) + tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) return { "total_segments": total_segments, @@ -230,35 +265,36 @@ class IndexingRunner: ).first() if not data_source_binding: raise ValueError('Data source binding not found.') - reader = NotionPageReader(integration_token=data_source_binding.access_token) + for page in notion_info['pages']: - if page['type'] == 'page': - page_ids = [page['page_id']] - documents = reader.load_data_as_documents(page_ids=page_ids) - elif page['type'] == 'database': - documents = reader.load_data_as_documents(database_id=page['page_id']) - else: - documents = [] + loader = NotionLoader( + notion_access_token=data_source_binding.access_token, + notion_workspace_id=workspace_id, + notion_obj_id=page['page_id'], + notion_page_type=page['type'] + ) + documents = loader.load() + processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) - # get node parser for splitting - node_parser = self._get_node_parser(processing_rule) + # get splitter + splitter = self._get_splitter(processing_rule) - # split to nodes - nodes = self._split_to_nodes( + # split to documents + documents = self._split_to_documents( text_docs=documents, - node_parser=node_parser, + splitter=splitter, processing_rule=processing_rule ) - total_segments += len(nodes) - for node in nodes: + total_segments += len(documents) + for document in documents: if len(preview_texts) < 5: - preview_texts.append(node.get_text()) + preview_texts.append(document.page_content) - tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) + tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) return { "total_segments": total_segments, @@ -268,14 +304,14 @@ class IndexingRunner: "preview": preview_texts } - def _load_data(self, document: Document) -> List[Document]: + def _load_data(self, dataset_document: DatasetDocument) -> List[Document]: # load file - if document.data_source_type not in ["upload_file", "notion_import"]: + if dataset_document.data_source_type not in ["upload_file", "notion_import"]: return [] - data_source_info = document.data_source_info_dict + data_source_info = dataset_document.data_source_info_dict text_docs = [] - if document.data_source_type == 'upload_file': + if dataset_document.data_source_type == 'upload_file': if not data_source_info or 'upload_file_id' not in data_source_info: raise ValueError("no upload file found") @@ -283,47 +319,28 @@ class IndexingRunner: filter(UploadFile.id == data_source_info['upload_file_id']). \ one_or_none() - text_docs = self._load_data_from_file(file_detail) - elif document.data_source_type == 'notion_import': - if not data_source_info or 'notion_page_id' not in data_source_info \ - or 'notion_workspace_id' not in data_source_info: - raise ValueError("no notion page found") - workspace_id = data_source_info['notion_workspace_id'] - page_id = data_source_info['notion_page_id'] - page_type = data_source_info['type'] - data_source_binding = DataSourceBinding.query.filter( - db.and_( - DataSourceBinding.tenant_id == document.tenant_id, - DataSourceBinding.provider == 'notion', - DataSourceBinding.disabled == False, - DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' - ) - ).first() - if not data_source_binding: - raise ValueError('Data source binding not found.') - if page_type == 'page': - # add page last_edited_time to data_source_info - self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document) - text_docs = self._load_page_data_from_notion(page_id, data_source_binding.access_token) - elif page_type == 'database': - # add page last_edited_time to data_source_info - self._get_notion_database_last_edited_time(page_id, data_source_binding.access_token, document) - text_docs = self._load_database_data_from_notion(page_id, data_source_binding.access_token) + text_docs = FileExtractor.load(file_detail) + elif dataset_document.data_source_type == 'notion_import': + loader = NotionLoader.from_document(dataset_document) + text_docs = loader.load() + # update document status to splitting self._update_document_index_status( - document_id=document.id, + document_id=dataset_document.id, after_indexing_status="splitting", extra_update_params={ - Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]), - Document.parsing_completed_at: datetime.datetime.utcnow() + DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]), + DatasetDocument.parsing_completed_at: datetime.datetime.utcnow() } ) # replace doc id to document model id + text_docs = cast(List[Document], text_docs) for text_doc in text_docs: # remove invalid symbol - text_doc.text = self.filter_string(text_doc.get_text()) - text_doc.doc_id = document.id + text_doc.page_content = self.filter_string(text_doc.page_content) + text_doc.metadata['document_id'] = dataset_document.id + text_doc.metadata['dataset_id'] = dataset_document.dataset_id return text_docs @@ -331,61 +348,7 @@ class IndexingRunner: pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]') return pattern.sub('', text) - def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]: - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(upload_file.key).suffix - filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" - self.storage.download(upload_file.key, filepath) - - file_extractor = DEFAULT_FILE_EXTRACTOR.copy() - file_extractor[".markdown"] = MarkdownParser() - file_extractor[".md"] = MarkdownParser() - file_extractor[".html"] = HTMLParser() - file_extractor[".htm"] = HTMLParser() - file_extractor[".pdf"] = PDFParser({'upload_file': upload_file}) - file_extractor[".xlsx"] = XLSXParser() - - loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor) - text_docs = loader.load_data() - - return text_docs - - def _load_page_data_from_notion(self, page_id: str, access_token: str) -> List[Document]: - page_ids = [page_id] - reader = NotionPageReader(integration_token=access_token) - text_docs = reader.load_data_as_documents(page_ids=page_ids) - return text_docs - - def _load_database_data_from_notion(self, database_id: str, access_token: str) -> List[Document]: - reader = NotionPageReader(integration_token=access_token) - text_docs = reader.load_data_as_documents(database_id=database_id) - return text_docs - - def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document): - reader = NotionPageReader(integration_token=access_token) - last_edited_time = reader.get_page_last_edited_time(page_id) - data_source_info = document.data_source_info_dict - data_source_info['last_edited_time'] = last_edited_time - update_params = { - Document.data_source_info: json.dumps(data_source_info) - } - - Document.query.filter_by(id=document.id).update(update_params) - db.session.commit() - - def _get_notion_database_last_edited_time(self, page_id: str, access_token: str, document: Document): - reader = NotionPageReader(integration_token=access_token) - last_edited_time = reader.get_database_last_edited_time(page_id) - data_source_info = document.data_source_info_dict - data_source_info['last_edited_time'] = last_edited_time - update_params = { - Document.data_source_info: json.dumps(data_source_info) - } - - Document.query.filter_by(id=document.id).update(update_params) - db.session.commit() - - def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser: + def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ @@ -414,68 +377,83 @@ class IndexingRunner: separators=["\n\n", "。", ".", " ", ""] ) - return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True) + return character_splitter - def _step_split(self, text_docs: List[Document], node_parser: NodeParser, - dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]: + def _step_split(self, text_docs: List[Document], splitter: TextSplitter, + dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ + -> List[Document]: """ - Split the text documents into nodes and save them to the document segment. + Split the text documents into documents and save them to the document segment. """ - nodes = self._split_to_nodes( + documents = self._split_to_documents( text_docs=text_docs, - node_parser=node_parser, + splitter=splitter, processing_rule=processing_rule ) # save node to document segment doc_store = DatesetDocumentStore( dataset=dataset, - user_id=document.created_by, + user_id=dataset_document.created_by, embedding_model_name=self.embedding_model_name, - document_id=document.id + document_id=dataset_document.id ) + # add document segments - doc_store.add_documents(nodes) + doc_store.add_documents(documents) # update document status to indexing cur_time = datetime.datetime.utcnow() self._update_document_index_status( - document_id=document.id, + document_id=dataset_document.id, after_indexing_status="indexing", extra_update_params={ - Document.cleaning_completed_at: cur_time, - Document.splitting_completed_at: cur_time, + DatasetDocument.cleaning_completed_at: cur_time, + DatasetDocument.splitting_completed_at: cur_time, } ) # update segment status to indexing self._update_segments_by_document( - document_id=document.id, + dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", DocumentSegment.indexing_at: datetime.datetime.utcnow() } ) - return nodes + return documents - def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser, - processing_rule: DatasetProcessRule) -> List[Node]: + def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, + processing_rule: DatasetProcessRule) -> List[Document]: """ Split the text documents into nodes. """ - all_nodes = [] + all_documents = [] for text_doc in text_docs: # document clean - document_text = self._document_clean(text_doc.get_text(), processing_rule) - text_doc.text = document_text + document_text = self._document_clean(text_doc.page_content, processing_rule) + text_doc.page_content = document_text # parse document to nodes - nodes = node_parser.get_nodes_from_documents([text_doc]) - nodes = [node for node in nodes if node.text is not None and node.text.strip()] - all_nodes.extend(nodes) + documents = splitter.split_documents([text_doc]) - return all_nodes + split_documents = [] + for document in documents: + if document.page_content is None or not document.page_content.strip(): + continue + + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document.page_content) + + document.metadata['doc_id'] = doc_id + document.metadata['doc_hash'] = hash + + split_documents.append(document) + + all_documents.extend(split_documents) + + return all_documents def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str: """ @@ -506,37 +484,38 @@ class IndexingRunner: return text - def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None: + def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None: """ Build the index for the document. """ - vector_index = VectorIndex(dataset=dataset) - keyword_table_index = KeywordTableIndex(dataset=dataset) + vector_index = IndexBuilder.get_index(dataset, 'high_quality') + keyword_table_index = IndexBuilder.get_index(dataset, 'economy') # chunk nodes by chunk size indexing_start_at = time.perf_counter() tokens = 0 chunk_size = 100 - for i in range(0, len(nodes), chunk_size): + for i in range(0, len(documents), chunk_size): # check document is paused - self._check_document_paused_status(document.id) - chunk_nodes = nodes[i:i + chunk_size] + self._check_document_paused_status(dataset_document.id) + chunk_documents = documents[i:i + chunk_size] tokens += sum( - TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes + TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) + for document in chunk_documents ) # save vector index - if dataset.indexing_technique == "high_quality": - vector_index.add_nodes(chunk_nodes) + if vector_index: + vector_index.add_texts(chunk_documents) # save keyword index - keyword_table_index.add_nodes(chunk_nodes) + keyword_table_index.add_texts(chunk_documents) - node_ids = [node.doc_id for node in chunk_nodes] + document_ids = [document.metadata['doc_id'] for document in chunk_documents] db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == document.id, - DocumentSegment.index_node_id.in_(node_ids), + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.index_node_id.in_(document_ids), DocumentSegment.status == "indexing" ).update({ DocumentSegment.status: "completed", @@ -549,12 +528,12 @@ class IndexingRunner: # update document status to completed self._update_document_index_status( - document_id=document.id, + document_id=dataset_document.id, after_indexing_status="completed", extra_update_params={ - Document.tokens: tokens, - Document.completed_at: datetime.datetime.utcnow(), - Document.indexing_latency: indexing_end_at - indexing_start_at, + DatasetDocument.tokens: tokens, + DatasetDocument.completed_at: datetime.datetime.utcnow(), + DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, } ) @@ -569,25 +548,25 @@ class IndexingRunner: """ Update the document indexing status. """ - count = Document.query.filter_by(id=document_id, is_paused=True).count() + count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() if count > 0: raise DocumentIsPausedException() update_params = { - Document.indexing_status: after_indexing_status + DatasetDocument.indexing_status: after_indexing_status } if extra_update_params: update_params.update(extra_update_params) - Document.query.filter_by(id=document_id).update(update_params) + DatasetDocument.query.filter_by(id=document_id).update(update_params) db.session.commit() - def _update_segments_by_document(self, document_id: str, update_params: dict) -> None: + def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None: """ Update the document segment by document id. """ - DocumentSegment.query.filter_by(document_id=document_id).update(update_params) + DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() diff --git a/api/core/llm/llm_builder.py b/api/core/llm/llm_builder.py index 30b0a931b3..c2deda5351 100644 --- a/api/core/llm/llm_builder.py +++ b/api/core/llm/llm_builder.py @@ -1,7 +1,6 @@ -from typing import Union, Optional +from typing import Union, Optional, List -from langchain.callbacks import CallbackManager -from langchain.llms.fake import FakeListLLM +from langchain.callbacks.base import BaseCallbackHandler from core.constant import llm_constant from core.llm.error import ProviderTokenNotInitError @@ -32,12 +31,11 @@ class LLMBuilder: """ @classmethod - def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]: - if model_name == 'fake': - return FakeListLLM(responses=[]) - + def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]: provider = cls.get_default_provider(tenant_id) + model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) + mode = cls.get_mode_by_model(model_name) if mode == 'chat': if provider == 'openai': @@ -52,16 +50,21 @@ class LLMBuilder: else: raise ValueError(f"model name {model_name} is not supported.") - model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) + + model_kwargs = { + 'top_p': kwargs.get('top_p', 1), + 'frequency_penalty': kwargs.get('frequency_penalty', 0), + 'presence_penalty': kwargs.get('presence_penalty', 0), + } + + model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs} return llm_cls( model_name=model_name, temperature=kwargs.get('temperature', 0), max_tokens=kwargs.get('max_tokens', 256), - top_p=kwargs.get('top_p', 1), - frequency_penalty=kwargs.get('frequency_penalty', 0), - presence_penalty=kwargs.get('presence_penalty', 0), - callback_manager=kwargs.get('callback_manager', None), + **model_extras_kwargs, + callbacks=kwargs.get('callbacks', None), streaming=kwargs.get('streaming', False), # request_timeout=None **model_credentials @@ -69,7 +72,7 @@ class LLMBuilder: @classmethod def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, - callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]: + callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]: model_name = model.get("name") completion_params = model.get("completion_params", {}) @@ -82,7 +85,7 @@ class LLMBuilder: frequency_penalty=completion_params.get('frequency_penalty', 0.1), presence_penalty=completion_params.get('presence_penalty', 0.1), streaming=streaming, - callback_manager=callback_manager + callbacks=callbacks ) @classmethod diff --git a/api/core/llm/provider/azure_provider.py b/api/core/llm/provider/azure_provider.py index c64e785215..db831ace08 100644 --- a/api/core/llm/provider/azure_provider.py +++ b/api/core/llm/provider/azure_provider.py @@ -42,7 +42,10 @@ class AzureProvider(BaseProvider): """ config = self.get_provider_api_key(model_id=model_id) config['openai_api_type'] = 'azure' - config['deployment_name'] = model_id.replace('.', '') if model_id else None + if model_id == 'text-embedding-ada-002': + config['deployment'] = model_id.replace('.', '') if model_id else None + else: + config['deployment_name'] = model_id.replace('.', '') if model_id else None return config def get_provider_name(self): diff --git a/api/core/llm/streamable_azure_chat_open_ai.py b/api/core/llm/streamable_azure_chat_open_ai.py index f3d514cf58..4d1d5be0b3 100644 --- a/api/core/llm/streamable_azure_chat_open_ai.py +++ b/api/core/llm/streamable_azure_chat_open_ai.py @@ -1,3 +1,4 @@ +from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.chat_models import AzureChatOpenAI from typing import Optional, List, Dict, Any @@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): return message_tokens - def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None - ) -> ChatResult: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], - verbose=self.verbose - ) - - chat_result = super()._generate(messages, stop) - - result = LLMResult( - generations=[chat_result.generations], - llm_output=chat_result.llm_output - ) - self.callback_manager.on_llm_end(result, verbose=self.verbose) - - return chat_result - - async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None - ) -> ChatResult: - if self.callback_manager.is_async: - await self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], - verbose=self.verbose - ) - else: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], - verbose=self.verbose - ) - - chat_result = super()._generate(messages, stop) - - result = LLMResult( - generations=[chat_result.generations], - llm_output=chat_result.llm_output - ) - - if self.callback_manager.is_async: - await self.callback_manager.on_llm_end(result, verbose=self.verbose) - else: - self.callback_manager.on_llm_end(result, verbose=self.verbose) - - return chat_result - @handle_llm_exceptions def generate( - self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: - return super().generate(messages, stop) + return super().generate(messages, stop, callbacks, **kwargs) @handle_llm_exceptions_async async def agenerate( - self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: - return await super().agenerate(messages, stop) + return await super().agenerate(messages, stop, callbacks, **kwargs) diff --git a/api/core/llm/streamable_azure_open_ai.py b/api/core/llm/streamable_azure_open_ai.py index e383f8cf23..ac2258bb61 100644 --- a/api/core/llm/streamable_azure_open_ai.py +++ b/api/core/llm/streamable_azure_open_ai.py @@ -1,5 +1,4 @@ -import os - +from langchain.callbacks.manager import Callbacks from langchain.llms import AzureOpenAI from langchain.schema import LLMResult from typing import Optional, List, Dict, Mapping, Any @@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI): @handle_llm_exceptions def generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: - return super().generate(prompts, stop) + return super().generate(prompts, stop, callbacks, **kwargs) @handle_llm_exceptions_async async def agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: - return await super().agenerate(prompts, stop) + return await super().agenerate(prompts, stop, callbacks, **kwargs) diff --git a/api/core/llm/streamable_chat_open_ai.py b/api/core/llm/streamable_chat_open_ai.py index 582041ba09..a1fad702ab 100644 --- a/api/core/llm/streamable_chat_open_ai.py +++ b/api/core/llm/streamable_chat_open_ai.py @@ -1,6 +1,7 @@ import os -from langchain.schema import BaseMessage, ChatResult, LLMResult +from langchain.callbacks.manager import Callbacks +from langchain.schema import BaseMessage, LLMResult from langchain.chat_models import ChatOpenAI from typing import Optional, List, Dict, Any @@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI): return message_tokens - def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None - ) -> ChatResult: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose - ) - - chat_result = super()._generate(messages, stop) - - result = LLMResult( - generations=[chat_result.generations], - llm_output=chat_result.llm_output - ) - self.callback_manager.on_llm_end(result, verbose=self.verbose) - - return chat_result - - async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None - ) -> ChatResult: - if self.callback_manager.is_async: - await self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose - ) - else: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose - ) - - chat_result = super()._generate(messages, stop) - - result = LLMResult( - generations=[chat_result.generations], - llm_output=chat_result.llm_output - ) - - if self.callback_manager.is_async: - await self.callback_manager.on_llm_end(result, verbose=self.verbose) - else: - self.callback_manager.on_llm_end(result, verbose=self.verbose) - - return chat_result - @handle_llm_exceptions def generate( - self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: - return super().generate(messages, stop) + return super().generate(messages, stop, callbacks, **kwargs) @handle_llm_exceptions_async async def agenerate( - self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: - return await super().agenerate(messages, stop) + return await super().agenerate(messages, stop, callbacks, **kwargs) diff --git a/api/core/llm/streamable_open_ai.py b/api/core/llm/streamable_open_ai.py index 9cf1b4c4bb..a69e461d0d 100644 --- a/api/core/llm/streamable_open_ai.py +++ b/api/core/llm/streamable_open_ai.py @@ -1,5 +1,6 @@ import os +from langchain.callbacks.manager import Callbacks from langchain.schema import LLMResult from typing import Optional, List, Dict, Any, Mapping from langchain import OpenAI @@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI): "organization": self.openai_organization if self.openai_organization else None, }} - @handle_llm_exceptions def generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: - return super().generate(prompts, stop) + return super().generate(prompts, stop, callbacks, **kwargs) @handle_llm_exceptions_async async def agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: - return await super().agenerate(prompts, stop) + return await super().agenerate(prompts, stop, callbacks, **kwargs) diff --git a/api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py b/api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py index e5933931a2..0edd2445e2 100644 --- a/api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py +++ b/api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py @@ -1,7 +1,7 @@ from typing import Any, List, Dict from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel +from langchain.schema import get_buffer_string from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ ReadOnlyConversationTokenDBBufferSharedMemory diff --git a/api/core/prompt/prompts.py b/api/core/prompt/prompts.py index af17075408..330f473125 100644 --- a/api/core/prompt/prompts.py +++ b/api/core/prompt/prompts.py @@ -1,5 +1,3 @@ -from llama_index import QueryKeywordExtractPrompt - CONVERSATION_TITLE_PROMPT = ( "Human:{query}\n-----\n" "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n" @@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( "[\"question1\",\"question2\",\"question3\"]\n" ) -QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = ( - "A question is provided below. Given the question, extract up to {max_keywords} " - "keywords from the text. Focus on extracting the keywords that we can use " - "to best lookup answers to the question. Avoid stopwords." - "I am not sure which language the following question is in. " - "If the user asked the question in Chinese, please return the keywords in Chinese. " - "If the user asked the question in English, please return the keywords in English.\n" - "---------------------\n" - "{question}\n" - "---------------------\n" - "Provide keywords in the following comma-separated format: 'KEYWORDS: '\n" -) - -QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt( - QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL -) - RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ the model prompt that best suits the input. You will be provided with the prompt, variables, and an opening statement. diff --git a/api/core/index/spiltter/fixed_text_splitter.py b/api/core/spiltter/fixed_text_splitter.py similarity index 100% rename from api/core/index/spiltter/fixed_text_splitter.py rename to api/core/spiltter/fixed_text_splitter.py diff --git a/api/core/tool/dataset_index_tool.py b/api/core/tool/dataset_index_tool.py new file mode 100644 index 0000000000..2776c6f48a --- /dev/null +++ b/api/core/tool/dataset_index_tool.py @@ -0,0 +1,87 @@ +from flask import current_app +from langchain.embeddings import OpenAIEmbeddings +from langchain.tools import BaseTool + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.embedding.cached_embedding import CacheEmbedding +from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig +from core.index.vector_index.vector_index import VectorIndex +from core.llm.llm_builder import LLMBuilder +from models.dataset import Dataset + + +class DatasetTool(BaseTool): + """Tool for querying a Dataset.""" + + dataset: Dataset + k: int = 2 + + def _run(self, tool_input: str) -> str: + if self.dataset.indexing_technique == "economy": + # use keyword table query + kw_table_index = KeywordTableIndex( + dataset=self.dataset, + config=KeywordTableConfig( + max_keywords_per_chunk=5 + ) + ) + + documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k}) + else: + model_credentials = LLMBuilder.get_model_credentials( + tenant_id=self.dataset.tenant_id, + model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), + model_name='text-embedding-ada-002' + ) + + embeddings = CacheEmbedding(OpenAIEmbeddings( + **model_credentials + )) + + vector_index = VectorIndex( + dataset=self.dataset, + config=current_app.config, + embeddings=embeddings + ) + + documents = vector_index.search( + tool_input, + search_type='similarity', + search_kwargs={ + 'k': self.k + } + ) + + hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id) + hit_callback.on_tool_end(documents) + + return str("\n".join([document.page_content for document in documents])) + + async def _arun(self, tool_input: str) -> str: + model_credentials = LLMBuilder.get_model_credentials( + tenant_id=self.dataset.tenant_id, + model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), + model_name='text-embedding-ada-002' + ) + + embeddings = CacheEmbedding(OpenAIEmbeddings( + **model_credentials + )) + + vector_index = VectorIndex( + dataset=self.dataset, + config=current_app.config, + embeddings=embeddings + ) + + documents = await vector_index.asearch( + tool_input, + search_type='similarity', + search_kwargs={ + 'k': 10 + } + ) + + hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id) + hit_callback.on_tool_end(documents) + return str("\n".join([document.page_content for document in documents])) diff --git a/api/core/tool/dataset_tool_builder.py b/api/core/tool/dataset_tool_builder.py deleted file mode 100644 index aa7a618b50..0000000000 --- a/api/core/tool/dataset_tool_builder.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Optional - -from langchain.callbacks import CallbackManager -from llama_index.langchain_helpers.agents import IndexToolConfig - -from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.index.keyword_table_index import KeywordTableIndex -from core.index.vector_index import VectorIndex -from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE -from core.tool.llama_index_tool import EnhanceLlamaIndexTool -from models.dataset import Dataset - - -class DatasetToolBuilder: - @classmethod - def build_dataset_tool(cls, dataset: Dataset, - response_mode: str = "no_synthesizer", - callback_handler: Optional[DatasetToolCallbackHandler] = None): - if dataset.indexing_technique == "economy": - # use keyword table query - index = KeywordTableIndex(dataset=dataset).query_index - - if not index: - return None - - query_kwargs = { - "mode": "default", - "response_mode": response_mode, - "query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE, - "max_keywords_per_query": 5, - # If num_chunks_per_query is too large, - # it will slow down the synthesis process due to multiple iterations of refinement. - "num_chunks_per_query": 2 - } - else: - index = VectorIndex(dataset=dataset).query_index - - if not index: - return None - - query_kwargs = { - "mode": "default", - "response_mode": response_mode, - # If top_k is too large, - # it will slow down the synthesis process due to multiple iterations of refinement. - "similarity_top_k": 2 - } - - # fulfill description when it is empty - description = dataset.description - if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name - - index_tool_config = IndexToolConfig( - index=index, - name=f"dataset-{dataset.id}", - description=description, - index_query_kwargs=query_kwargs, - tool_kwargs={ - "callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()]) - }, - # tool_kwargs={"return_direct": True}, - # return_direct: Whether to return LLM results directly or process the output data with an Output Parser - ) - - index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id) - - return EnhanceLlamaIndexTool.from_tool_config( - tool_config=index_tool_config, - callback_handler=index_callback_handler - ) diff --git a/api/core/tool/llama_index_tool.py b/api/core/tool/llama_index_tool.py deleted file mode 100644 index ffb216771b..0000000000 --- a/api/core/tool/llama_index_tool.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Dict - -from langchain.tools import BaseTool -from llama_index.indices.base import BaseGPTIndex -from llama_index.langchain_helpers.agents import IndexToolConfig -from pydantic import Field - -from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler - - -class EnhanceLlamaIndexTool(BaseTool): - """Tool for querying a LlamaIndex.""" - - # NOTE: name/description still needs to be set - index: BaseGPTIndex - query_kwargs: Dict = Field(default_factory=dict) - return_sources: bool = False - callback_handler: IndexToolCallbackHandler - - @classmethod - def from_tool_config(cls, tool_config: IndexToolConfig, - callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool": - """Create a tool from a tool config.""" - return_sources = tool_config.tool_kwargs.pop("return_sources", False) - return cls( - index=tool_config.index, - callback_handler=callback_handler, - name=tool_config.name, - description=tool_config.description, - return_sources=return_sources, - query_kwargs=tool_config.index_query_kwargs, - **tool_config.tool_kwargs, - ) - - def _run(self, tool_input: str) -> str: - response = self.index.query(tool_input, **self.query_kwargs) - self.callback_handler.on_tool_end(response) - return str(response) - - async def _arun(self, tool_input: str) -> str: - response = await self.index.aquery(tool_input, **self.query_kwargs) - self.callback_handler.on_tool_end(response) - return str(response) diff --git a/api/core/vector_store/base.py b/api/core/vector_store/base.py deleted file mode 100644 index 526f83831d..0000000000 --- a/api/core/vector_store/base.py +++ /dev/null @@ -1,34 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional - -from llama_index import ServiceContext, GPTVectorStoreIndex -from llama_index.data_structs import Node -from llama_index.vector_stores.types import VectorStore - - -class BaseVectorStoreClient(ABC): - @abstractmethod - def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex: - raise NotImplementedError - - @abstractmethod - def to_index_config(self, index_id: str) -> dict: - raise NotImplementedError - - -class BaseGPTVectorStoreIndex(GPTVectorStoreIndex): - def delete_node(self, node_id: str): - self._vector_store.delete_node(node_id) - - def exists_by_node_id(self, node_id: str) -> bool: - return self._vector_store.exists_by_node_id(node_id) - - -class EnhanceVectorStore(ABC): - @abstractmethod - def delete_node(self, node_id: str): - pass - - @abstractmethod - def exists_by_node_id(self, node_id: str) -> bool: - pass diff --git a/api/core/vector_store/qdrant_vector_store.py b/api/core/vector_store/qdrant_vector_store.py new file mode 100644 index 0000000000..103a4b97c0 --- /dev/null +++ b/api/core/vector_store/qdrant_vector_store.py @@ -0,0 +1,69 @@ +from typing import cast, Any + +from langchain.schema import Document +from langchain.vectorstores import Qdrant +from qdrant_client.http.models import Filter, PointIdsList, FilterSelector +from qdrant_client.local.qdrant_local import QdrantLocal + + +class QdrantVectorStore(Qdrant): + def del_texts(self, filter: Filter): + if not filter: + raise ValueError('filter must not be empty') + + self._reload_if_needed() + + self.client.delete( + collection_name=self.collection_name, + points_selector=FilterSelector( + filter=filter + ), + ) + + def del_text(self, uuid: str) -> None: + self._reload_if_needed() + + self.client.delete( + collection_name=self.collection_name, + points_selector=PointIdsList( + points=[uuid], + ), + ) + + def text_exists(self, uuid: str) -> bool: + self._reload_if_needed() + + response = self.client.retrieve( + collection_name=self.collection_name, + ids=[uuid] + ) + + return len(response) > 0 + + def delete(self): + self._reload_if_needed() + + self.client.delete_collection(collection_name=self.collection_name) + + @classmethod + def _document_from_scored_point( + cls, + scored_point: Any, + content_payload_key: str, + metadata_payload_key: str, + ) -> Document: + if scored_point.payload.get('doc_id'): + return Document( + page_content=scored_point.payload.get(content_payload_key), + metadata={'doc_id': scored_point.id} + ) + + return Document( + page_content=scored_point.payload.get(content_payload_key), + metadata=scored_point.payload.get(metadata_payload_key) or {}, + ) + + def _reload_if_needed(self): + if isinstance(self.client, QdrantLocal): + self.client = cast(QdrantLocal, self.client) + self.client._load() diff --git a/api/core/vector_store/qdrant_vector_store_client.py b/api/core/vector_store/qdrant_vector_store_client.py deleted file mode 100644 index 1188c121e3..0000000000 --- a/api/core/vector_store/qdrant_vector_store_client.py +++ /dev/null @@ -1,147 +0,0 @@ -import os -from typing import cast, List - -from llama_index.data_structs import Node -from llama_index.data_structs.node_v2 import DocumentRelationship -from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult -from qdrant_client.http.models import Payload, Filter - -import qdrant_client -from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex -from llama_index.data_structs.data_structs_v2 import QdrantIndexDict -from llama_index.vector_stores import QdrantVectorStore -from qdrant_client.local.qdrant_local import QdrantLocal - -from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore - - -class QdrantVectorStoreClient(BaseVectorStoreClient): - - def __init__(self, url: str, api_key: str, root_path: str): - self._client = self.init_from_config(url, api_key, root_path) - - @classmethod - def init_from_config(cls, url: str, api_key: str, root_path: str): - if url and url.startswith('path:'): - path = url.replace('path:', '') - if not os.path.isabs(path): - path = os.path.join(root_path, path) - - return qdrant_client.QdrantClient( - path=path - ) - else: - return qdrant_client.QdrantClient( - url=url, - api_key=api_key, - ) - - def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex: - index_struct = QdrantIndexDict() - - if self._client is None: - raise Exception("Vector client is not initialized.") - - # {"collection_name": "Gpt_index_xxx"} - collection_name = config.get('collection_name') - if not collection_name: - raise Exception("collection_name cannot be None.") - - return GPTQdrantEnhanceIndex( - service_context=service_context, - index_struct=index_struct, - vector_store=QdrantEnhanceVectorStore( - client=self._client, - collection_name=collection_name - ) - ) - - def to_index_config(self, index_id: str) -> dict: - return {"collection_name": index_id} - - -class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex): - pass - - -class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore): - def delete_node(self, node_id: str): - """ - Delete node from the index. - - :param node_id: node id - """ - from qdrant_client.http import models as rest - - self._reload_if_needed() - - self._client.delete( - collection_name=self._collection_name, - points_selector=rest.Filter( - must=[ - rest.FieldCondition( - key="id", match=rest.MatchValue(value=node_id) - ) - ] - ), - ) - - def exists_by_node_id(self, node_id: str) -> bool: - """ - Get node from the index by node id. - - :param node_id: node id - """ - self._reload_if_needed() - - response = self._client.retrieve( - collection_name=self._collection_name, - ids=[node_id] - ) - - return len(response) > 0 - - def query( - self, - query: VectorStoreQuery, - ) -> VectorStoreQueryResult: - """Query index for top k most similar nodes. - - Args: - query (VectorStoreQuery): query - """ - query_embedding = cast(List[float], query.query_embedding) - - self._reload_if_needed() - - response = self._client.search( - collection_name=self._collection_name, - query_vector=query_embedding, - limit=cast(int, query.similarity_top_k), - query_filter=cast(Filter, self._build_query_filter(query)), - with_vectors=True - ) - - nodes = [] - similarities = [] - ids = [] - for point in response: - payload = cast(Payload, point.payload) - node = Node( - doc_id=str(point.id), - text=payload.get("text"), - embedding=point.vector, - extra_info=payload.get("extra_info"), - relationships={ - DocumentRelationship.SOURCE: payload.get("doc_id", "None"), - }, - ) - nodes.append(node) - similarities.append(point.score) - ids.append(str(point.id)) - - return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) - - def _reload_if_needed(self): - if isinstance(self._client._client, QdrantLocal): - self._client._client._load() diff --git a/api/core/vector_store/vector_store.py b/api/core/vector_store/vector_store.py deleted file mode 100644 index 59a4c5060b..0000000000 --- a/api/core/vector_store/vector_store.py +++ /dev/null @@ -1,62 +0,0 @@ -from flask import Flask -from llama_index import ServiceContext, GPTVectorStoreIndex -from requests import ReadTimeout -from tenacity import retry, retry_if_exception_type, stop_after_attempt - -from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient -from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient - -SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant'] - - -class VectorStore: - - def __init__(self): - self._vector_store = None - self._client = None - - def init_app(self, app: Flask): - if not app.config['VECTOR_STORE']: - return - - self._vector_store = app.config['VECTOR_STORE'] - if self._vector_store not in SUPPORTED_VECTOR_STORES: - raise ValueError(f"Vector store {self._vector_store} is not supported.") - - if self._vector_store == 'weaviate': - self._client = WeaviateVectorStoreClient( - endpoint=app.config['WEAVIATE_ENDPOINT'], - api_key=app.config['WEAVIATE_API_KEY'], - grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'], - batch_size=app.config['WEAVIATE_BATCH_SIZE'] - ) - elif self._vector_store == 'qdrant': - self._client = QdrantVectorStoreClient( - url=app.config['QDRANT_URL'], - api_key=app.config['QDRANT_API_KEY'], - root_path=app.root_path - ) - - app.extensions['vector_store'] = self - - @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) - def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex: - vector_store_config: dict = index_struct.get('vector_store') - index = self.get_client().get_index( - service_context=service_context, - config=vector_store_config - ) - - return index - - def to_index_struct(self, index_id: str) -> dict: - return { - "type": self._vector_store, - "vector_store": self.get_client().to_index_config(index_id) - } - - def get_client(self): - if not self._client: - raise Exception("Vector store client is not initialized.") - - return self._client diff --git a/api/core/vector_store/vector_store_index_query.py b/api/core/vector_store/vector_store_index_query.py deleted file mode 100644 index f29de83f9e..0000000000 --- a/api/core/vector_store/vector_store_index_query.py +++ /dev/null @@ -1,66 +0,0 @@ -from llama_index.indices.query.base import IS -from typing import ( - Any, - Dict, - List, - Optional -) - -from llama_index.docstore import BaseDocumentStore -from llama_index.indices.postprocessor.node import ( - BaseNodePostprocessor, -) -from llama_index.indices.vector_store import GPTVectorStoreIndexQuery -from llama_index.indices.response.response_builder import ResponseMode -from llama_index.indices.service_context import ServiceContext -from llama_index.optimization.optimizer import BaseTokenUsageOptimizer -from llama_index.prompts.prompts import ( - QuestionAnswerPrompt, - RefinePrompt, - SimpleInputPrompt, -) - -from core.index.query.synthesizer import EnhanceResponseSynthesizer - - -class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery): - @classmethod - def from_args( - cls, - index_struct: IS, - service_context: ServiceContext, - docstore: Optional[BaseDocumentStore] = None, - node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, - verbose: bool = False, - # response synthesizer args - response_mode: ResponseMode = ResponseMode.DEFAULT, - text_qa_template: Optional[QuestionAnswerPrompt] = None, - refine_template: Optional[RefinePrompt] = None, - simple_template: Optional[SimpleInputPrompt] = None, - response_kwargs: Optional[Dict] = None, - use_async: bool = False, - streaming: bool = False, - optimizer: Optional[BaseTokenUsageOptimizer] = None, - # class-specific args - **kwargs: Any, - ) -> "BaseGPTIndexQuery": - response_synthesizer = EnhanceResponseSynthesizer.from_args( - service_context=service_context, - text_qa_template=text_qa_template, - refine_template=refine_template, - simple_template=simple_template, - response_mode=response_mode, - response_kwargs=response_kwargs, - use_async=use_async, - streaming=streaming, - optimizer=optimizer, - ) - return cls( - index_struct=index_struct, - service_context=service_context, - response_synthesizer=response_synthesizer, - docstore=docstore, - node_postprocessors=node_postprocessors, - verbose=verbose, - **kwargs, - ) diff --git a/api/core/vector_store/weaviate_vector_store.py b/api/core/vector_store/weaviate_vector_store.py new file mode 100644 index 0000000000..6dae568827 --- /dev/null +++ b/api/core/vector_store/weaviate_vector_store.py @@ -0,0 +1,38 @@ +from langchain.vectorstores import Weaviate + + +class WeaviateVectorStore(Weaviate): + def del_texts(self, where_filter: dict): + if not where_filter: + raise ValueError('where_filter must not be empty') + + self._client.batch.delete_objects( + class_name=self._index_name, + where=where_filter, + output='minimal' + ) + + def del_text(self, uuid: str) -> None: + self._client.data_object.delete( + uuid, + class_name=self._index_name + ) + + def text_exists(self, uuid: str) -> bool: + result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({ + "path": ["doc_id"], + "operator": "Equal", + "valueText": uuid, + }).with_limit(1).do() + + if "errors" in result: + raise ValueError(f"Error during query: {result['errors']}") + + entries = result["data"]["Get"][self._index_name] + if len(entries) == 0: + return False + + return True + + def delete(self): + self._client.schema.delete_class(self._index_name) diff --git a/api/core/vector_store/weaviate_vector_store_client.py b/api/core/vector_store/weaviate_vector_store_client.py deleted file mode 100644 index 0fe120de71..0000000000 --- a/api/core/vector_store/weaviate_vector_store_client.py +++ /dev/null @@ -1,270 +0,0 @@ -import json -import weaviate -from dataclasses import field -from typing import List, Any, Dict, Optional - -from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore -from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex -from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node -from llama_index.data_structs.node_v2 import DocumentRelationship -from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger -from llama_index.vector_stores import WeaviateVectorStore -from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode -from llama_index.readers.weaviate.utils import ( - parse_get_response, - validate_client, -) - - -class WeaviateVectorStoreClient(BaseVectorStoreClient): - - def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int): - self._client = self.init_from_config(endpoint, api_key, grpc_enabled, batch_size) - - def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int): - auth_config = weaviate.auth.AuthApiKey(api_key=api_key) - - weaviate.connect.connection.has_grpc = grpc_enabled - - client = weaviate.Client( - url=endpoint, - auth_client_secret=auth_config, - timeout_config=(5, 60), - startup_period=None - ) - - client.batch.configure( - # `batch_size` takes an `int` value to enable auto-batching - # (`None` is used for manual batching) - batch_size=batch_size, - # dynamically update the `batch_size` based on import speed - dynamic=True, - # `timeout_retries` takes an `int` value to retry on time outs - timeout_retries=3, - ) - - return client - - def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex: - index_struct = WeaviateIndexDict() - - if self._client is None: - raise Exception("Vector client is not initialized.") - - # {"class_prefix": "Gpt_index_xxx"} - class_prefix = config.get('class_prefix') - if not class_prefix: - raise Exception("class_prefix cannot be None.") - - return GPTWeaviateEnhanceIndex( - service_context=service_context, - index_struct=index_struct, - vector_store=WeaviateWithSimilaritiesVectorStore( - weaviate_client=self._client, - class_prefix=class_prefix - ) - ) - - def to_index_config(self, index_id: str) -> dict: - return {"class_prefix": index_id} - - -class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore): - def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult: - """Query index for top k most similar nodes.""" - nodes = self.weaviate_query( - self._client, - self._class_prefix, - query, - ) - nodes = nodes[: query.similarity_top_k] - node_idxs = [str(i) for i in range(len(nodes))] - - similarities = [] - for node in nodes: - similarities.append(node.extra_info['similarity']) - del node.extra_info['similarity'] - - return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities) - - def weaviate_query( - self, - client: Any, - class_prefix: str, - query_spec: VectorStoreQuery, - ) -> List[Node]: - """Convert to LlamaIndex list.""" - validate_client(client) - - class_name = _class_name(class_prefix) - prop_names = [p["name"] for p in NODE_SCHEMA] - vector = query_spec.query_embedding - - # build query - query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"]) - if query_spec.mode == VectorStoreQueryMode.DEFAULT: - _logger.debug("Using vector search") - if vector is not None: - query = query.with_near_vector( - { - "vector": vector, - } - ) - elif query_spec.mode == VectorStoreQueryMode.HYBRID: - _logger.debug(f"Using hybrid search with alpha {query_spec.alpha}") - query = query.with_hybrid( - query=query_spec.query_str, - alpha=query_spec.alpha, - vector=vector, - ) - query = query.with_limit(query_spec.similarity_top_k) - _logger.debug(f"Using limit of {query_spec.similarity_top_k}") - - # execute query - query_result = query.do() - - # parse results - parsed_result = parse_get_response(query_result) - entries = parsed_result[class_name] - results = [self._to_node(entry) for entry in entries] - return results - - def _to_node(self, entry: Dict) -> Node: - """Convert to Node.""" - extra_info_str = entry["extra_info"] - if extra_info_str == "": - extra_info = None - else: - extra_info = json.loads(extra_info_str) - - if 'certainty' in entry['_additional']: - if extra_info: - extra_info['similarity'] = entry['_additional']['certainty'] - else: - extra_info = {'similarity': entry['_additional']['certainty']} - - node_info_str = entry["node_info"] - if node_info_str == "": - node_info = None - else: - node_info = json.loads(node_info_str) - - relationships_str = entry["relationships"] - relationships: Dict[DocumentRelationship, str] - if relationships_str == "": - relationships = field(default_factory=dict) - else: - relationships = { - DocumentRelationship(k): v for k, v in json.loads(relationships_str).items() - } - - return Node( - text=entry["text"], - doc_id=entry["doc_id"], - embedding=entry["_additional"]["vector"], - extra_info=extra_info, - node_info=node_info, - relationships=relationships, - ) - - def delete(self, doc_id: str, **delete_kwargs: Any) -> None: - """Delete a document. - - Args: - doc_id (str): document id - - """ - delete_document(self._client, doc_id, self._class_prefix) - - def delete_node(self, node_id: str): - """ - Delete node from the index. - - :param node_id: node id - """ - delete_node(self._client, node_id, self._class_prefix) - - def exists_by_node_id(self, node_id: str) -> bool: - """ - Get node from the index by node id. - - :param node_id: node id - """ - entry = get_by_node_id(self._client, node_id, self._class_prefix) - return True if entry else False - - -class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex): - pass - - -def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None: - """Delete entry.""" - validate_client(client) - # make sure that each entry - class_name = _class_name(class_prefix) - where_filter = { - "path": ["ref_doc_id"], - "operator": "Equal", - "valueString": ref_doc_id, - } - query = ( - client.query.get(class_name).with_additional(["id"]).with_where(where_filter) - ) - - query_result = query.do() - parsed_result = parse_get_response(query_result) - entries = parsed_result[class_name] - for entry in entries: - client.data_object.delete(entry["_additional"]["id"], class_name) - - while len(entries) > 0: - query_result = query.do() - parsed_result = parse_get_response(query_result) - entries = parsed_result[class_name] - for entry in entries: - client.data_object.delete(entry["_additional"]["id"], class_name) - - -def delete_node(client: Any, node_id: str, class_prefix: str) -> None: - """Delete entry.""" - validate_client(client) - # make sure that each entry - class_name = _class_name(class_prefix) - where_filter = { - "path": ["doc_id"], - "operator": "Equal", - "valueString": node_id, - } - query = ( - client.query.get(class_name).with_additional(["id"]).with_where(where_filter) - ) - - query_result = query.do() - parsed_result = parse_get_response(query_result) - entries = parsed_result[class_name] - for entry in entries: - client.data_object.delete(entry["_additional"]["id"], class_name) - - -def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]: - """Delete entry.""" - validate_client(client) - # make sure that each entry - class_name = _class_name(class_prefix) - where_filter = { - "path": ["doc_id"], - "operator": "Equal", - "valueString": node_id, - } - query = ( - client.query.get(class_name).with_additional(["id"]).with_where(where_filter) - ) - - query_result = query.do() - parsed_result = parse_get_response(query_result) - entries = parsed_result[class_name] - if len(entries) == 0: - return None - - return entries[0] diff --git a/api/extensions/ext_vector_store.py b/api/extensions/ext_vector_store.py deleted file mode 100644 index 4ed7a93422..0000000000 --- a/api/extensions/ext_vector_store.py +++ /dev/null @@ -1,7 +0,0 @@ -from core.vector_store.vector_store import VectorStore - -vector_store = VectorStore() - - -def init_app(app): - vector_store.init_app(app) diff --git a/api/libs/helper.py b/api/libs/helper.py index 767f368d33..b306fee2a7 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -3,6 +3,7 @@ import re import subprocess import uuid from datetime import datetime +from hashlib import sha256 from zoneinfo import available_timezones import random import string @@ -147,3 +148,8 @@ def get_remote_ip(request): return request.headers.getlist("X-Forwarded-For")[0] else: return request.remote_addr + + +def generate_text_hash(text: str) -> str: + hash_text = str(text) + 'None' + return sha256(hash_text.encode()).hexdigest() diff --git a/api/models/account.py b/api/models/account.py index bc15e86508..8f9d89cde4 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -38,8 +38,6 @@ class Account(UserMixin, db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - _current_tenant: db.Model = None - @property def current_tenant(self): return self._current_tenant diff --git a/api/models/dataset.py b/api/models/dataset.py index bbc5340bc2..345eea5f47 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -66,6 +66,23 @@ class Dataset(db.Model): def document_count(self): return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() + @property + def available_document_count(self): + return db.session.query(func.count(Document.id)).filter( + Document.dataset_id == self.id, + Document.indexing_status == 'completed', + Document.enabled == True, + Document.archived == False + ).scalar() + + @property + def available_segment_count(self): + return db.session.query(func.count(DocumentSegment.id)).filter( + DocumentSegment.dataset_id == self.id, + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True + ).scalar() + @property def word_count(self): return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ @@ -260,7 +277,7 @@ class Document(db.Model): @property def dataset(self): - return Dataset.query.get(self.dataset_id) + return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none() @property def segment_count(self): @@ -395,7 +412,18 @@ class DatasetKeywordTable(db.Model): @property def keyword_table_dict(self): - return json.loads(self.keyword_table) if self.keyword_table else None + class SetDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + super().__init__(object_hook=self.object_hook, *args, **kwargs) + + def object_hook(self, dct): + if isinstance(dct, dict): + for keyword, node_idxs in dct.items(): + if isinstance(node_idxs, list): + dct[keyword] = set(node_idxs) + return dct + + return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None class Embedding(db.Model): diff --git a/api/requirements.txt b/api/requirements.txt index 95a1d1aa9d..e129eb4b37 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -2,6 +2,7 @@ coverage~=7.2.4 beautifulsoup4==4.12.2 flask~=2.3.2 Flask-SQLAlchemy~=3.0.3 +SQLAlchemy~=1.4.28 flask-login==0.6.2 flask-migrate~=4.0.4 flask-restful==0.3.9 @@ -9,8 +10,7 @@ flask-session2==1.3.1 flask-cors==3.0.10 gunicorn~=20.1.0 gevent~=22.10.2 -langchain==0.0.142 -llama-index==0.5.27 +langchain==0.0.209 openai~=0.27.5 psycopg2-binary~=2.9.6 pycryptodome==3.17 @@ -29,6 +29,7 @@ sentry-sdk[flask]~=1.21.1 jieba==0.42.1 celery==5.2.7 redis~=4.5.4 -pypdf==3.8.1 openpyxl==3.1.2 -chardet~=5.1.0 \ No newline at end of file +chardet~=5.1.0 +docx2txt==0.8 +pypdfium2==4.16.0 \ No newline at end of file diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 70b63fbe19..f2ab184573 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -4,7 +4,6 @@ import uuid from core.constant import llm_constant from models.account import Account from services.dataset_service import DatasetService -from services.errors.account import NoPermissionError class AppModelConfigService: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 10f1fc2f35..7c4a59d355 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,7 +7,6 @@ from typing import Optional, List from extensions.ext_redis import redis_client from flask_login import current_user -from core.index.index_builder import IndexBuilder from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db @@ -386,8 +385,6 @@ class DocumentService: dataset.indexing_technique = document_data["indexing_technique"] - if dataset.indexing_technique == 'high_quality': - IndexBuilder.get_default_service_context(dataset.tenant_id) documents = [] batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) if 'original_document_id' in document_data and document_data["original_document_id"]: diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 619df1b873..b0029f80ad 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -3,47 +3,56 @@ import time from typing import List import numpy as np -from llama_index.data_structs.node_v2 import NodeWithScore -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.vector_store import GPTVectorStoreIndexQuery +from flask import current_app +from langchain.embeddings import OpenAIEmbeddings +from langchain.embeddings.base import Embeddings +from langchain.schema import Document from sklearn.manifold import TSNE -from core.docstore.empty_docstore import EmptyDocumentStore -from core.index.vector_index import VectorIndex +from core.embedding.cached_embedding import CacheEmbedding +from core.index.vector_index.vector_index import VectorIndex +from core.llm.llm_builder import LLMBuilder from extensions.ext_database import db from models.account import Account from models.dataset import Dataset, DocumentSegment, DatasetQuery -from services.errors.index import IndexNotInitializedError class HitTestingService: @classmethod def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: - index = VectorIndex(dataset=dataset).query_index + if dataset.available_document_count == 0 or dataset.available_document_count == 0: + return { + "query": { + "content": query, + "tsne_position": {'x': 0, 'y': 0}, + }, + "records": [] + } - if not index: - raise IndexNotInitializedError() - - index_query = GPTVectorStoreIndexQuery( - index_struct=index.index_struct, - service_context=index.service_context, - vector_store=index.query_context.get('vector_store'), - docstore=EmptyDocumentStore(), - response_synthesizer=None, - similarity_top_k=limit + model_credentials = LLMBuilder.get_model_credentials( + tenant_id=dataset.tenant_id, + model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), + model_name='text-embedding-ada-002' ) - query_bundle = QueryBundle( - query_str=query, - custom_embedding_strs=[query], - ) + embeddings = CacheEmbedding(OpenAIEmbeddings( + **model_credentials + )) - query_bundle.embedding = index.service_context.embed_model.get_agg_embedding_from_queries( - query_bundle.embedding_strs + vector_index = VectorIndex( + dataset=dataset, + config=current_app.config, + embeddings=embeddings ) start = time.perf_counter() - nodes = index_query.retrieve(query_bundle=query_bundle) + documents = vector_index.search( + query, + search_type='similarity_score_threshold', + search_kwargs={ + 'k': 10 + } + ) end = time.perf_counter() logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") @@ -58,25 +67,24 @@ class HitTestingService: db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(dataset, query_bundle, nodes) + return cls.compact_retrieve_response(dataset, embeddings, query, documents) @classmethod - def compact_retrieve_response(cls, dataset: Dataset, query_bundle: QueryBundle, nodes: List[NodeWithScore]): - embeddings = [ - query_bundle.embedding + def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]): + text_embeddings = [ + embeddings.embed_query(query) ] - for node in nodes: - embeddings.append(node.node.embedding) + text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents])) - tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings) + tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings) query_position = tsne_position_data.pop(0) i = 0 records = [] - for node in nodes: - index_node_id = node.node.doc_id + for document in documents: + index_node_id = document.metadata['doc_id'] segment = db.session.query(DocumentSegment).filter( DocumentSegment.dataset_id == dataset.id, @@ -91,7 +99,7 @@ class HitTestingService: record = { "segment": segment, - "score": node.score, + "score": document.metadata['score'], "tsne_position": tsne_position_data[i] } @@ -101,7 +109,7 @@ class HitTestingService: return { "query": { - "content": query_bundle.query_str, + "content": query, "tsne_position": query_position, }, "records": records diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 9ea259227e..db802d208a 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -4,96 +4,81 @@ import time import click from celery import shared_task -from llama_index.data_structs import Node -from llama_index.data_structs.node_v2 import DocumentRelationship +from langchain.schema import Document from werkzeug.exceptions import NotFound -from core.index.keyword_table_index import KeywordTableIndex -from core.index.vector_index import VectorIndex +from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import DocumentSegment, Document +from models.dataset import DocumentSegment +from models.dataset import Document as DatasetDocument @shared_task -def add_document_to_index_task(document_id: str): +def add_document_to_index_task(dataset_document_id: str): """ Async Add document to index :param document_id: Usage: add_document_to_index.delay(document_id) """ - logging.info(click.style('Start add document to index: {}'.format(document_id), fg='green')) + logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green')) start_at = time.perf_counter() - document = db.session.query(Document).filter(Document.id == document_id).first() - if not document: + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first() + if not dataset_document: raise NotFound('Document not found') - if document.indexing_status != 'completed': + if dataset_document.indexing_status != 'completed': return - indexing_cache_key = 'document_{}_indexing'.format(document.id) + indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id) try: segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == document.id, + DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True ) \ .order_by(DocumentSegment.position.asc()).all() - nodes = [] - previous_node = None + documents = [] for segment in segments: - relationships = { - DocumentRelationship.SOURCE: document.id - } - - if previous_node: - relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id - - previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id - - node = Node( - doc_id=segment.index_node_id, - doc_hash=segment.index_node_hash, - text=segment.content, - extra_info=None, - node_info=None, - relationships=relationships + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } ) - previous_node = node + documents.append(document) - nodes.append(node) - - dataset = document.dataset + dataset = dataset_document.dataset if not dataset: raise Exception('Document has no dataset') - vector_index = VectorIndex(dataset=dataset) - keyword_table_index = KeywordTableIndex(dataset=dataset) - # save vector index - if dataset.indexing_technique == "high_quality": - vector_index.add_nodes( - nodes=nodes, - duplicate_check=True - ) + index = IndexBuilder.get_index(dataset, 'high_quality') + if index: + index.add_texts(documents) # save keyword index - keyword_table_index.add_nodes(nodes) + index = IndexBuilder.get_index(dataset, 'economy') + if index: + index.add_texts(documents) end_at = time.perf_counter() logging.info( - click.style('Document added to index: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green')) except Exception as e: logging.exception("add document to index failed") - document.enabled = False - document.disabled_at = datetime.datetime.utcnow() - document.status = 'error' - document.error = str(e) + dataset_document.enabled = False + dataset_document.disabled_at = datetime.datetime.utcnow() + dataset_document.status = 'error' + dataset_document.error = str(e) db.session.commit() finally: redis_client.delete(indexing_cache_key) diff --git a/api/tasks/add_segment_to_index_task.py b/api/tasks/add_segment_to_index_task.py index bd3cadfd3c..bf96a0dc0b 100644 --- a/api/tasks/add_segment_to_index_task.py +++ b/api/tasks/add_segment_to_index_task.py @@ -4,12 +4,10 @@ import time import click from celery import shared_task -from llama_index.data_structs import Node -from llama_index.data_structs.node_v2 import DocumentRelationship +from langchain.schema import Document from werkzeug.exceptions import NotFound -from core.index.keyword_table_index import KeywordTableIndex -from core.index.vector_index import VectorIndex +from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -36,44 +34,41 @@ def add_segment_to_index_task(segment_id: str): indexing_cache_key = 'segment_{}_indexing'.format(segment.id) try: - relationships = { - DocumentRelationship.SOURCE: segment.document_id, - } - - previous_segment = segment.previous_segment - if previous_segment: - relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id - - next_segment = segment.next_segment - if next_segment: - relationships[DocumentRelationship.NEXT] = next_segment.index_node_id - - node = Node( - doc_id=segment.index_node_id, - doc_hash=segment.index_node_hash, - text=segment.content, - extra_info=None, - node_info=None, - relationships=relationships + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } ) dataset = segment.dataset if not dataset: - raise Exception('Segment has no dataset') + logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + return - vector_index = VectorIndex(dataset=dataset) - keyword_table_index = KeywordTableIndex(dataset=dataset) + dataset_document = segment.document + + if not dataset_document: + logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + return + + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': + logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + return # save vector index - if dataset.indexing_technique == "high_quality": - vector_index.add_nodes( - nodes=[node], - duplicate_check=True - ) + index = IndexBuilder.get_index(dataset, 'high_quality') + if index: + index.add_texts([document], duplicate_check=True) # save keyword index - keyword_table_index.add_nodes([node]) + index = IndexBuilder.get_index(dataset, 'economy') + if index: + index.add_texts([document]) end_at = time.perf_counter() logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 3c5ea8eb95..1232a8df1f 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -4,8 +4,7 @@ import time import click from celery import shared_task -from core.index.keyword_table_index import KeywordTableIndex -from core.index.vector_index import VectorIndex +from core.index.index import IndexBuilder from extensions.ext_database import db from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \ AppDatasetJoin @@ -33,29 +32,24 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, index_struct=index_struct ) - vector_index = VectorIndex(dataset=dataset) - keyword_table_index = KeywordTableIndex(dataset=dataset) - documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() - index_doc_ids = [document.id for document in documents] segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() - index_node_ids = [segment.index_node_id for segment in segments] + + vector_index = IndexBuilder.get_index(dataset, 'high_quality') + kw_index = IndexBuilder.get_index(dataset, 'economy') # delete from vector index - if dataset.indexing_technique == "high_quality": - for index_doc_id in index_doc_ids: - try: - vector_index.del_doc(index_doc_id) - except Exception: - logging.exception("Delete doc index failed when dataset deleted.") - continue + if vector_index: + try: + vector_index.delete() + except Exception: + logging.exception("Delete doc index failed when dataset deleted.") # delete from keyword index - if index_node_ids: - try: - keyword_table_index.del_nodes(index_node_ids) - except Exception: - logging.exception("Delete nodes index failed when dataset deleted.") + try: + kw_index.delete() + except Exception: + logging.exception("Delete nodes index failed when dataset deleted.") for document in documents: db.session.delete(document) @@ -63,7 +57,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, for segment in segments: db.session.delete(segment) - db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == dataset_id).delete() db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete() db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete() db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete() diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 2d8a4fa75e..650037b6eb 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -4,8 +4,7 @@ import time import click from celery import shared_task -from core.index.keyword_table_index import KeywordTableIndex -from core.index.vector_index import VectorIndex +from core.index.index import IndexBuilder from extensions.ext_database import db from models.dataset import DocumentSegment, Dataset @@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str): if not dataset: raise Exception('Document has no dataset') - vector_index = VectorIndex(dataset=dataset) - keyword_table_index = KeywordTableIndex(dataset=dataset) + vector_index = IndexBuilder.get_index(dataset, 'high_quality') + kw_index = IndexBuilder.get_index(dataset, 'economy') segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - vector_index.del_nodes(index_node_ids) + if vector_index: + vector_index.delete_by_document_id(document_id) # delete from keyword index if index_node_ids: - keyword_table_index.del_nodes(index_node_ids) + kw_index.delete_by_ids(index_node_ids) for segment in segments: db.session.delete(segment) + db.session.commit() end_at = time.perf_counter() logging.info( diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 3fdc03884f..a516fd23c2 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -5,8 +5,7 @@ from typing import List import click from celery import shared_task -from core.index.keyword_table_index import KeywordTableIndex -from core.index.vector_index import VectorIndex +from core.index.index import IndexBuilder from extensions.ext_database import db from models.dataset import DocumentSegment, Dataset, Document @@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str): if not dataset: raise Exception('Document has no dataset') - vector_index = VectorIndex(dataset=dataset) - keyword_table_index = KeywordTableIndex(dataset=dataset) + vector_index = IndexBuilder.get_index(dataset, 'high_quality') + kw_index = IndexBuilder.get_index(dataset, 'economy') for document_id in document_ids: document = db.session.query(Document).filter( Document.id == document_id ).first() db.session.delete(document) + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - vector_index.del_nodes(index_node_ids) + if vector_index: + vector_index.delete_by_document_id(document_id) # delete from keyword index if index_node_ids: - keyword_table_index.del_nodes(index_node_ids) + kw_index.delete_by_ids(index_node_ids) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index f5f9129558..fac50510e5 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -3,10 +3,12 @@ import time import click from celery import shared_task -from llama_index.data_structs.node_v2 import DocumentRelationship, Node -from core.index.vector_index import VectorIndex +from langchain.schema import Document + +from core.index.index import IndexBuilder from extensions.ext_database import db -from models.dataset import DocumentSegment, Document, Dataset +from models.dataset import DocumentSegment, Dataset +from models.dataset import Document as DatasetDocument @shared_task @@ -24,49 +26,47 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): dataset = Dataset.query.filter_by( id=dataset_id ).first() + if not dataset: raise Exception('Dataset not found') - documents = Document.query.filter_by(dataset_id=dataset_id).all() - if documents: - vector_index = VectorIndex(dataset=dataset) - for document in documents: - # delete from vector index - if action == "remove": - vector_index.del_doc(document.id) - elif action == "add": + + if action == "remove": + index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) + index.delete() + elif action == "add": + dataset_documents = db.session.query(DatasetDocument).filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == 'completed', + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ).all() + + if dataset_documents: + # save vector index + index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) + for dataset_document in dataset_documents: + # delete from vector index segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == document.id, + DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True ) .order_by(DocumentSegment.position.asc()).all() - nodes = [] - previous_node = None + documents = [] for segment in segments: - relationships = { - DocumentRelationship.SOURCE: document.id - } - - if previous_node: - relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id - - previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id - - node = Node( - doc_id=segment.index_node_id, - doc_hash=segment.index_node_hash, - text=segment.content, - extra_info=None, - node_info=None, - relationships=relationships + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } ) - previous_node = node - nodes.append(node) + documents.append(document) + # save vector index - vector_index.add_nodes( - nodes=nodes, - duplicate_check=True - ) + index.add_texts(documents) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 56869428df..dd50284f32 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -6,11 +6,9 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.data_source.notion import NotionPageReader -from core.index.keyword_table_index import KeywordTableIndex -from core.index.vector_index import VectorIndex +from core.data_loader.loader.notion import NotionLoader +from core.index.index import IndexBuilder from core.indexing_runner import IndexingRunner, DocumentIsPausedException -from core.llm.error import ProviderTokenNotInitError from extensions.ext_database import db from models.dataset import Document, Dataset, DocumentSegment from models.source import DataSourceBinding @@ -43,6 +41,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): raise ValueError("no notion page found") workspace_id = data_source_info['notion_workspace_id'] page_id = data_source_info['notion_page_id'] + page_type = data_source_info['type'] page_edited_time = data_source_info['last_edited_time'] data_source_binding = DataSourceBinding.query.filter( db.and_( @@ -54,8 +53,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ).first() if not data_source_binding: raise ValueError('Data source binding not found.') - reader = NotionPageReader(integration_token=data_source_binding.access_token) - last_edited_time = reader.get_page_last_edited_time(page_id) + + loader = NotionLoader( + notion_access_token=data_source_binding.access_token, + notion_workspace_id=workspace_id, + notion_obj_id=page_id, + notion_page_type=page_type + ) + + last_edited_time = loader.get_notion_last_edited_time() + # check the page is updated if last_edited_time != page_edited_time: document.indexing_status = 'parsing' @@ -68,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): if not dataset: raise Exception('Dataset not found') - vector_index = VectorIndex(dataset=dataset) - keyword_table_index = KeywordTableIndex(dataset=dataset) + vector_index = IndexBuilder.get_index(dataset, 'high_quality') + kw_index = IndexBuilder.get_index(dataset, 'economy') segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - vector_index.del_nodes(index_node_ids) + if vector_index: + vector_index.delete_by_document_id(document_id) # delete from keyword index if index_node_ids: - keyword_table_index.del_nodes(index_node_ids) + kw_index.delete_by_ids(index_node_ids) for segment in segments: db.session.delete(segment) @@ -89,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) except Exception: logging.exception("Cleaned document when document update data source or process rule failed") + try: indexing_runner = IndexingRunner() indexing_runner.run([document]) end_at = time.perf_counter() logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) - except DocumentIsPausedException: - logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow')) - except ProviderTokenNotInitError as e: - document.indexing_status = 'error' - document.error = str(e.description) - document.stopped_at = datetime.datetime.utcnow() - db.session.commit() - except Exception as e: - logging.exception("consume update document failed") - document.indexing_status = 'error' - document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() - db.session.commit() + except DocumentIsPausedException as ex: + logging.info(click.style(str(ex), fg='yellow')) + except Exception: + pass diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 211d110fa8..328dfdfc45 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -7,7 +7,6 @@ from celery import shared_task from werkzeug.exceptions import NotFound from core.indexing_runner import IndexingRunner, DocumentIsPausedException -from core.llm.error import ProviderTokenNotInitError from extensions.ext_database import db from models.dataset import Document @@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list): Usage: document_indexing_task.delay(dataset_id, document_id) """ documents = [] + start_at = time.perf_counter() for document_id in document_ids: logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) - start_at = time.perf_counter() document = db.session.query(Document).filter( Document.id == document_id, @@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) - except DocumentIsPausedException: - logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow')) - except ProviderTokenNotInitError as e: - document.indexing_status = 'error' - document.error = str(e.description) - document.stopped_at = datetime.datetime.utcnow() - db.session.commit() - except Exception as e: - logging.exception("consume document failed") - document.indexing_status = 'error' - document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() - db.session.commit() + logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + except DocumentIsPausedException as ex: + logging.info(click.style(str(ex), fg='yellow')) + except Exception: + pass diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 8fee81f321..ae9e2c3cda 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -6,10 +6,8 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.index.keyword_table_index import KeywordTableIndex -from core.index.vector_index import VectorIndex +from core.index.index import IndexBuilder from core.indexing_runner import IndexingRunner, DocumentIsPausedException -from core.llm.error import ProviderTokenNotInitError from extensions.ext_database import db from models.dataset import Document, Dataset, DocumentSegment @@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str): if not dataset: raise Exception('Dataset not found') - vector_index = VectorIndex(dataset=dataset) - keyword_table_index = KeywordTableIndex(dataset=dataset) + vector_index = IndexBuilder.get_index(dataset, 'high_quality') + kw_index = IndexBuilder.get_index(dataset, 'economy') segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - vector_index.del_nodes(index_node_ids) + if vector_index: + vector_index.delete_by_ids(index_node_ids) # delete from keyword index if index_node_ids: - keyword_table_index.del_nodes(index_node_ids) + kw_index.delete_by_ids(index_node_ids) for segment in segments: db.session.delete(segment) @@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str): click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) except Exception: logging.exception("Cleaned document when document update data source or process rule failed") + try: indexing_runner = IndexingRunner() indexing_runner.run([document]) end_at = time.perf_counter() logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) - except DocumentIsPausedException: - logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow')) - except ProviderTokenNotInitError as e: - document.indexing_status = 'error' - document.error = str(e.description) - document.stopped_at = datetime.datetime.utcnow() - db.session.commit() - except Exception as e: - logging.exception("consume update document failed") - document.indexing_status = 'error' - document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() - db.session.commit() + except DocumentIsPausedException as ex: + logging.info(click.style(str(ex), fg='yellow')) + except Exception: + pass diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 3ab48e8a46..bde8541bea 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -1,4 +1,3 @@ -import datetime import logging import time @@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): indexing_runner.run_in_indexing_status(document) end_at = time.perf_counter() logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) - except DocumentIsPausedException: - logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow')) - except Exception as e: - logging.exception("consume document failed") - document.indexing_status = 'error' - document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() - db.session.commit() + except DocumentIsPausedException as ex: + logging.info(click.style(str(ex), fg='yellow')) + except Exception: + pass diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 99a4bd3ec3..81e8367b53 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -5,8 +5,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.index.keyword_table_index import KeywordTableIndex -from core.index.vector_index import VectorIndex +from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment, Document @@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str): if not dataset: raise Exception('Document has no dataset') - vector_index = VectorIndex(dataset=dataset) - keyword_table_index = KeywordTableIndex(dataset=dataset) + vector_index = IndexBuilder.get_index(dataset, 'high_quality') + kw_index = IndexBuilder.get_index(dataset, 'economy') # delete from vector index - vector_index.del_doc(document.id) + vector_index.delete_by_document_id(document.id) # delete from keyword index segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: - keyword_table_index.del_nodes(index_node_ids) + kw_index.delete_by_ids(index_node_ids) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/remove_segment_from_index_task.py b/api/tasks/remove_segment_from_index_task.py index 48cebfc4d1..821d7dc3a7 100644 --- a/api/tasks/remove_segment_from_index_task.py +++ b/api/tasks/remove_segment_from_index_task.py @@ -5,8 +5,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.index.keyword_table_index import KeywordTableIndex -from core.index.vector_index import VectorIndex +from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -36,17 +35,28 @@ def remove_segment_from_index_task(segment_id: str): dataset = segment.dataset if not dataset: - raise Exception('Segment has no dataset') + logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + return - vector_index = VectorIndex(dataset=dataset) - keyword_table_index = KeywordTableIndex(dataset=dataset) + dataset_document = segment.document + + if not dataset_document: + logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + return + + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': + logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + return + + vector_index = IndexBuilder.get_index(dataset, 'high_quality') + kw_index = IndexBuilder.get_index(dataset, 'economy') # delete from vector index - if dataset.indexing_technique == "high_quality": - vector_index.del_nodes([segment.index_node_id]) + if vector_index: + vector_index.delete_by_ids([segment.index_node_id]) # delete from keyword index - keyword_table_index.del_nodes([segment.index_node_id]) + kw_index.delete_by_ids([segment.index_node_id]) end_at = time.perf_counter() logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))