mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 23:55:54 +08:00
feat: upgrade langchain (#430)
Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
parent
1dee5de9b4
commit
3241e4015b
@ -14,7 +14,7 @@ from flask import Flask, request, Response, session
|
|||||||
import flask_login
|
import flask_login
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
|
|
||||||
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \
|
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||||
ext_database, ext_storage
|
ext_database, ext_storage
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_login import login_manager
|
from extensions.ext_login import login_manager
|
||||||
@ -79,7 +79,6 @@ def initialize_extensions(app):
|
|||||||
ext_database.init_app(app)
|
ext_database.init_app(app)
|
||||||
ext_migrate.init(app, db)
|
ext_migrate.init(app, db)
|
||||||
ext_redis.init_app(app)
|
ext_redis.init_app(app)
|
||||||
ext_vector_store.init_app(app)
|
|
||||||
ext_storage.init_app(app)
|
ext_storage.init_app(app)
|
||||||
ext_celery.init_app(app)
|
ext_celery.init_app(app)
|
||||||
ext_session.init_app(app)
|
ext_session.init_app(app)
|
||||||
|
@ -1,15 +1,19 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from core.index.index import IndexBuilder
|
||||||
from libs.password import password_pattern, valid_password, hash_password
|
from libs.password import password_pattern, valid_password, hash_password
|
||||||
from libs.helper import email as email_validate
|
from libs.helper import email as email_validate
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.rsa import generate_key_pair
|
from libs.rsa import generate_key_pair
|
||||||
from models.account import InvitationCode, Tenant
|
from models.account import InvitationCode, Tenant
|
||||||
|
from models.dataset import Dataset
|
||||||
from models.model import Account
|
from models.model import Account
|
||||||
import secrets
|
import secrets
|
||||||
import base64
|
import base64
|
||||||
@ -159,8 +163,39 @@ def generate_upper_string():
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
|
||||||
|
def recreate_all_dataset_indexes():
|
||||||
|
click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
|
||||||
|
recreate_count = 0
|
||||||
|
|
||||||
|
page = 1
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality')\
|
||||||
|
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||||
|
except NotFound:
|
||||||
|
break
|
||||||
|
|
||||||
|
page += 1
|
||||||
|
for dataset in datasets:
|
||||||
|
try:
|
||||||
|
click.echo('Recreating dataset index: {}'.format(dataset.id))
|
||||||
|
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
|
if index and index._is_origin():
|
||||||
|
index.recreate_dataset(dataset)
|
||||||
|
recreate_count += 1
|
||||||
|
else:
|
||||||
|
click.echo('passed.')
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||||
|
continue
|
||||||
|
|
||||||
|
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
|
||||||
|
|
||||||
|
|
||||||
def register_commands(app):
|
def register_commands(app):
|
||||||
app.cli.add_command(reset_password)
|
app.cli.add_command(reset_password)
|
||||||
app.cli.add_command(reset_email)
|
app.cli.add_command(reset_email)
|
||||||
app.cli.add_command(generate_invitation_codes)
|
app.cli.add_command(generate_invitation_codes)
|
||||||
app.cli.add_command(reset_encrypt_key_pair)
|
app.cli.add_command(reset_encrypt_key_pair)
|
||||||
|
app.cli.add_command(recreate_all_dataset_indexes)
|
||||||
|
@ -187,11 +187,13 @@ class Config:
|
|||||||
# For temp use only
|
# For temp use only
|
||||||
# set default LLM provider, default is 'openai', support `azure_openai`
|
# set default LLM provider, default is 'openai', support `azure_openai`
|
||||||
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
|
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
|
||||||
|
|
||||||
# notion import setting
|
# notion import setting
|
||||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||||
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
|
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
|
||||||
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
|
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
|
||||||
|
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
|
||||||
|
|
||||||
|
|
||||||
class CloudEditionConfig(Config):
|
class CloudEditionConfig(Config):
|
||||||
|
@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound
|
|||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.data_source.notion import NotionPageReader
|
from core.data_loader.loader.notion import NotionLoader
|
||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.oauth_data_source import NotionOAuth
|
|
||||||
from models.dataset import Document
|
from models.dataset import Document
|
||||||
from models.source import DataSourceBinding
|
from models.source import DataSourceBinding
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource):
|
|||||||
).first()
|
).first()
|
||||||
if not data_source_binding:
|
if not data_source_binding:
|
||||||
raise NotFound('Data source binding not found.')
|
raise NotFound('Data source binding not found.')
|
||||||
reader = NotionPageReader(integration_token=data_source_binding.access_token)
|
|
||||||
if page_type == 'page':
|
loader = NotionLoader(
|
||||||
page_content = reader.read_page(page_id)
|
notion_access_token=data_source_binding.access_token,
|
||||||
elif page_type == 'database':
|
notion_workspace_id=workspace_id,
|
||||||
page_content = reader.query_database_data(page_id)
|
notion_obj_id=page_id,
|
||||||
else:
|
notion_page_type=page_type
|
||||||
page_content = ""
|
)
|
||||||
|
|
||||||
|
text_docs = loader.load()
|
||||||
return {
|
return {
|
||||||
'content': page_content
|
'content': "\n".join([doc.page_content for doc in text_docs])
|
||||||
}, 200
|
}, 200
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
|
|||||||
UnsupportedFileTypeError
|
UnsupportedFileTypeError
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.index.readers.html_parser import HTMLParser
|
from core.data_loader.file_extractor import FileExtractor
|
||||||
from core.index.readers.pdf_parser import PDFParser
|
|
||||||
from core.index.readers.xlsx_parser import XLSXParser
|
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -123,31 +121,7 @@ class FilePreviewApi(Resource):
|
|||||||
if extension not in ALLOWED_EXTENSIONS:
|
if extension not in ALLOWED_EXTENSIONS:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
text = FileExtractor.load(upload_file, return_text=True)
|
||||||
suffix = Path(upload_file.key).suffix
|
|
||||||
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
|
||||||
storage.download(upload_file.key, filepath)
|
|
||||||
|
|
||||||
if extension == 'pdf':
|
|
||||||
parser = PDFParser({'upload_file': upload_file})
|
|
||||||
text = parser.parse_file(Path(filepath))
|
|
||||||
elif extension in ['html', 'htm']:
|
|
||||||
# Use BeautifulSoup to extract text
|
|
||||||
parser = HTMLParser()
|
|
||||||
text = parser.parse_file(Path(filepath))
|
|
||||||
elif extension == 'xlsx':
|
|
||||||
parser = XLSXParser()
|
|
||||||
text = parser.parse_file(filepath)
|
|
||||||
else:
|
|
||||||
# ['txt', 'markdown', 'md']
|
|
||||||
with open(filepath, "rb") as fp:
|
|
||||||
data = fp.read()
|
|
||||||
encoding = chardet.detect(data)['encoding']
|
|
||||||
if encoding:
|
|
||||||
text = data.decode(encoding=encoding).strip() if data else ''
|
|
||||||
else:
|
|
||||||
text = data.decode(encoding='utf-8').strip() if data else ''
|
|
||||||
|
|
||||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||||
return {'content': text}
|
return {'content': text}
|
||||||
|
|
||||||
|
@ -32,8 +32,13 @@ class VersionApi(Resource):
|
|||||||
'current_version': args.get('current_version')
|
'current_version': args.get('current_version')
|
||||||
})
|
})
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logging.exception("Check update error.")
|
logging.warning("Check update version error: {}.".format(str(error)))
|
||||||
raise InternalServerError()
|
return {
|
||||||
|
'version': args.get('current_version'),
|
||||||
|
'release_date': '',
|
||||||
|
'release_notes': '',
|
||||||
|
'can_auto_update': False
|
||||||
|
}
|
||||||
|
|
||||||
content = json.loads(response.content)
|
content = json.loads(response.content)
|
||||||
return {
|
return {
|
||||||
|
@ -3,19 +3,11 @@ from typing import Optional
|
|||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
from jieba.analyse import default_tfidf
|
|
||||||
from langchain import set_handler
|
|
||||||
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
|
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
|
||||||
from llama_index import IndexStructType, QueryMode
|
|
||||||
from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||||
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
|
|
||||||
from core.index.keyword_table.stopwords import STOPWORDS
|
|
||||||
from core.prompt.prompt_template import OneLineFormatter
|
from core.prompt.prompt_template import OneLineFormatter
|
||||||
from core.vector_store.vector_store import VectorStore
|
|
||||||
from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery
|
|
||||||
|
|
||||||
|
|
||||||
class HostedOpenAICredential(BaseModel):
|
class HostedOpenAICredential(BaseModel):
|
||||||
@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials()
|
|||||||
def init_app(app: Flask):
|
def init_app(app: Flask):
|
||||||
formatter = OneLineFormatter()
|
formatter = OneLineFormatter()
|
||||||
DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
|
DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
|
||||||
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map()
|
|
||||||
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = {
|
|
||||||
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
|
|
||||||
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
|
|
||||||
}
|
|
||||||
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = {
|
|
||||||
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
|
|
||||||
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
|
|
||||||
}
|
|
||||||
|
|
||||||
default_tfidf.stop_words = STOPWORDS
|
|
||||||
|
|
||||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||||
langchain.verbose = True
|
langchain.verbose = True
|
||||||
set_handler(DifyStdOutCallbackHandler())
|
|
||||||
|
|
||||||
if app.config.get("OPENAI_API_KEY"):
|
if app.config.get("OPENAI_API_KEY"):
|
||||||
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
|
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
|
||||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
|
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
|
||||||
from langchain.callbacks import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
|
|
||||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||||
@ -16,23 +16,20 @@ class AgentBuilder:
|
|||||||
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
|
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
|
||||||
dataset_tool_callback_handler: DatasetToolCallbackHandler,
|
dataset_tool_callback_handler: DatasetToolCallbackHandler,
|
||||||
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
|
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
|
||||||
llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
|
|
||||||
llm = LLMBuilder.to_llm(
|
llm = LLMBuilder.to_llm(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_name=agent_loop_gather_callback_handler.model_name,
|
model_name=agent_loop_gather_callback_handler.model_name,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=1024,
|
max_tokens=1024,
|
||||||
callback_manager=llm_callback_manager
|
callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_callback_manager = CallbackManager([
|
|
||||||
agent_loop_gather_callback_handler,
|
|
||||||
dataset_tool_callback_handler,
|
|
||||||
DifyStdOutCallbackHandler()
|
|
||||||
])
|
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
tool.callback_manager = tool_callback_manager
|
tool.callbacks = [
|
||||||
|
agent_loop_gather_callback_handler,
|
||||||
|
dataset_tool_callback_handler,
|
||||||
|
DifyStdOutCallbackHandler()
|
||||||
|
]
|
||||||
|
|
||||||
prompt = cls.build_agent_prompt_template(
|
prompt = cls.build_agent_prompt_template(
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@ -54,7 +51,7 @@ class AgentBuilder:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
callback_manager=agent_callback_manager,
|
callbacks=agent_callback_manager,
|
||||||
max_iterations=6,
|
max_iterations=6,
|
||||||
early_stopping_method="generate",
|
early_stopping_method="generate",
|
||||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||||
|
@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask
|
|||||||
|
|
||||||
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||||
"""Callback Handler that prints to std out."""
|
"""Callback Handler that prints to std out."""
|
||||||
|
raise_error: bool = True
|
||||||
|
|
||||||
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
|
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
|
||||||
"""Initialize callback handler."""
|
"""Initialize callback handler."""
|
||||||
@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
self._current_loop.completion = response.generations[0][0].text
|
self._current_loop.completion = response.generations[0][0].text
|
||||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||||
|
|
||||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
||||||
"""Do nothing."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
self._agent_loops = []
|
self._agent_loops = []
|
||||||
self._current_loop = None
|
self._current_loop = None
|
||||||
|
|
||||||
def on_chain_start(
|
|
||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Print out that we are entering a chain."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
||||||
"""Print out that we finished a chain."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_chain_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
logging.error(error)
|
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
self._agent_loops = []
|
self._agent_loops = []
|
||||||
self._current_loop = None
|
self._current_loop = None
|
||||||
|
|
||||||
def on_text(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
color: Optional[str] = None,
|
|
||||||
end: str = "",
|
|
||||||
**kwargs: Optional[str],
|
|
||||||
) -> None:
|
|
||||||
"""Run on additional input from chains and agents."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||||
"""Run on agent end."""
|
"""Run on agent end."""
|
||||||
# Final Answer
|
# Final Answer
|
||||||
|
@ -3,7 +3,6 @@ import logging
|
|||||||
from typing import Any, Dict, List, Union, Optional
|
from typing import Any, Dict, List, Union, Optional
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
|
||||||
|
|
||||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||||
from core.conversation_message_task import ConversationMessageTask
|
from core.conversation_message_task import ConversationMessageTask
|
||||||
@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask
|
|||||||
|
|
||||||
class DatasetToolCallbackHandler(BaseCallbackHandler):
|
class DatasetToolCallbackHandler(BaseCallbackHandler):
|
||||||
"""Callback Handler that prints to std out."""
|
"""Callback Handler that prints to std out."""
|
||||||
|
raise_error: bool = True
|
||||||
|
|
||||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||||
"""Initialize callback handler."""
|
"""Initialize callback handler."""
|
||||||
@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
logging.error(error)
|
logging.error(error)
|
||||||
|
|
||||||
def on_chain_start(
|
|
||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_chain_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_llm_start(
|
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
||||||
"""Do nothing."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_llm_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
logging.error(error)
|
|
||||||
|
|
||||||
def on_agent_action(
|
|
||||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
|
||||||
) -> Any:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_text(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
color: Optional[str] = None,
|
|
||||||
end: str = "",
|
|
||||||
**kwargs: Optional[str],
|
|
||||||
) -> None:
|
|
||||||
"""Run on additional input from chains and agents."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
|
||||||
"""Run on agent end."""
|
|
||||||
pass
|
|
||||||
|
@ -1,39 +1,26 @@
|
|||||||
from llama_index import Response
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import DocumentSegment
|
from models.dataset import DocumentSegment
|
||||||
|
|
||||||
|
|
||||||
class IndexToolCallbackHandler:
|
class DatasetIndexToolCallbackHandler:
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._response = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def response(self) -> Response:
|
|
||||||
return self._response
|
|
||||||
|
|
||||||
def on_tool_end(self, response: Response) -> None:
|
|
||||||
"""Handle tool end."""
|
|
||||||
self._response = response
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
|
|
||||||
"""Callback handler for dataset tool."""
|
"""Callback handler for dataset tool."""
|
||||||
|
|
||||||
def __init__(self, dataset_id: str) -> None:
|
def __init__(self, dataset_id: str) -> None:
|
||||||
super().__init__()
|
|
||||||
self.dataset_id = dataset_id
|
self.dataset_id = dataset_id
|
||||||
|
|
||||||
def on_tool_end(self, response: Response) -> None:
|
def on_tool_end(self, documents: List[Document]) -> None:
|
||||||
"""Handle tool end."""
|
"""Handle tool end."""
|
||||||
for node in response.source_nodes:
|
for document in documents:
|
||||||
index_node_id = node.node.doc_id
|
doc_id = document.metadata['doc_id']
|
||||||
|
|
||||||
# add hit count to document segment
|
# add hit count to document segment
|
||||||
db.session.query(DocumentSegment).filter(
|
db.session.query(DocumentSegment).filter(
|
||||||
DocumentSegment.dataset_id == self.dataset_id,
|
DocumentSegment.dataset_id == self.dataset_id,
|
||||||
DocumentSegment.index_node_id == index_node_id
|
DocumentSegment.index_node_id == doc_id
|
||||||
).update(
|
).update(
|
||||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
|
@ -3,7 +3,7 @@ import time
|
|||||||
from typing import Any, Dict, List, Union, Optional
|
from typing import Any, Dict, List, Union, Optional
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage
|
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||||
|
|
||||||
from core.callback_handler.entity.llm_message import LLMMessage
|
from core.callback_handler.entity.llm_message import LLMMessage
|
||||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||||
@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI
|
|||||||
|
|
||||||
|
|
||||||
class LLMCallbackHandler(BaseCallbackHandler):
|
class LLMCallbackHandler(BaseCallbackHandler):
|
||||||
|
raise_error: bool = True
|
||||||
|
|
||||||
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||||
conversation_message_task: ConversationMessageTask):
|
conversation_message_task: ConversationMessageTask):
|
||||||
@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Whether to call verbose callbacks even if verbose is False."""
|
"""Whether to call verbose callbacks even if verbose is False."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
self.start_at = time.perf_counter()
|
||||||
|
real_prompts = []
|
||||||
|
for message in messages[0]:
|
||||||
|
if message.type == 'human':
|
||||||
|
role = 'user'
|
||||||
|
elif message.type == 'ai':
|
||||||
|
role = 'assistant'
|
||||||
|
else:
|
||||||
|
role = 'system'
|
||||||
|
|
||||||
|
real_prompts.append({
|
||||||
|
"role": role,
|
||||||
|
"text": message.content
|
||||||
|
})
|
||||||
|
|
||||||
|
self.llm_message.prompt = real_prompts
|
||||||
|
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
self.start_at = time.perf_counter()
|
self.start_at = time.perf_counter()
|
||||||
|
|
||||||
if 'Chat' in serialized['name']:
|
self.llm_message.prompt = [{
|
||||||
real_prompts = []
|
"role": 'user',
|
||||||
messages = []
|
"text": prompts[0]
|
||||||
for prompt in prompts:
|
}]
|
||||||
role, content = prompt.split(': ', maxsplit=1)
|
|
||||||
if role == 'human':
|
|
||||||
role = 'user'
|
|
||||||
message = HumanMessage(content=content)
|
|
||||||
elif role == 'ai':
|
|
||||||
role = 'assistant'
|
|
||||||
message = AIMessage(content=content)
|
|
||||||
else:
|
|
||||||
message = SystemMessage(content=content)
|
|
||||||
|
|
||||||
real_prompt = {
|
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
|
||||||
"role": role,
|
|
||||||
"text": content
|
|
||||||
}
|
|
||||||
real_prompts.append(real_prompt)
|
|
||||||
messages.append(message)
|
|
||||||
|
|
||||||
self.llm_message.prompt = real_prompts
|
|
||||||
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
|
|
||||||
else:
|
|
||||||
self.llm_message.prompt = [{
|
|
||||||
"role": 'user',
|
|
||||||
"text": prompts[0]
|
|
||||||
}]
|
|
||||||
|
|
||||||
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
|
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
|||||||
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
|
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
|
||||||
else:
|
else:
|
||||||
logging.error(error)
|
logging.error(error)
|
||||||
|
|
||||||
def on_chain_start(
|
|
||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_chain_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_tool_start(
|
|
||||||
self,
|
|
||||||
serialized: Dict[str, Any],
|
|
||||||
input_str: str,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_agent_action(
|
|
||||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
|
||||||
) -> Any:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_tool_end(
|
|
||||||
self,
|
|
||||||
output: str,
|
|
||||||
color: Optional[str] = None,
|
|
||||||
observation_prefix: Optional[str] = None,
|
|
||||||
llm_prefix: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_tool_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_text(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
color: Optional[str] = None,
|
|
||||||
end: str = "",
|
|
||||||
**kwargs: Optional[str],
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_agent_finish(
|
|
||||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from typing import Any, Dict, List, Union, Optional
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
|
||||||
|
|
||||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||||
from core.callback_handler.entity.chain_result import ChainResult
|
from core.callback_handler.entity.chain_result import ChainResult
|
||||||
@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask
|
|||||||
|
|
||||||
class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
||||||
"""Callback Handler that prints to std out."""
|
"""Callback Handler that prints to std out."""
|
||||||
|
raise_error: bool = True
|
||||||
|
|
||||||
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
|
||||||
"""Initialize callback handler."""
|
"""Initialize callback handler."""
|
||||||
@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Print out that we are entering a chain."""
|
"""Print out that we are entering a chain."""
|
||||||
if not self._current_chain_result:
|
if not self._current_chain_result:
|
||||||
self._current_chain_result = ChainResult(
|
chain_type = serialized['id'][-1]
|
||||||
type=serialized['name'],
|
if chain_type:
|
||||||
prompt=inputs,
|
self._current_chain_result = ChainResult(
|
||||||
started_at=time.perf_counter()
|
type=chain_type,
|
||||||
)
|
prompt=inputs,
|
||||||
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
|
started_at=time.perf_counter()
|
||||||
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
|
)
|
||||||
|
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
|
||||||
|
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
"""Print out that we finished a chain."""
|
"""Print out that we finished a chain."""
|
||||||
@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
logging.error(error)
|
logging.error(error)
|
||||||
self.clear_chain_results()
|
self.clear_chain_results()
|
||||||
|
|
||||||
def on_llm_start(
|
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
||||||
"""Do nothing."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_llm_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
logging.error(error)
|
|
||||||
|
|
||||||
def on_tool_start(
|
|
||||||
self,
|
|
||||||
serialized: Dict[str, Any],
|
|
||||||
input_str: str,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_agent_action(
|
|
||||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
|
||||||
) -> Any:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_tool_end(
|
|
||||||
self,
|
|
||||||
output: str,
|
|
||||||
color: Optional[str] = None,
|
|
||||||
observation_prefix: Optional[str] = None,
|
|
||||||
llm_prefix: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_tool_error(
|
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing."""
|
|
||||||
logging.error(error)
|
|
||||||
|
|
||||||
def on_text(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
color: Optional[str] = None,
|
|
||||||
end: str = "",
|
|
||||||
**kwargs: Optional[str],
|
|
||||||
) -> None:
|
|
||||||
"""Run on additional input from chains and agents."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
|
||||||
"""Run on agent end."""
|
|
||||||
pass
|
|
@ -1,9 +1,10 @@
|
|||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.input import print_text
|
from langchain.input import print_text
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage
|
||||||
|
|
||||||
|
|
||||||
class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||||
@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Initialize callback handler."""
|
"""Initialize callback handler."""
|
||||||
self.color = color
|
self.color = color
|
||||||
|
|
||||||
|
def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
print_text("\n[on_chat_model_start]\n", color='blue')
|
||||||
|
for sub_messages in messages:
|
||||||
|
for sub_message in sub_messages:
|
||||||
|
print_text(str(sub_message) + "\n", color='blue')
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Print out the prompts."""
|
"""Print out the prompts."""
|
||||||
print_text("\n[on_llm_start]\n", color='blue')
|
print_text("\n[on_llm_start]\n", color='blue')
|
||||||
|
print_text(prompts[0] + "\n", color='blue')
|
||||||
if 'Chat' in serialized['name']:
|
|
||||||
for prompt in prompts:
|
|
||||||
print_text(prompt + "\n", color='blue')
|
|
||||||
else:
|
|
||||||
print_text(prompts[0] + "\n", color='blue')
|
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Print out that we are entering a chain."""
|
"""Print out that we are entering a chain."""
|
||||||
class_name = serialized["name"]
|
chain_type = serialized['id'][-1]
|
||||||
print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink')
|
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
"""Print out that we finished a chain."""
|
"""Print out that we finished a chain."""
|
||||||
@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Run on agent end."""
|
"""Run on agent end."""
|
||||||
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
|
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_llm(self) -> bool:
|
||||||
|
"""Whether to ignore LLM callbacks."""
|
||||||
|
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_chain(self) -> bool:
|
||||||
|
"""Whether to ignore chain callbacks."""
|
||||||
|
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_agent(self) -> bool:
|
||||||
|
"""Whether to ignore agent callbacks."""
|
||||||
|
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_chat_model(self) -> bool:
|
||||||
|
"""Whether to ignore chat model callbacks."""
|
||||||
|
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||||
|
|
||||||
|
|
||||||
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
|
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
|
||||||
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain.callbacks import CallbackManager
|
|
||||||
|
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
||||||
from core.chain.tool_chain import ToolChain
|
from core.chain.tool_chain import ToolChain
|
||||||
@ -14,7 +12,7 @@ class ChainBuilder:
|
|||||||
tool=tool,
|
tool=tool,
|
||||||
input_key=kwargs.get('input_key', 'input'),
|
input_key=kwargs.get('input_key', 'input'),
|
||||||
output_key=kwargs.get('output_key', 'tool_output'),
|
output_key=kwargs.get('output_key', 'tool_output'),
|
||||||
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
|
callbacks=[DifyStdOutCallbackHandler()]
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -27,7 +25,7 @@ class ChainBuilder:
|
|||||||
sensitive_words=sensitive_words.split(","),
|
sensitive_words=sensitive_words.split(","),
|
||||||
canned_response=tool_config.get("canned_response", ''),
|
canned_response=tool_config.get("canned_response", ''),
|
||||||
output_key="sensitive_word_avoidance_output",
|
output_key="sensitive_word_avoidance_output",
|
||||||
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]),
|
callbacks=[DifyStdOutCallbackHandler()],
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
"""Base classes for LLM-powered router chains."""
|
"""Base classes for LLM-powered router chains."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
|
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
|
||||||
|
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.prompts import BasePromptTemplate
|
from langchain.prompts import BasePromptTemplate
|
||||||
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
|
from langchain.schema import BaseOutputParser, OutputParserException
|
||||||
|
|
||||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||||
|
|
||||||
@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any]
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
output = cast(
|
output = cast(
|
||||||
Dict[str, Any],
|
Dict[str, Any],
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from typing import Optional, List
|
from typing import Optional, List, cast
|
||||||
|
|
||||||
from langchain.callbacks import SharedCallbackManager, CallbackManager
|
|
||||||
from langchain.chains import SequentialChain
|
from langchain.chains import SequentialChain
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
|
|
||||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
|
||||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||||
from core.chain.chain_builder import ChainBuilder
|
from core.chain.chain_builder import ChainBuilder
|
||||||
@ -18,6 +16,7 @@ from models.dataset import Dataset
|
|||||||
class MainChainBuilder:
|
class MainChainBuilder:
|
||||||
@classmethod
|
@classmethod
|
||||||
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
||||||
|
rest_tokens: int,
|
||||||
conversation_message_task: ConversationMessageTask):
|
conversation_message_task: ConversationMessageTask):
|
||||||
first_input_key = "input"
|
first_input_key = "input"
|
||||||
final_output_key = "output"
|
final_output_key = "output"
|
||||||
@ -30,6 +29,7 @@ class MainChainBuilder:
|
|||||||
tool_chains, chains_output_key = cls.get_agent_chains(
|
tool_chains, chains_output_key = cls.get_agent_chains(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
agent_mode=agent_mode,
|
agent_mode=agent_mode,
|
||||||
|
rest_tokens=rest_tokens,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
conversation_message_task=conversation_message_task
|
conversation_message_task=conversation_message_task
|
||||||
)
|
)
|
||||||
@ -42,9 +42,8 @@ class MainChainBuilder:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
for chain in chains:
|
for chain in chains:
|
||||||
# do not add handler into singleton callback manager
|
chain = cast(Chain, chain)
|
||||||
if not isinstance(chain.callback_manager, SharedCallbackManager):
|
chain.callbacks.append(chain_callback_handler)
|
||||||
chain.callback_manager.add_handler(chain_callback_handler)
|
|
||||||
|
|
||||||
# build main chain
|
# build main chain
|
||||||
overall_chain = SequentialChain(
|
overall_chain = SequentialChain(
|
||||||
@ -57,7 +56,9 @@ class MainChainBuilder:
|
|||||||
return overall_chain
|
return overall_chain
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
|
||||||
|
rest_tokens: int,
|
||||||
|
memory: Optional[BaseChatMemory],
|
||||||
conversation_message_task: ConversationMessageTask):
|
conversation_message_task: ConversationMessageTask):
|
||||||
# agent mode
|
# agent mode
|
||||||
chains = []
|
chains = []
|
||||||
@ -93,7 +94,8 @@ class MainChainBuilder:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
datasets=datasets,
|
datasets=datasets,
|
||||||
conversation_message_task=conversation_message_task,
|
conversation_message_task=conversation_message_task,
|
||||||
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
|
rest_tokens=rest_tokens,
|
||||||
|
callbacks=[DifyStdOutCallbackHandler()]
|
||||||
)
|
)
|
||||||
chains.append(multi_dataset_router_chain)
|
chains.append(multi_dataset_router_chain)
|
||||||
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
|
import math
|
||||||
from typing import Mapping, List, Dict, Any, Optional
|
from typing import Mapping, List, Dict, Any, Optional
|
||||||
|
|
||||||
from langchain import LLMChain, PromptTemplate, ConversationChain
|
from langchain import PromptTemplate
|
||||||
from langchain.callbacks import CallbackManager
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.schema import BaseLanguageModel
|
|
||||||
from pydantic import Extra
|
from pydantic import Extra
|
||||||
|
|
||||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||||
@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
|
|||||||
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
|
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
|
||||||
from core.conversation_message_task import ConversationMessageTask
|
from core.conversation_message_task import ConversationMessageTask
|
||||||
from core.llm.llm_builder import LLMBuilder
|
from core.llm.llm_builder import LLMBuilder
|
||||||
from core.tool.dataset_tool_builder import DatasetToolBuilder
|
from core.tool.dataset_index_tool import DatasetTool
|
||||||
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
|
from models.dataset import Dataset, DatasetProcessRule
|
||||||
from models.dataset import Dataset
|
|
||||||
|
|
||||||
|
DEFAULT_K = 2
|
||||||
|
CONTEXT_TOKENS_PERCENT = 0.3
|
||||||
MULTI_PROMPT_ROUTER_TEMPLATE = """
|
MULTI_PROMPT_ROUTER_TEMPLATE = """
|
||||||
Given a raw text input to a language model select the model prompt best suited for \
|
Given a raw text input to a language model select the model prompt best suited for \
|
||||||
the input. You will be given the names of the available prompts and a description of \
|
the input. You will be given the names of the available prompts and a description of \
|
||||||
@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain):
|
|||||||
|
|
||||||
router_chain: LLMRouterChain
|
router_chain: LLMRouterChain
|
||||||
"""Chain for deciding a destination chain and the input to it."""
|
"""Chain for deciding a destination chain and the input to it."""
|
||||||
dataset_tools: Mapping[str, EnhanceLlamaIndexTool]
|
dataset_tools: Mapping[str, DatasetTool]
|
||||||
"""Map of name to candidate chains that inputs can be routed to."""
|
"""Map of name to candidate chains that inputs can be routed to."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain):
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
datasets: List[Dataset],
|
datasets: List[Dataset],
|
||||||
conversation_message_task: ConversationMessageTask,
|
conversation_message_task: ConversationMessageTask,
|
||||||
|
rest_tokens: int,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Convenience constructor for instantiating from destination prompts."""
|
"""Convenience constructor for instantiating from destination prompts."""
|
||||||
llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
|
|
||||||
llm = LLMBuilder.to_llm(
|
llm = LLMBuilder.to_llm(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model_name='gpt-3.5-turbo',
|
model_name='gpt-3.5-turbo',
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=1024,
|
max_tokens=1024,
|
||||||
callback_manager=llm_callback_manager
|
callbacks=[DifyStdOutCallbackHandler()]
|
||||||
)
|
)
|
||||||
|
|
||||||
destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description
|
destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
|
||||||
else ('useful for when you want to answer queries about the ' + d.name))
|
else ('useful for when you want to answer queries about the ' + d.name))
|
||||||
for d in datasets]
|
for d in datasets]
|
||||||
destinations_str = "\n".join(destinations)
|
destinations_str = "\n".join(destinations)
|
||||||
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
|
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
|
||||||
destinations=destinations_str
|
destinations=destinations_str
|
||||||
)
|
)
|
||||||
|
|
||||||
router_prompt = PromptTemplate(
|
router_prompt = PromptTemplate(
|
||||||
template=router_template,
|
template=router_template,
|
||||||
input_variables=["input"],
|
input_variables=["input"],
|
||||||
output_parser=RouterOutputParser(),
|
output_parser=RouterOutputParser(),
|
||||||
)
|
)
|
||||||
|
|
||||||
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
|
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
|
||||||
dataset_tools = {}
|
dataset_tools = {}
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
dataset_tool = DatasetToolBuilder.build_dataset_tool(
|
# fulfill description when it is empty
|
||||||
|
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
description = dataset.description
|
||||||
|
if not description:
|
||||||
|
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||||
|
|
||||||
|
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
|
||||||
|
if k == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dataset_tool = DatasetTool(
|
||||||
|
name=f"dataset-{dataset.id}",
|
||||||
|
description=description,
|
||||||
|
k=k,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
response_mode='no_synthesizer', # "compact"
|
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
|
||||||
callback_handler=DatasetToolCallbackHandler(conversation_message_task)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if dataset_tool:
|
dataset_tools[str(dataset.id)] = dataset_tool
|
||||||
dataset_tools[dataset.id] = dataset_tool
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
router_chain=router_chain,
|
router_chain=router_chain,
|
||||||
@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
|
||||||
|
processing_rule = dataset.latest_process_rule
|
||||||
|
if not processing_rule:
|
||||||
|
return DEFAULT_K
|
||||||
|
|
||||||
|
if processing_rule.mode == "custom":
|
||||||
|
rules = processing_rule.rules_dict
|
||||||
|
if not rules:
|
||||||
|
return DEFAULT_K
|
||||||
|
|
||||||
|
segmentation = rules["segmentation"]
|
||||||
|
segment_max_tokens = segmentation["max_tokens"]
|
||||||
|
else:
|
||||||
|
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
|
||||||
|
|
||||||
|
# when rest_tokens is less than default context tokens
|
||||||
|
if rest_tokens < segment_max_tokens * DEFAULT_K:
|
||||||
|
return rest_tokens // segment_max_tokens
|
||||||
|
|
||||||
|
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
|
||||||
|
|
||||||
|
# when context_limit_tokens is less than default context tokens, use default_k
|
||||||
|
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
|
||||||
|
return DEFAULT_K
|
||||||
|
|
||||||
|
# Expand the k value when there's still some room left in the 30% rest tokens space
|
||||||
|
return context_limit_tokens // segment_max_tokens
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any]
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if len(self.dataset_tools) == 0:
|
if len(self.dataset_tools) == 0:
|
||||||
return {"text": ''}
|
return {"text": ''}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import List, Dict
|
from typing import List, Dict, Optional, Any
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
|
|
||||||
@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
|
|||||||
return self.canned_response
|
return self.canned_response
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
text = inputs[self.input_key]
|
text = inputs[self.input_key]
|
||||||
output = self._check_sensitive_word(text)
|
output = self._check_sensitive_word(text)
|
||||||
return {self.output_key: output}
|
return {self.output_key: output}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import List, Dict
|
from typing import List, Dict, Optional, Any
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
@ -30,12 +31,20 @@ class ToolChain(Chain):
|
|||||||
"""
|
"""
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
input = inputs[self.input_key]
|
input = inputs[self.input_key]
|
||||||
output = self.tool.run(input, self.verbose)
|
output = self.tool.run(input, self.verbose)
|
||||||
return {self.output_key: output}
|
return {self.output_key: output}
|
||||||
|
|
||||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
async def _acall(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Run the logic of this chain and return the output."""
|
"""Run the logic of this chain and return the output."""
|
||||||
input = inputs[self.input_key]
|
input = inputs[self.input_key]
|
||||||
output = await self.tool.arun(input, self.verbose)
|
output = await self.tool.arun(input, self.verbose)
|
||||||
|
@ -1,17 +1,18 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, List, Union, Tuple
|
from typing import Optional, List, Union, Tuple
|
||||||
|
|
||||||
from langchain.callbacks import CallbackManager
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.llms import BaseLLM
|
from langchain.llms import BaseLLM
|
||||||
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
|
from langchain.schema import BaseMessage, HumanMessage
|
||||||
from requests.exceptions import ChunkedEncodingError
|
from requests.exceptions import ChunkedEncodingError
|
||||||
|
|
||||||
from core.constant import llm_constant
|
from core.constant import llm_constant
|
||||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
|
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
|
||||||
DifyStdOutCallbackHandler
|
DifyStdOutCallbackHandler
|
||||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler
|
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||||
from core.llm.error import LLMBadRequestError
|
from core.llm.error import LLMBadRequestError
|
||||||
from core.llm.llm_builder import LLMBuilder
|
from core.llm.llm_builder import LLMBuilder
|
||||||
from core.chain.main_chain_builder import MainChainBuilder
|
from core.chain.main_chain_builder import MainChainBuilder
|
||||||
@ -34,8 +35,6 @@ class Completion:
|
|||||||
"""
|
"""
|
||||||
errors: ProviderTokenNotInitError
|
errors: ProviderTokenNotInitError
|
||||||
"""
|
"""
|
||||||
cls.validate_query_tokens(app.tenant_id, app_model_config, query)
|
|
||||||
|
|
||||||
memory = None
|
memory = None
|
||||||
if conversation:
|
if conversation:
|
||||||
# get memory of conversation (read-only)
|
# get memory of conversation (read-only)
|
||||||
@ -48,6 +47,14 @@ class Completion:
|
|||||||
|
|
||||||
inputs = conversation.inputs
|
inputs = conversation.inputs
|
||||||
|
|
||||||
|
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
|
||||||
|
mode=app.mode,
|
||||||
|
tenant_id=app.tenant_id,
|
||||||
|
app_model_config=app_model_config,
|
||||||
|
query=query,
|
||||||
|
inputs=inputs
|
||||||
|
)
|
||||||
|
|
||||||
conversation_message_task = ConversationMessageTask(
|
conversation_message_task = ConversationMessageTask(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
app=app,
|
app=app,
|
||||||
@ -64,6 +71,7 @@ class Completion:
|
|||||||
main_chain = MainChainBuilder.to_langchain_components(
|
main_chain = MainChainBuilder.to_langchain_components(
|
||||||
tenant_id=app.tenant_id,
|
tenant_id=app.tenant_id,
|
||||||
agent_mode=app_model_config.agent_mode_dict,
|
agent_mode=app_model_config.agent_mode_dict,
|
||||||
|
rest_tokens=rest_tokens_for_context_and_memory,
|
||||||
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
|
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
|
||||||
conversation_message_task=conversation_message_task
|
conversation_message_task=conversation_message_task
|
||||||
)
|
)
|
||||||
@ -115,7 +123,7 @@ class Completion:
|
|||||||
memory=memory
|
memory=memory
|
||||||
)
|
)
|
||||||
|
|
||||||
final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
|
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
|
||||||
|
|
||||||
cls.recale_llm_max_tokens(
|
cls.recale_llm_max_tokens(
|
||||||
final_llm=final_llm,
|
final_llm=final_llm,
|
||||||
@ -247,16 +255,14 @@ And answer according to the language of the user's question.
|
|||||||
return messages, ['\nHuman:']
|
return messages, ['\nHuman:']
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
conversation_message_task: ConversationMessageTask) -> CallbackManager:
|
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
|
||||||
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
||||||
if streaming:
|
if streaming:
|
||||||
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
|
return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
|
||||||
else:
|
else:
|
||||||
callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
|
return [llm_callback_handler, DifyStdOutCallbackHandler()]
|
||||||
|
|
||||||
return CallbackManager(callback_handlers)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||||
@ -293,7 +299,8 @@ And answer according to the language of the user's question.
|
|||||||
return memory
|
return memory
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
|
def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
|
||||||
|
query: str, inputs: dict) -> int:
|
||||||
llm = LLMBuilder.to_llm_from_model(
|
llm = LLMBuilder.to_llm_from_model(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
model=app_model_config.model_dict
|
model=app_model_config.model_dict
|
||||||
@ -302,8 +309,26 @@ And answer according to the language of the user's question.
|
|||||||
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
|
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
|
||||||
max_tokens = llm.max_tokens
|
max_tokens = llm.max_tokens
|
||||||
|
|
||||||
if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
|
# get prompt without memory and context
|
||||||
raise LLMBadRequestError("Query is too long")
|
prompt, _ = cls.get_main_llm_prompt(
|
||||||
|
mode=mode,
|
||||||
|
llm=llm,
|
||||||
|
pre_prompt=app_model_config.pre_prompt,
|
||||||
|
query=query,
|
||||||
|
inputs=inputs,
|
||||||
|
chain_output=None,
|
||||||
|
memory=None
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
|
||||||
|
else llm.get_num_tokens_from_messages(prompt)
|
||||||
|
|
||||||
|
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
|
||||||
|
if rest_tokens < 0:
|
||||||
|
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||||
|
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
||||||
|
|
||||||
|
return rest_tokens
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||||
@ -360,7 +385,7 @@ And answer according to the language of the user's question.
|
|||||||
streaming=streaming
|
streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
|
llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
|
||||||
|
|
||||||
cls.recale_llm_max_tokens(
|
cls.recale_llm_max_tokens(
|
||||||
final_llm=llm,
|
final_llm=llm,
|
||||||
|
@ -293,12 +293,12 @@ class PubHandler:
|
|||||||
if not user:
|
if not user:
|
||||||
raise ValueError("user is required")
|
raise ValueError("user is required")
|
||||||
|
|
||||||
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
|
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
|
||||||
return "generate_result:{}-{}".format(user_str, task_id)
|
return "generate_result:{}-{}".format(user_str, task_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
|
def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
|
||||||
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
|
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
|
||||||
return "generate_result_stopped:{}-{}".format(user_str, task_id)
|
return "generate_result_stopped:{}-{}".format(user_str, task_id)
|
||||||
|
|
||||||
def pub_text(self, text: str):
|
def pub_text(self, text: str):
|
||||||
@ -306,10 +306,10 @@ class PubHandler:
|
|||||||
'event': 'message',
|
'event': 'message',
|
||||||
'data': {
|
'data': {
|
||||||
'task_id': self._task_id,
|
'task_id': self._task_id,
|
||||||
'message_id': self._message.id,
|
'message_id': str(self._message.id),
|
||||||
'text': text,
|
'text': text,
|
||||||
'mode': self._conversation.mode,
|
'mode': self._conversation.mode,
|
||||||
'conversation_id': self._conversation.id
|
'conversation_id': str(self._conversation.id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
43
api/core/data_loader/file_extractor.py
Normal file
43
api/core/data_loader/file_extractor.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from langchain.document_loaders import TextLoader, Docx2txtLoader
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
from core.data_loader.loader.csv import CSVLoader
|
||||||
|
from core.data_loader.loader.excel import ExcelLoader
|
||||||
|
from core.data_loader.loader.html import HTMLLoader
|
||||||
|
from core.data_loader.loader.markdown import MarkdownLoader
|
||||||
|
from core.data_loader.loader.pdf import PdfLoader
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from models.model import UploadFile
|
||||||
|
|
||||||
|
|
||||||
|
class FileExtractor:
|
||||||
|
@classmethod
|
||||||
|
def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
suffix = Path(upload_file.key).suffix
|
||||||
|
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||||
|
storage.download(upload_file.key, file_path)
|
||||||
|
|
||||||
|
input_file = Path(file_path)
|
||||||
|
delimiter = '\n'
|
||||||
|
if input_file.suffix == '.xlsx':
|
||||||
|
loader = ExcelLoader(file_path)
|
||||||
|
elif input_file.suffix == '.pdf':
|
||||||
|
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||||
|
elif input_file.suffix in ['.md', '.markdown']:
|
||||||
|
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||||
|
elif input_file.suffix in ['.htm', '.html']:
|
||||||
|
loader = HTMLLoader(file_path)
|
||||||
|
elif input_file.suffix == '.docx':
|
||||||
|
loader = Docx2txtLoader(file_path)
|
||||||
|
elif input_file.suffix == '.csv':
|
||||||
|
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||||
|
else:
|
||||||
|
# txt
|
||||||
|
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||||
|
|
||||||
|
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
|
67
api/core/data_loader/loader/csv.py
Normal file
67
api/core/data_loader/loader/csv.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, List
|
||||||
|
|
||||||
|
from langchain.document_loaders import CSVLoader as LCCSVLoader
|
||||||
|
from langchain.document_loaders.helpers import detect_file_encodings
|
||||||
|
|
||||||
|
from models.dataset import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CSVLoader(LCCSVLoader):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
source_column: Optional[str] = None,
|
||||||
|
csv_args: Optional[Dict] = None,
|
||||||
|
encoding: Optional[str] = None,
|
||||||
|
autodetect_encoding: bool = True,
|
||||||
|
):
|
||||||
|
self.file_path = file_path
|
||||||
|
self.source_column = source_column
|
||||||
|
self.encoding = encoding
|
||||||
|
self.csv_args = csv_args or {}
|
||||||
|
self.autodetect_encoding = autodetect_encoding
|
||||||
|
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
"""Load data into document objects."""
|
||||||
|
try:
|
||||||
|
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
||||||
|
docs = self._read_from_file(csvfile)
|
||||||
|
except UnicodeDecodeError as e:
|
||||||
|
if self.autodetect_encoding:
|
||||||
|
detected_encodings = detect_file_encodings(self.file_path)
|
||||||
|
for encoding in detected_encodings:
|
||||||
|
logger.debug("Trying encoding: ", encoding.encoding)
|
||||||
|
try:
|
||||||
|
with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
|
||||||
|
docs = self._read_from_file(csvfile)
|
||||||
|
break
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Error loading {self.file_path}") from e
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def _read_from_file(self, csvfile):
|
||||||
|
docs = []
|
||||||
|
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||||
|
for i, row in enumerate(csv_reader):
|
||||||
|
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
|
||||||
|
try:
|
||||||
|
source = (
|
||||||
|
row[self.source_column]
|
||||||
|
if self.source_column is not None
|
||||||
|
else ''
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Source column '{self.source_column}' not found in CSV file."
|
||||||
|
)
|
||||||
|
metadata = {"source": source, "row": i}
|
||||||
|
doc = Document(page_content=content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
return docs
|
43
api/core/data_loader/loader/excel.py
Normal file
43
api/core/data_loader/loader/excel.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.document_loaders.base import BaseLoader
|
||||||
|
from langchain.schema import Document
|
||||||
|
from openpyxl.reader.excel import load_workbook
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExcelLoader(BaseLoader):
|
||||||
|
"""Load xlxs files.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to load.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: str
|
||||||
|
):
|
||||||
|
"""Initialize with file path."""
|
||||||
|
self._file_path = file_path
|
||||||
|
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
data = []
|
||||||
|
keys = []
|
||||||
|
wb = load_workbook(filename=self._file_path, read_only=True)
|
||||||
|
# loop over all sheets
|
||||||
|
for sheet in wb:
|
||||||
|
for row in sheet.iter_rows(values_only=True):
|
||||||
|
if all(v is None for v in row):
|
||||||
|
continue
|
||||||
|
if keys == []:
|
||||||
|
keys = list(map(str, row))
|
||||||
|
else:
|
||||||
|
row_dict = dict(zip(keys, row))
|
||||||
|
row_dict = {k: v for k, v in row_dict.items() if v}
|
||||||
|
data.append(json.dumps(row_dict, ensure_ascii=False))
|
||||||
|
|
||||||
|
return [Document(page_content='\n\n'.join(data))]
|
35
api/core/data_loader/loader/html.py
Normal file
35
api/core/data_loader/loader/html.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from langchain.document_loaders.base import BaseLoader
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HTMLLoader(BaseLoader):
|
||||||
|
"""Load html files.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to load.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: str
|
||||||
|
):
|
||||||
|
"""Initialize with file path."""
|
||||||
|
self._file_path = file_path
|
||||||
|
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
return [Document(page_content=self._load_as_text())]
|
||||||
|
|
||||||
|
def _load_as_text(self) -> str:
|
||||||
|
with open(self._file_path, "rb") as fp:
|
||||||
|
soup = BeautifulSoup(fp, 'html.parser')
|
||||||
|
text = soup.get_text()
|
||||||
|
text = text.strip() if text else ''
|
||||||
|
|
||||||
|
return text
|
134
api/core/data_loader/loader/markdown.py
Normal file
134
api/core/data_loader/loader/markdown.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Optional, List, Tuple, cast
|
||||||
|
|
||||||
|
from langchain.document_loaders.base import BaseLoader
|
||||||
|
from langchain.document_loaders.helpers import detect_file_encodings
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MarkdownLoader(BaseLoader):
|
||||||
|
"""Load md files.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to load.
|
||||||
|
|
||||||
|
remove_hyperlinks: Whether to remove hyperlinks from the text.
|
||||||
|
|
||||||
|
remove_images: Whether to remove images from the text.
|
||||||
|
|
||||||
|
encoding: File encoding to use. If `None`, the file will be loaded
|
||||||
|
with the default system encoding.
|
||||||
|
|
||||||
|
autodetect_encoding: Whether to try to autodetect the file encoding
|
||||||
|
if the specified encoding fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
remove_hyperlinks: bool = True,
|
||||||
|
remove_images: bool = True,
|
||||||
|
encoding: Optional[str] = None,
|
||||||
|
autodetect_encoding: bool = True,
|
||||||
|
):
|
||||||
|
"""Initialize with file path."""
|
||||||
|
self._file_path = file_path
|
||||||
|
self._remove_hyperlinks = remove_hyperlinks
|
||||||
|
self._remove_images = remove_images
|
||||||
|
self._encoding = encoding
|
||||||
|
self._autodetect_encoding = autodetect_encoding
|
||||||
|
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
tups = self.parse_tups(self._file_path)
|
||||||
|
documents = []
|
||||||
|
for header, value in tups:
|
||||||
|
value = value.strip()
|
||||||
|
if header is None:
|
||||||
|
documents.append(Document(page_content=value))
|
||||||
|
else:
|
||||||
|
documents.append(Document(page_content=f"\n\n{header}\n{value}"))
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
|
||||||
|
"""Convert a markdown file to a dictionary.
|
||||||
|
|
||||||
|
The keys are the headers and the values are the text under each header.
|
||||||
|
|
||||||
|
"""
|
||||||
|
markdown_tups: List[Tuple[Optional[str], str]] = []
|
||||||
|
lines = markdown_text.split("\n")
|
||||||
|
|
||||||
|
current_header = None
|
||||||
|
current_text = ""
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
header_match = re.match(r"^#+\s", line)
|
||||||
|
if header_match:
|
||||||
|
if current_header is not None:
|
||||||
|
markdown_tups.append((current_header, current_text))
|
||||||
|
|
||||||
|
current_header = line
|
||||||
|
current_text = ""
|
||||||
|
else:
|
||||||
|
current_text += line + "\n"
|
||||||
|
markdown_tups.append((current_header, current_text))
|
||||||
|
|
||||||
|
if current_header is not None:
|
||||||
|
# pass linting, assert keys are defined
|
||||||
|
markdown_tups = [
|
||||||
|
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
|
||||||
|
for key, value in markdown_tups
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
markdown_tups = [
|
||||||
|
(key, re.sub("\n", "", value)) for key, value in markdown_tups
|
||||||
|
]
|
||||||
|
|
||||||
|
return markdown_tups
|
||||||
|
|
||||||
|
def remove_images(self, content: str) -> str:
|
||||||
|
"""Get a dictionary of a markdown file from its path."""
|
||||||
|
pattern = r"!{1}\[\[(.*)\]\]"
|
||||||
|
content = re.sub(pattern, "", content)
|
||||||
|
return content
|
||||||
|
|
||||||
|
def remove_hyperlinks(self, content: str) -> str:
|
||||||
|
"""Get a dictionary of a markdown file from its path."""
|
||||||
|
pattern = r"\[(.*?)\]\((.*?)\)"
|
||||||
|
content = re.sub(pattern, r"\1", content)
|
||||||
|
return content
|
||||||
|
|
||||||
|
def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
|
||||||
|
"""Parse file into tuples."""
|
||||||
|
content = ""
|
||||||
|
try:
|
||||||
|
with open(filepath, "r", encoding=self._encoding) as f:
|
||||||
|
content = f.read()
|
||||||
|
except UnicodeDecodeError as e:
|
||||||
|
if self._autodetect_encoding:
|
||||||
|
detected_encodings = detect_file_encodings(filepath)
|
||||||
|
for encoding in detected_encodings:
|
||||||
|
logger.debug("Trying encoding: ", encoding.encoding)
|
||||||
|
try:
|
||||||
|
with open(filepath, encoding=encoding.encoding) as f:
|
||||||
|
content = f.read()
|
||||||
|
break
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Error loading {filepath}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error loading {filepath}") from e
|
||||||
|
|
||||||
|
if self._remove_hyperlinks:
|
||||||
|
content = self.remove_hyperlinks(content)
|
||||||
|
|
||||||
|
if self._remove_images:
|
||||||
|
content = self.remove_images(content)
|
||||||
|
|
||||||
|
return self.markdown_to_tups(content)
|
@ -1,67 +1,224 @@
|
|||||||
"""Notion reader."""
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
from typing import List, Dict, Any, Optional
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import requests # type: ignore
|
import requests
|
||||||
|
from flask import current_app
|
||||||
|
from langchain.document_loaders.base import BaseLoader
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
from llama_index.readers.base import BaseReader
|
from extensions.ext_database import db
|
||||||
from llama_index.readers.schema.base import Document
|
from models.dataset import Document as DocumentModel
|
||||||
|
from models.source import DataSourceBinding
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN"
|
|
||||||
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
|
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
|
||||||
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
|
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
|
||||||
SEARCH_URL = "https://api.notion.com/v1/search"
|
SEARCH_URL = "https://api.notion.com/v1/search"
|
||||||
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
|
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
|
||||||
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
||||||
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
|
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Notion DB reader coming soon!
|
class NotionLoader(BaseLoader):
|
||||||
class NotionPageReader(BaseReader):
|
def __init__(
|
||||||
"""Notion Page reader.
|
self,
|
||||||
|
notion_access_token: str,
|
||||||
|
notion_workspace_id: str,
|
||||||
|
notion_obj_id: str,
|
||||||
|
notion_page_type: str,
|
||||||
|
document_model: Optional[DocumentModel] = None
|
||||||
|
):
|
||||||
|
self._document_model = document_model
|
||||||
|
self._notion_workspace_id = notion_workspace_id
|
||||||
|
self._notion_obj_id = notion_obj_id
|
||||||
|
self._notion_page_type = notion_page_type
|
||||||
|
self._notion_access_token = notion_access_token
|
||||||
|
|
||||||
Reads a set of Notion pages.
|
if not self._notion_access_token:
|
||||||
|
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
|
||||||
Args:
|
|
||||||
integration_token (str): Notion integration token.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, integration_token: Optional[str] = None) -> None:
|
|
||||||
"""Initialize with parameters."""
|
|
||||||
if integration_token is None:
|
|
||||||
integration_token = os.getenv(INTEGRATION_TOKEN_NAME)
|
|
||||||
if integration_token is None:
|
if integration_token is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Must specify `integration_token` or set environment "
|
"Must specify `integration_token` or set environment "
|
||||||
"variable `NOTION_INTEGRATION_TOKEN`."
|
"variable `NOTION_INTEGRATION_TOKEN`."
|
||||||
)
|
)
|
||||||
self.token = integration_token
|
|
||||||
self.headers = {
|
|
||||||
"Authorization": "Bearer " + self.token,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Notion-Version": "2022-06-28",
|
|
||||||
}
|
|
||||||
|
|
||||||
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
|
self._notion_access_token = integration_token
|
||||||
"""Read a block."""
|
|
||||||
done = False
|
@classmethod
|
||||||
|
def from_document(cls, document_model: DocumentModel):
|
||||||
|
data_source_info = document_model.data_source_info_dict
|
||||||
|
if not data_source_info or 'notion_page_id' not in data_source_info \
|
||||||
|
or 'notion_workspace_id' not in data_source_info:
|
||||||
|
raise ValueError("no notion page found")
|
||||||
|
|
||||||
|
notion_workspace_id = data_source_info['notion_workspace_id']
|
||||||
|
notion_obj_id = data_source_info['notion_page_id']
|
||||||
|
notion_page_type = data_source_info['type']
|
||||||
|
notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
notion_access_token=notion_access_token,
|
||||||
|
notion_workspace_id=notion_workspace_id,
|
||||||
|
notion_obj_id=notion_obj_id,
|
||||||
|
notion_page_type=notion_page_type,
|
||||||
|
document_model=document_model
|
||||||
|
)
|
||||||
|
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
self.update_last_edited_time(
|
||||||
|
self._document_model
|
||||||
|
)
|
||||||
|
|
||||||
|
text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
|
||||||
|
|
||||||
|
return text_docs
|
||||||
|
|
||||||
|
def _load_data_as_documents(
|
||||||
|
self, notion_obj_id: str, notion_page_type: str
|
||||||
|
) -> List[Document]:
|
||||||
|
docs = []
|
||||||
|
if notion_page_type == 'database':
|
||||||
|
# get all the pages in the database
|
||||||
|
page_text = self._get_notion_database_data(notion_obj_id)
|
||||||
|
docs.append(Document(page_content=page_text))
|
||||||
|
elif notion_page_type == 'page':
|
||||||
|
page_text_list = self._get_notion_block_data(notion_obj_id)
|
||||||
|
for page_text in page_text_list:
|
||||||
|
docs.append(Document(page_content=page_text))
|
||||||
|
else:
|
||||||
|
raise ValueError("notion page type not supported")
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def _get_notion_database_data(
|
||||||
|
self, database_id: str, query_dict: Dict[str, Any] = {}
|
||||||
|
) -> str:
|
||||||
|
"""Get all the pages from a Notion database."""
|
||||||
|
res = requests.post(
|
||||||
|
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer " + self._notion_access_token,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Notion-Version": "2022-06-28",
|
||||||
|
},
|
||||||
|
json=query_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = res.json()
|
||||||
|
|
||||||
|
database_content_list = []
|
||||||
|
if 'results' not in data or data["results"] is None:
|
||||||
|
return ""
|
||||||
|
for result in data["results"]:
|
||||||
|
properties = result['properties']
|
||||||
|
data = {}
|
||||||
|
for property_name, property_value in properties.items():
|
||||||
|
type = property_value['type']
|
||||||
|
if type == 'multi_select':
|
||||||
|
value = []
|
||||||
|
multi_select_list = property_value[type]
|
||||||
|
for multi_select in multi_select_list:
|
||||||
|
value.append(multi_select['name'])
|
||||||
|
elif type == 'rich_text' or type == 'title':
|
||||||
|
if len(property_value[type]) > 0:
|
||||||
|
value = property_value[type][0]['plain_text']
|
||||||
|
else:
|
||||||
|
value = ''
|
||||||
|
elif type == 'select' or type == 'status':
|
||||||
|
if property_value[type]:
|
||||||
|
value = property_value[type]['name']
|
||||||
|
else:
|
||||||
|
value = ''
|
||||||
|
else:
|
||||||
|
value = property_value[type]
|
||||||
|
data[property_name] = value
|
||||||
|
database_content_list.append(json.dumps(data, ensure_ascii=False))
|
||||||
|
|
||||||
|
return "\n\n".join(database_content_list)
|
||||||
|
|
||||||
|
def _get_notion_block_data(self, page_id: str) -> List[str]:
|
||||||
result_lines_arr = []
|
result_lines_arr = []
|
||||||
cur_block_id = block_id
|
cur_block_id = page_id
|
||||||
while not done:
|
while True:
|
||||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||||
query_dict: Dict[str, Any] = {}
|
query_dict: Dict[str, Any] = {}
|
||||||
|
|
||||||
res = requests.request(
|
res = requests.request(
|
||||||
"GET", block_url, headers=self.headers, json=query_dict
|
"GET",
|
||||||
|
block_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer " + self._notion_access_token,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Notion-Version": "2022-06-28",
|
||||||
|
},
|
||||||
|
json=query_dict
|
||||||
|
)
|
||||||
|
data = res.json()
|
||||||
|
# current block's heading
|
||||||
|
heading = ''
|
||||||
|
for result in data["results"]:
|
||||||
|
result_type = result["type"]
|
||||||
|
result_obj = result[result_type]
|
||||||
|
cur_result_text_arr = []
|
||||||
|
if result_type == 'table':
|
||||||
|
result_block_id = result["id"]
|
||||||
|
text = self._read_table_rows(result_block_id)
|
||||||
|
text += "\n\n"
|
||||||
|
result_lines_arr.append(text)
|
||||||
|
else:
|
||||||
|
if "rich_text" in result_obj:
|
||||||
|
for rich_text in result_obj["rich_text"]:
|
||||||
|
# skip if doesn't have text object
|
||||||
|
if "text" in rich_text:
|
||||||
|
text = rich_text["text"]["content"]
|
||||||
|
cur_result_text_arr.append(text)
|
||||||
|
if result_type in HEADING_TYPE:
|
||||||
|
heading = text
|
||||||
|
|
||||||
|
result_block_id = result["id"]
|
||||||
|
has_children = result["has_children"]
|
||||||
|
block_type = result["type"]
|
||||||
|
if has_children and block_type != 'child_page':
|
||||||
|
children_text = self._read_block(
|
||||||
|
result_block_id, num_tabs=1
|
||||||
|
)
|
||||||
|
cur_result_text_arr.append(children_text)
|
||||||
|
|
||||||
|
cur_result_text = "\n".join(cur_result_text_arr)
|
||||||
|
cur_result_text += "\n\n"
|
||||||
|
if result_type in HEADING_TYPE:
|
||||||
|
result_lines_arr.append(cur_result_text)
|
||||||
|
else:
|
||||||
|
result_lines_arr.append(f'{heading}\n{cur_result_text}')
|
||||||
|
|
||||||
|
if data["next_cursor"] is None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
cur_block_id = data["next_cursor"]
|
||||||
|
return result_lines_arr
|
||||||
|
|
||||||
|
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
|
||||||
|
"""Read a block."""
|
||||||
|
result_lines_arr = []
|
||||||
|
cur_block_id = block_id
|
||||||
|
while True:
|
||||||
|
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||||
|
query_dict: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
res = requests.request(
|
||||||
|
"GET",
|
||||||
|
block_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer " + self._notion_access_token,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Notion-Version": "2022-06-28",
|
||||||
|
},
|
||||||
|
json=query_dict
|
||||||
)
|
)
|
||||||
data = res.json()
|
data = res.json()
|
||||||
if 'results' not in data or data["results"] is None:
|
if 'results' not in data or data["results"] is None:
|
||||||
done = True
|
|
||||||
break
|
break
|
||||||
heading = ''
|
heading = ''
|
||||||
for result in data["results"]:
|
for result in data["results"]:
|
||||||
@ -98,7 +255,6 @@ class NotionPageReader(BaseReader):
|
|||||||
result_lines_arr.append(f'{heading}\n{cur_result_text}')
|
result_lines_arr.append(f'{heading}\n{cur_result_text}')
|
||||||
|
|
||||||
if data["next_cursor"] is None:
|
if data["next_cursor"] is None:
|
||||||
done = True
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
cur_block_id = data["next_cursor"]
|
cur_block_id = data["next_cursor"]
|
||||||
@ -116,7 +272,14 @@ class NotionPageReader(BaseReader):
|
|||||||
query_dict: Dict[str, Any] = {}
|
query_dict: Dict[str, Any] = {}
|
||||||
|
|
||||||
res = requests.request(
|
res = requests.request(
|
||||||
"GET", block_url, headers=self.headers, json=query_dict
|
"GET",
|
||||||
|
block_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer " + self._notion_access_token,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Notion-Version": "2022-06-28",
|
||||||
|
},
|
||||||
|
json=query_dict
|
||||||
)
|
)
|
||||||
data = res.json()
|
data = res.json()
|
||||||
# get table headers text
|
# get table headers text
|
||||||
@ -129,9 +292,9 @@ class NotionPageReader(BaseReader):
|
|||||||
table_header_cell_texts.append(text)
|
table_header_cell_texts.append(text)
|
||||||
# get table columns text and format
|
# get table columns text and format
|
||||||
results = data["results"]
|
results = data["results"]
|
||||||
for i in range(len(results)-1):
|
for i in range(len(results) - 1):
|
||||||
column_texts = []
|
column_texts = []
|
||||||
tabel_column_cells = data["results"][i+1]['table_row']['cells']
|
tabel_column_cells = data["results"][i + 1]['table_row']['cells']
|
||||||
for j in range(len(tabel_column_cells)):
|
for j in range(len(tabel_column_cells)):
|
||||||
if tabel_column_cells[j]:
|
if tabel_column_cells[j]:
|
||||||
for table_column_cell_text in tabel_column_cells[j]:
|
for table_column_cell_text in tabel_column_cells[j]:
|
||||||
@ -149,221 +312,58 @@ class NotionPageReader(BaseReader):
|
|||||||
|
|
||||||
result_lines = "\n".join(result_lines_arr)
|
result_lines = "\n".join(result_lines_arr)
|
||||||
return result_lines
|
return result_lines
|
||||||
def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]:
|
|
||||||
"""Read a block."""
|
|
||||||
done = False
|
|
||||||
result_lines_arr = []
|
|
||||||
cur_block_id = block_id
|
|
||||||
while not done:
|
|
||||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
|
||||||
query_dict: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
res = requests.request(
|
def update_last_edited_time(self, document_model: DocumentModel):
|
||||||
"GET", block_url, headers=self.headers, json=query_dict
|
if not document_model:
|
||||||
|
return
|
||||||
|
|
||||||
|
last_edited_time = self.get_notion_last_edited_time()
|
||||||
|
data_source_info = document_model.data_source_info_dict
|
||||||
|
data_source_info['last_edited_time'] = last_edited_time
|
||||||
|
update_params = {
|
||||||
|
DocumentModel.data_source_info: json.dumps(data_source_info)
|
||||||
|
}
|
||||||
|
|
||||||
|
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def get_notion_last_edited_time(self) -> str:
|
||||||
|
obj_id = self._notion_obj_id
|
||||||
|
page_type = self._notion_page_type
|
||||||
|
if page_type == 'database':
|
||||||
|
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
|
||||||
|
else:
|
||||||
|
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
|
||||||
|
|
||||||
|
query_dict: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
res = requests.request(
|
||||||
|
"GET",
|
||||||
|
retrieve_page_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer " + self._notion_access_token,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Notion-Version": "2022-06-28",
|
||||||
|
},
|
||||||
|
json=query_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
data = res.json()
|
||||||
|
return data["last_edited_time"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||||
|
data_source_binding = DataSourceBinding.query.filter(
|
||||||
|
db.and_(
|
||||||
|
DataSourceBinding.tenant_id == tenant_id,
|
||||||
|
DataSourceBinding.provider == 'notion',
|
||||||
|
DataSourceBinding.disabled == False,
|
||||||
|
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
|
||||||
)
|
)
|
||||||
data = res.json()
|
).first()
|
||||||
# current block's heading
|
|
||||||
heading = ''
|
|
||||||
for result in data["results"]:
|
|
||||||
result_type = result["type"]
|
|
||||||
result_obj = result[result_type]
|
|
||||||
cur_result_text_arr = []
|
|
||||||
if result_type == 'table':
|
|
||||||
result_block_id = result["id"]
|
|
||||||
text = self._read_table_rows(result_block_id)
|
|
||||||
text += "\n\n"
|
|
||||||
result_lines_arr.append(text)
|
|
||||||
else:
|
|
||||||
if "rich_text" in result_obj:
|
|
||||||
for rich_text in result_obj["rich_text"]:
|
|
||||||
# skip if doesn't have text object
|
|
||||||
if "text" in rich_text:
|
|
||||||
text = rich_text["text"]["content"]
|
|
||||||
cur_result_text_arr.append(text)
|
|
||||||
if result_type in HEADING_TYPE:
|
|
||||||
heading = text
|
|
||||||
|
|
||||||
result_block_id = result["id"]
|
if not data_source_binding:
|
||||||
has_children = result["has_children"]
|
raise Exception(f'No notion data source binding found for tenant {tenant_id} '
|
||||||
block_type = result["type"]
|
f'and notion workspace {notion_workspace_id}')
|
||||||
if has_children and block_type != 'child_page':
|
|
||||||
children_text = self._read_block(
|
|
||||||
result_block_id, num_tabs=num_tabs + 1
|
|
||||||
)
|
|
||||||
cur_result_text_arr.append(children_text)
|
|
||||||
|
|
||||||
cur_result_text = "\n".join(cur_result_text_arr)
|
return data_source_binding.access_token
|
||||||
cur_result_text += "\n\n"
|
|
||||||
if result_type in HEADING_TYPE:
|
|
||||||
result_lines_arr.append(cur_result_text)
|
|
||||||
else:
|
|
||||||
result_lines_arr.append(f'{heading}\n{cur_result_text}')
|
|
||||||
|
|
||||||
if data["next_cursor"] is None:
|
|
||||||
done = True
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
cur_block_id = data["next_cursor"]
|
|
||||||
return result_lines_arr
|
|
||||||
|
|
||||||
def read_page(self, page_id: str) -> str:
|
|
||||||
"""Read a page."""
|
|
||||||
return self._read_block(page_id)
|
|
||||||
|
|
||||||
def read_page_as_documents(self, page_id: str) -> List[str]:
|
|
||||||
"""Read a page as documents."""
|
|
||||||
return self._read_parent_blocks(page_id)
|
|
||||||
|
|
||||||
def query_database_data(
|
|
||||||
self, database_id: str, query_dict: Dict[str, Any] = {}
|
|
||||||
) -> str:
|
|
||||||
"""Get all the pages from a Notion database."""
|
|
||||||
res = requests.post\
|
|
||||||
(
|
|
||||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
|
||||||
headers=self.headers,
|
|
||||||
json=query_dict,
|
|
||||||
)
|
|
||||||
data = res.json()
|
|
||||||
database_content_list = []
|
|
||||||
if 'results' not in data or data["results"] is None:
|
|
||||||
return ""
|
|
||||||
for result in data["results"]:
|
|
||||||
properties = result['properties']
|
|
||||||
data = {}
|
|
||||||
for property_name, property_value in properties.items():
|
|
||||||
type = property_value['type']
|
|
||||||
if type == 'multi_select':
|
|
||||||
value = []
|
|
||||||
multi_select_list = property_value[type]
|
|
||||||
for multi_select in multi_select_list:
|
|
||||||
value.append(multi_select['name'])
|
|
||||||
elif type == 'rich_text' or type == 'title':
|
|
||||||
if len(property_value[type]) > 0:
|
|
||||||
value = property_value[type][0]['plain_text']
|
|
||||||
else:
|
|
||||||
value = ''
|
|
||||||
elif type == 'select' or type == 'status':
|
|
||||||
if property_value[type]:
|
|
||||||
value = property_value[type]['name']
|
|
||||||
else:
|
|
||||||
value = ''
|
|
||||||
else:
|
|
||||||
value = property_value[type]
|
|
||||||
data[property_name] = value
|
|
||||||
database_content_list.append(json.dumps(data, ensure_ascii=False))
|
|
||||||
|
|
||||||
return "\n\n".join(database_content_list)
|
|
||||||
|
|
||||||
def query_database(
|
|
||||||
self, database_id: str, query_dict: Dict[str, Any] = {}
|
|
||||||
) -> List[str]:
|
|
||||||
"""Get all the pages from a Notion database."""
|
|
||||||
res = requests.post\
|
|
||||||
(
|
|
||||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
|
||||||
headers=self.headers,
|
|
||||||
json=query_dict,
|
|
||||||
)
|
|
||||||
data = res.json()
|
|
||||||
page_ids = []
|
|
||||||
for result in data["results"]:
|
|
||||||
page_id = result["id"]
|
|
||||||
page_ids.append(page_id)
|
|
||||||
|
|
||||||
return page_ids
|
|
||||||
|
|
||||||
def search(self, query: str) -> List[str]:
|
|
||||||
"""Search Notion page given a text query."""
|
|
||||||
done = False
|
|
||||||
next_cursor: Optional[str] = None
|
|
||||||
page_ids = []
|
|
||||||
while not done:
|
|
||||||
query_dict = {
|
|
||||||
"query": query,
|
|
||||||
}
|
|
||||||
if next_cursor is not None:
|
|
||||||
query_dict["start_cursor"] = next_cursor
|
|
||||||
res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict)
|
|
||||||
data = res.json()
|
|
||||||
for result in data["results"]:
|
|
||||||
page_id = result["id"]
|
|
||||||
page_ids.append(page_id)
|
|
||||||
|
|
||||||
if data["next_cursor"] is None:
|
|
||||||
done = True
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
next_cursor = data["next_cursor"]
|
|
||||||
return page_ids
|
|
||||||
|
|
||||||
def load_data(
|
|
||||||
self, page_ids: List[str] = [], database_id: Optional[str] = None
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Load data from the input directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
page_ids (List[str]): List of page ids to load.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Document]: List of documents.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not page_ids and not database_id:
|
|
||||||
raise ValueError("Must specify either `page_ids` or `database_id`.")
|
|
||||||
docs = []
|
|
||||||
if database_id is not None:
|
|
||||||
# get all the pages in the database
|
|
||||||
page_ids = self.query_database(database_id)
|
|
||||||
for page_id in page_ids:
|
|
||||||
page_text = self.read_page(page_id)
|
|
||||||
docs.append(Document(page_text))
|
|
||||||
else:
|
|
||||||
for page_id in page_ids:
|
|
||||||
page_text = self.read_page(page_id)
|
|
||||||
docs.append(Document(page_text))
|
|
||||||
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def load_data_as_documents(
|
|
||||||
self, page_ids: List[str] = [], database_id: Optional[str] = None
|
|
||||||
) -> List[Document]:
|
|
||||||
if not page_ids and not database_id:
|
|
||||||
raise ValueError("Must specify either `page_ids` or `database_id`.")
|
|
||||||
docs = []
|
|
||||||
if database_id is not None:
|
|
||||||
# get all the pages in the database
|
|
||||||
page_text = self.query_database_data(database_id)
|
|
||||||
docs.append(Document(page_text))
|
|
||||||
else:
|
|
||||||
for page_id in page_ids:
|
|
||||||
page_text_list = self.read_page_as_documents(page_id)
|
|
||||||
for page_text in page_text_list:
|
|
||||||
docs.append(Document(page_text))
|
|
||||||
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def get_page_last_edited_time(self, page_id: str) -> str:
|
|
||||||
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id)
|
|
||||||
query_dict: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
res = requests.request(
|
|
||||||
"GET", retrieve_page_url, headers=self.headers, json=query_dict
|
|
||||||
)
|
|
||||||
data = res.json()
|
|
||||||
return data["last_edited_time"]
|
|
||||||
|
|
||||||
def get_database_last_edited_time(self, database_id: str) -> str:
|
|
||||||
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=database_id)
|
|
||||||
query_dict: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
res = requests.request(
|
|
||||||
"GET", retrieve_page_url, headers=self.headers, json=query_dict
|
|
||||||
)
|
|
||||||
data = res.json()
|
|
||||||
return data["last_edited_time"]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
reader = NotionPageReader()
|
|
||||||
logger.info(reader.search("What I"))
|
|
55
api/core/data_loader/loader/pdf.py
Normal file
55
api/core/data_loader/loader/pdf.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langchain.document_loaders import PyPDFium2Loader
|
||||||
|
from langchain.document_loaders.base import BaseLoader
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from models.model import UploadFile
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PdfLoader(BaseLoader):
|
||||||
|
"""Load pdf files.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to load.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
upload_file: Optional[UploadFile] = None
|
||||||
|
):
|
||||||
|
"""Initialize with file path."""
|
||||||
|
self._file_path = file_path
|
||||||
|
self._upload_file = upload_file
|
||||||
|
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
plaintext_file_key = ''
|
||||||
|
plaintext_file_exists = False
|
||||||
|
if self._upload_file:
|
||||||
|
if self._upload_file.hash:
|
||||||
|
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
|
||||||
|
+ self._upload_file.hash + '.0625.plaintext'
|
||||||
|
try:
|
||||||
|
text = storage.load(plaintext_file_key).decode('utf-8')
|
||||||
|
plaintext_file_exists = True
|
||||||
|
return [Document(page_content=text)]
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
documents = PyPDFium2Loader(file_path=self._file_path).load()
|
||||||
|
text_list = []
|
||||||
|
for document in documents:
|
||||||
|
text_list.append(document.page_content)
|
||||||
|
text = "\n\n".join(text_list)
|
||||||
|
|
||||||
|
# save plaintext file for caching
|
||||||
|
if not plaintext_file_exists and plaintext_file_key:
|
||||||
|
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
@ -1,10 +1,6 @@
|
|||||||
from typing import Any, Dict, Optional, Sequence
|
from typing import Any, Dict, Optional, Sequence
|
||||||
|
|
||||||
import tiktoken
|
from langchain.schema import Document
|
||||||
from llama_index.data_structs import Node
|
|
||||||
from llama_index.docstore.types import BaseDocumentStore
|
|
||||||
from llama_index.docstore.utils import json_to_doc
|
|
||||||
from llama_index.schema import BaseDocument
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
from core.llm.token_calculator import TokenCalculator
|
from core.llm.token_calculator import TokenCalculator
|
||||||
@ -12,7 +8,7 @@ from extensions.ext_database import db
|
|||||||
from models.dataset import Dataset, DocumentSegment
|
from models.dataset import Dataset, DocumentSegment
|
||||||
|
|
||||||
|
|
||||||
class DatesetDocumentStore(BaseDocumentStore):
|
class DatesetDocumentStore:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore):
|
|||||||
return self._embedding_model_name
|
return self._embedding_model_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def docs(self) -> Dict[str, BaseDocument]:
|
def docs(self) -> Dict[str, Document]:
|
||||||
document_segments = db.session.query(DocumentSegment).filter(
|
document_segments = db.session.query(DocumentSegment).filter(
|
||||||
DocumentSegment.dataset_id == self._dataset.id
|
DocumentSegment.dataset_id == self._dataset.id
|
||||||
).all()
|
).all()
|
||||||
@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore):
|
|||||||
output = {}
|
output = {}
|
||||||
for document_segment in document_segments:
|
for document_segment in document_segments:
|
||||||
doc_id = document_segment.index_node_id
|
doc_id = document_segment.index_node_id
|
||||||
result = self.segment_to_dict(document_segment)
|
output[doc_id] = Document(
|
||||||
output[doc_id] = json_to_doc(result)
|
page_content=document_segment.content,
|
||||||
|
metadata={
|
||||||
|
"doc_id": document_segment.index_node_id,
|
||||||
|
"doc_hash": document_segment.index_node_hash,
|
||||||
|
"document_id": document_segment.document_id,
|
||||||
|
"dataset_id": document_segment.dataset_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def add_documents(
|
def add_documents(
|
||||||
self, docs: Sequence[BaseDocument], allow_update: bool = True
|
self, docs: Sequence[Document], allow_update: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||||
DocumentSegment.document == self._document_id
|
DocumentSegment.document == self._document_id
|
||||||
@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore):
|
|||||||
max_position = 0
|
max_position = 0
|
||||||
|
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
if doc.is_doc_id_none:
|
if not isinstance(doc, Document):
|
||||||
raise ValueError("doc_id not set")
|
raise ValueError("doc must be a Document")
|
||||||
|
|
||||||
if not isinstance(doc, Node):
|
segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
|
||||||
raise ValueError("doc must be a Node")
|
|
||||||
|
|
||||||
segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False)
|
|
||||||
|
|
||||||
# NOTE: doc could already exist in the store, but we overwrite it
|
# NOTE: doc could already exist in the store, but we overwrite it
|
||||||
if not allow_update and segment_document:
|
if not allow_update and segment_document:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"doc_id {doc.get_doc_id()} already exists. "
|
f"doc_id {doc.metadata['doc_id']} already exists. "
|
||||||
"Set allow_update to True to overwrite."
|
"Set allow_update to True to overwrite."
|
||||||
)
|
)
|
||||||
|
|
||||||
# calc embedding use tokens
|
# calc embedding use tokens
|
||||||
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text())
|
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
|
||||||
|
|
||||||
if not segment_document:
|
if not segment_document:
|
||||||
max_position += 1
|
max_position += 1
|
||||||
@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore):
|
|||||||
tenant_id=self._dataset.tenant_id,
|
tenant_id=self._dataset.tenant_id,
|
||||||
dataset_id=self._dataset.id,
|
dataset_id=self._dataset.id,
|
||||||
document_id=self._document_id,
|
document_id=self._document_id,
|
||||||
index_node_id=doc.get_doc_id(),
|
index_node_id=doc.metadata['doc_id'],
|
||||||
index_node_hash=doc.get_doc_hash(),
|
index_node_hash=doc.metadata['doc_hash'],
|
||||||
position=max_position,
|
position=max_position,
|
||||||
content=doc.get_text(),
|
content=doc.page_content,
|
||||||
word_count=len(doc.get_text()),
|
word_count=len(doc.page_content),
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
created_by=self._user_id,
|
created_by=self._user_id,
|
||||||
)
|
)
|
||||||
db.session.add(segment_document)
|
db.session.add(segment_document)
|
||||||
else:
|
else:
|
||||||
segment_document.content = doc.get_text()
|
segment_document.content = doc.page_content
|
||||||
segment_document.index_node_hash = doc.get_doc_hash()
|
segment_document.index_node_hash = doc.metadata['doc_hash']
|
||||||
segment_document.word_count = len(doc.get_text())
|
segment_document.word_count = len(doc.page_content)
|
||||||
segment_document.tokens = tokens
|
segment_document.tokens = tokens
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore):
|
|||||||
|
|
||||||
def get_document(
|
def get_document(
|
||||||
self, doc_id: str, raise_error: bool = True
|
self, doc_id: str, raise_error: bool = True
|
||||||
) -> Optional[BaseDocument]:
|
) -> Optional[Document]:
|
||||||
document_segment = self.get_document_segment(doc_id)
|
document_segment = self.get_document_segment(doc_id)
|
||||||
|
|
||||||
if document_segment is None:
|
if document_segment is None:
|
||||||
@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
result = self.segment_to_dict(document_segment)
|
return Document(
|
||||||
return json_to_doc(result)
|
page_content=document_segment.content,
|
||||||
|
metadata={
|
||||||
|
"doc_id": document_segment.index_node_id,
|
||||||
|
"doc_hash": document_segment.index_node_hash,
|
||||||
|
"document_id": document_segment.document_id,
|
||||||
|
"dataset_id": document_segment.dataset_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
|
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
|
||||||
document_segment = self.get_document_segment(doc_id)
|
document_segment = self.get_document_segment(doc_id)
|
||||||
@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore):
|
|||||||
|
|
||||||
return document_segment.index_node_hash
|
return document_segment.index_node_hash
|
||||||
|
|
||||||
def update_docstore(self, other: "BaseDocumentStore") -> None:
|
|
||||||
"""Update docstore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other (BaseDocumentStore): docstore to update from
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.add_documents(list(other.docs.values()))
|
|
||||||
|
|
||||||
def get_document_segment(self, doc_id: str) -> DocumentSegment:
|
def get_document_segment(self, doc_id: str) -> DocumentSegment:
|
||||||
document_segment = db.session.query(DocumentSegment).filter(
|
document_segment = db.session.query(DocumentSegment).filter(
|
||||||
DocumentSegment.dataset_id == self._dataset.id,
|
DocumentSegment.dataset_id == self._dataset.id,
|
||||||
@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore):
|
|||||||
).first()
|
).first()
|
||||||
|
|
||||||
return document_segment
|
return document_segment
|
||||||
|
|
||||||
def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"doc_id": segment.index_node_id,
|
|
||||||
"doc_hash": segment.index_node_hash,
|
|
||||||
"text": segment.content,
|
|
||||||
"__type__": Node.get_type()
|
|
||||||
}
|
|
||||||
|
@ -1,51 +0,0 @@
|
|||||||
from typing import Any, Dict, Optional, Sequence
|
|
||||||
from llama_index.docstore.types import BaseDocumentStore
|
|
||||||
from llama_index.schema import BaseDocument
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyDocumentStore(BaseDocumentStore):
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""Serialize to dict."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def docs(self) -> Dict[str, BaseDocument]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def add_documents(
|
|
||||||
self, docs: Sequence[BaseDocument], allow_update: bool = True
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def document_exists(self, doc_id: str) -> bool:
|
|
||||||
"""Check if document exists."""
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_document(
|
|
||||||
self, doc_id: str, raise_error: bool = True
|
|
||||||
) -> Optional[BaseDocument]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
|
|
||||||
"""Set the hash for a given doc_id."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_document_hash(self, doc_id: str) -> Optional[str]:
|
|
||||||
"""Get the stored hash for a document, if it exists."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
def update_docstore(self, other: "BaseDocumentStore") -> None:
|
|
||||||
"""Update docstore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other (BaseDocumentStore): docstore to update from
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.add_documents(list(other.docs.values()))
|
|
72
api/core/embedding/cached_embedding.py
Normal file
72
api/core/embedding/cached_embedding.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs import helper
|
||||||
|
from models.dataset import Embedding
|
||||||
|
|
||||||
|
|
||||||
|
class CacheEmbedding(Embeddings):
|
||||||
|
def __init__(self, embeddings: Embeddings):
|
||||||
|
self._embeddings = embeddings
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Embed search docs."""
|
||||||
|
# use doc embedding cache or store if not exists
|
||||||
|
text_embeddings = []
|
||||||
|
embedding_queue_texts = []
|
||||||
|
for text in texts:
|
||||||
|
hash = helper.generate_text_hash(text)
|
||||||
|
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
|
||||||
|
if embedding:
|
||||||
|
text_embeddings.append(embedding.get_embedding())
|
||||||
|
else:
|
||||||
|
embedding_queue_texts.append(text)
|
||||||
|
|
||||||
|
embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for text in embedding_queue_texts:
|
||||||
|
hash = helper.generate_text_hash(text)
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding = Embedding(hash=hash)
|
||||||
|
embedding.set_embedding(embedding_results[i])
|
||||||
|
db.session.add(embedding)
|
||||||
|
db.session.commit()
|
||||||
|
except IntegrityError:
|
||||||
|
db.session.rollback()
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
logging.exception('Failed to add embedding to db')
|
||||||
|
continue
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
text_embeddings.extend(embedding_results)
|
||||||
|
return text_embeddings
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Embed query text."""
|
||||||
|
# use doc embedding cache or store if not exists
|
||||||
|
hash = helper.generate_text_hash(text)
|
||||||
|
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
|
||||||
|
if embedding:
|
||||||
|
return embedding.get_embedding()
|
||||||
|
|
||||||
|
embedding_results = self._embeddings.embed_query(text)
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding = Embedding(hash=hash)
|
||||||
|
embedding.set_embedding(embedding_results)
|
||||||
|
db.session.add(embedding)
|
||||||
|
db.session.commit()
|
||||||
|
except IntegrityError:
|
||||||
|
db.session.rollback()
|
||||||
|
except:
|
||||||
|
logging.exception('Failed to add embedding to db')
|
||||||
|
|
||||||
|
return embedding_results
|
@ -1,214 +0,0 @@
|
|||||||
from typing import Optional, Any, List
|
|
||||||
|
|
||||||
import openai
|
|
||||||
from llama_index.embeddings.base import BaseEmbedding
|
|
||||||
from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
|
|
||||||
_TEXT_MODE_MODEL_DICT
|
|
||||||
from tenacity import wait_random_exponential, retry, stop_after_attempt
|
|
||||||
|
|
||||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
|
||||||
|
|
||||||
|
|
||||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
|
||||||
def get_embedding(
|
|
||||||
text: str,
|
|
||||||
engine: Optional[str] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> List[float]:
|
|
||||||
"""Get embedding.
|
|
||||||
|
|
||||||
NOTE: Copied from OpenAI's embedding utils:
|
|
||||||
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
|
||||||
|
|
||||||
Copied here to avoid importing unnecessary dependencies
|
|
||||||
like matplotlib, plotly, scipy, sklearn.
|
|
||||||
|
|
||||||
"""
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
|
|
||||||
|
|
||||||
|
|
||||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
|
||||||
async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
|
|
||||||
float]:
|
|
||||||
"""Asynchronously get embedding.
|
|
||||||
|
|
||||||
NOTE: Copied from OpenAI's embedding utils:
|
|
||||||
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
|
||||||
|
|
||||||
Copied here to avoid importing unnecessary dependencies
|
|
||||||
like matplotlib, plotly, scipy, sklearn.
|
|
||||||
|
|
||||||
"""
|
|
||||||
# replace newlines, which can negatively affect performance.
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
|
|
||||||
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
|
|
||||||
"embedding"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
|
||||||
def get_embeddings(
|
|
||||||
list_of_text: List[str],
|
|
||||||
engine: Optional[str] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> List[List[float]]:
|
|
||||||
"""Get embeddings.
|
|
||||||
|
|
||||||
NOTE: Copied from OpenAI's embedding utils:
|
|
||||||
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
|
||||||
|
|
||||||
Copied here to avoid importing unnecessary dependencies
|
|
||||||
like matplotlib, plotly, scipy, sklearn.
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
|
||||||
|
|
||||||
# replace newlines, which can negatively affect performance.
|
|
||||||
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
|
||||||
|
|
||||||
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
|
|
||||||
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
|
|
||||||
return [d["embedding"] for d in data]
|
|
||||||
|
|
||||||
|
|
||||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
|
||||||
async def aget_embeddings(
|
|
||||||
list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
|
|
||||||
) -> List[List[float]]:
|
|
||||||
"""Asynchronously get embeddings.
|
|
||||||
|
|
||||||
NOTE: Copied from OpenAI's embedding utils:
|
|
||||||
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
|
||||||
|
|
||||||
Copied here to avoid importing unnecessary dependencies
|
|
||||||
like matplotlib, plotly, scipy, sklearn.
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
|
||||||
|
|
||||||
# replace newlines, which can negatively affect performance.
|
|
||||||
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
|
||||||
|
|
||||||
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
|
|
||||||
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
|
|
||||||
return [d["embedding"] for d in data]
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIEmbedding(BaseEmbedding):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
|
|
||||||
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
|
|
||||||
deployment_name: Optional[str] = None,
|
|
||||||
openai_api_key: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Init params."""
|
|
||||||
new_kwargs = {}
|
|
||||||
|
|
||||||
if 'embed_batch_size' in kwargs:
|
|
||||||
new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
|
|
||||||
|
|
||||||
if 'tokenizer' in kwargs:
|
|
||||||
new_kwargs['tokenizer'] = kwargs['tokenizer']
|
|
||||||
|
|
||||||
super().__init__(**new_kwargs)
|
|
||||||
self.mode = OpenAIEmbeddingMode(mode)
|
|
||||||
self.model = OpenAIEmbeddingModelType(model)
|
|
||||||
self.deployment_name = deployment_name
|
|
||||||
self.openai_api_key = openai_api_key
|
|
||||||
self.openai_api_type = kwargs.get('openai_api_type')
|
|
||||||
self.openai_api_version = kwargs.get('openai_api_version')
|
|
||||||
self.openai_api_base = kwargs.get('openai_api_base')
|
|
||||||
|
|
||||||
@handle_llm_exceptions
|
|
||||||
def _get_query_embedding(self, query: str) -> List[float]:
|
|
||||||
"""Get query embedding."""
|
|
||||||
if self.deployment_name is not None:
|
|
||||||
engine = self.deployment_name
|
|
||||||
else:
|
|
||||||
key = (self.mode, self.model)
|
|
||||||
if key not in _QUERY_MODE_MODEL_DICT:
|
|
||||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
|
||||||
engine = _QUERY_MODE_MODEL_DICT[key]
|
|
||||||
return get_embedding(query, engine=engine, api_key=self.openai_api_key,
|
|
||||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
|
||||||
api_base=self.openai_api_base)
|
|
||||||
|
|
||||||
def _get_text_embedding(self, text: str) -> List[float]:
|
|
||||||
"""Get text embedding."""
|
|
||||||
if self.deployment_name is not None:
|
|
||||||
engine = self.deployment_name
|
|
||||||
else:
|
|
||||||
key = (self.mode, self.model)
|
|
||||||
if key not in _TEXT_MODE_MODEL_DICT:
|
|
||||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
|
||||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
|
||||||
return get_embedding(text, engine=engine, api_key=self.openai_api_key,
|
|
||||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
|
||||||
api_base=self.openai_api_base)
|
|
||||||
|
|
||||||
async def _aget_text_embedding(self, text: str) -> List[float]:
|
|
||||||
"""Asynchronously get text embedding."""
|
|
||||||
if self.deployment_name is not None:
|
|
||||||
engine = self.deployment_name
|
|
||||||
else:
|
|
||||||
key = (self.mode, self.model)
|
|
||||||
if key not in _TEXT_MODE_MODEL_DICT:
|
|
||||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
|
||||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
|
||||||
return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
|
|
||||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
|
||||||
api_base=self.openai_api_base)
|
|
||||||
|
|
||||||
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""Get text embeddings.
|
|
||||||
|
|
||||||
By default, this is a wrapper around _get_text_embedding.
|
|
||||||
Can be overriden for batch queries.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if self.openai_api_type and self.openai_api_type == 'azure':
|
|
||||||
embeddings = []
|
|
||||||
for text in texts:
|
|
||||||
embeddings.append(self._get_text_embedding(text))
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
if self.deployment_name is not None:
|
|
||||||
engine = self.deployment_name
|
|
||||||
else:
|
|
||||||
key = (self.mode, self.model)
|
|
||||||
if key not in _TEXT_MODE_MODEL_DICT:
|
|
||||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
|
||||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
|
||||||
embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
|
|
||||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
|
||||||
api_base=self.openai_api_base)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""Asynchronously get text embeddings."""
|
|
||||||
if self.openai_api_type and self.openai_api_type == 'azure':
|
|
||||||
embeddings = []
|
|
||||||
for text in texts:
|
|
||||||
embeddings.append(await self._aget_text_embedding(text))
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
if self.deployment_name is not None:
|
|
||||||
engine = self.deployment_name
|
|
||||||
else:
|
|
||||||
key = (self.mode, self.model)
|
|
||||||
if key not in _TEXT_MODE_MODEL_DICT:
|
|
||||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
|
||||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
|
||||||
embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
|
|
||||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
|
||||||
api_base=self.openai_api_base)
|
|
||||||
return embeddings
|
|
59
api/core/index/base.py
Normal file
59
api/core/index/base.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from abc import abstractmethod, ABC
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
from langchain.schema import Document, BaseRetriever
|
||||||
|
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class BaseIndex(ABC):
|
||||||
|
|
||||||
|
def __init__(self, dataset: Dataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_texts(self, texts: list[Document], **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def text_exists(self, id: str) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_by_document_id(self, document_id: str):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search(
|
||||||
|
self, query: str,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||||
|
for text in texts:
|
||||||
|
doc_id = text.metadata['doc_id']
|
||||||
|
exists_duplicate_node = self.text_exists(doc_id)
|
||||||
|
if exists_duplicate_node:
|
||||||
|
texts.remove(text)
|
||||||
|
|
||||||
|
return texts
|
||||||
|
|
||||||
|
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||||
|
return [text.metadata['doc_id'] for text in texts]
|
41
api/core/index/index.py
Normal file
41
api/core/index/index.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from flask import current_app
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
from core.embedding.cached_embedding import CacheEmbedding
|
||||||
|
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||||
|
from core.index.vector_index.vector_index import VectorIndex
|
||||||
|
from core.llm.llm_builder import LLMBuilder
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class IndexBuilder:
|
||||||
|
@classmethod
|
||||||
|
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
|
||||||
|
if indexing_technique == "high_quality":
|
||||||
|
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
|
||||||
|
return None
|
||||||
|
|
||||||
|
model_credentials = LLMBuilder.get_model_credentials(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
|
||||||
|
model_name='text-embedding-ada-002'
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||||
|
**model_credentials
|
||||||
|
))
|
||||||
|
|
||||||
|
return VectorIndex(
|
||||||
|
dataset=dataset,
|
||||||
|
config=current_app.config,
|
||||||
|
embeddings=embeddings
|
||||||
|
)
|
||||||
|
elif indexing_technique == "economy":
|
||||||
|
return KeywordTableIndex(
|
||||||
|
dataset=dataset,
|
||||||
|
config=KeywordTableConfig(
|
||||||
|
max_keywords_per_chunk=10
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown indexing technique')
|
@ -1,60 +0,0 @@
|
|||||||
from langchain.callbacks import CallbackManager
|
|
||||||
from llama_index import ServiceContext, PromptHelper, LLMPredictor
|
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
|
||||||
from core.embedding.openai_embedding import OpenAIEmbedding
|
|
||||||
from core.llm.llm_builder import LLMBuilder
|
|
||||||
|
|
||||||
|
|
||||||
class IndexBuilder:
|
|
||||||
@classmethod
|
|
||||||
def get_default_service_context(cls, tenant_id: str) -> ServiceContext:
|
|
||||||
# set number of output tokens
|
|
||||||
num_output = 512
|
|
||||||
|
|
||||||
# only for verbose
|
|
||||||
callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
|
|
||||||
|
|
||||||
llm = LLMBuilder.to_llm(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_name='text-davinci-003',
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=num_output,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_predictor = LLMPredictor(llm=llm)
|
|
||||||
|
|
||||||
# These parameters here will affect the logic of segmenting the final synthesized response.
|
|
||||||
# The number of refinement iterations in the synthesis process depends
|
|
||||||
# on whether the length of the segmented output exceeds the max_input_size.
|
|
||||||
prompt_helper = PromptHelper(
|
|
||||||
max_input_size=3500,
|
|
||||||
num_output=num_output,
|
|
||||||
max_chunk_overlap=20
|
|
||||||
)
|
|
||||||
|
|
||||||
provider = LLMBuilder.get_default_provider(tenant_id)
|
|
||||||
|
|
||||||
model_credentials = LLMBuilder.get_model_credentials(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_provider=provider,
|
|
||||||
model_name='text-embedding-ada-002'
|
|
||||||
)
|
|
||||||
|
|
||||||
return ServiceContext.from_defaults(
|
|
||||||
llm_predictor=llm_predictor,
|
|
||||||
prompt_helper=prompt_helper,
|
|
||||||
embed_model=OpenAIEmbedding(**model_credentials),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_fake_llm_service_context(cls, tenant_id: str) -> ServiceContext:
|
|
||||||
llm = LLMBuilder.to_llm(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_name='fake'
|
|
||||||
)
|
|
||||||
|
|
||||||
return ServiceContext.from_defaults(
|
|
||||||
llm_predictor=LLMPredictor(llm=llm),
|
|
||||||
embed_model=OpenAIEmbedding()
|
|
||||||
)
|
|
@ -1,159 +0,0 @@
|
|||||||
import re
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Set,
|
|
||||||
Optional
|
|
||||||
)
|
|
||||||
|
|
||||||
import jieba.analyse
|
|
||||||
|
|
||||||
from core.index.keyword_table.stopwords import STOPWORDS
|
|
||||||
from llama_index.indices.query.base import IS
|
|
||||||
from llama_index import QueryMode
|
|
||||||
from llama_index.indices.base import QueryMap
|
|
||||||
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
|
|
||||||
from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery
|
|
||||||
from llama_index.docstore import BaseDocumentStore
|
|
||||||
from llama_index.indices.postprocessor.node import (
|
|
||||||
BaseNodePostprocessor,
|
|
||||||
)
|
|
||||||
from llama_index.indices.response.response_builder import ResponseMode
|
|
||||||
from llama_index.indices.service_context import ServiceContext
|
|
||||||
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
|
|
||||||
from llama_index.prompts.prompts import (
|
|
||||||
QuestionAnswerPrompt,
|
|
||||||
RefinePrompt,
|
|
||||||
SimpleInputPrompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
from core.index.query.synthesizer import EnhanceResponseSynthesizer
|
|
||||||
|
|
||||||
|
|
||||||
def jieba_extract_keywords(
|
|
||||||
text_chunk: str,
|
|
||||||
max_keywords: Optional[int] = None,
|
|
||||||
expand_with_subtokens: bool = True,
|
|
||||||
) -> Set[str]:
|
|
||||||
"""Extract keywords with JIEBA tfidf."""
|
|
||||||
keywords = jieba.analyse.extract_tags(
|
|
||||||
sentence=text_chunk,
|
|
||||||
topK=max_keywords,
|
|
||||||
)
|
|
||||||
|
|
||||||
if expand_with_subtokens:
|
|
||||||
return set(expand_tokens_with_subtokens(keywords))
|
|
||||||
else:
|
|
||||||
return set(keywords)
|
|
||||||
|
|
||||||
|
|
||||||
def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]:
|
|
||||||
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
|
||||||
results = set()
|
|
||||||
for token in tokens:
|
|
||||||
results.add(token)
|
|
||||||
sub_tokens = re.findall(r"\w+", token)
|
|
||||||
if len(sub_tokens) > 1:
|
|
||||||
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex):
|
|
||||||
"""GPT JIEBA Keyword Table Index.
|
|
||||||
|
|
||||||
This index uses a JIEBA keyword extractor to extract keywords from the text.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _extract_keywords(self, text: str) -> Set[str]:
|
|
||||||
"""Extract keywords from text."""
|
|
||||||
return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_query_map(self) -> QueryMap:
|
|
||||||
"""Get query map."""
|
|
||||||
super_map = super().get_query_map()
|
|
||||||
super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery
|
|
||||||
return super_map
|
|
||||||
|
|
||||||
def _delete(self, doc_id: str, **delete_kwargs: Any) -> None:
|
|
||||||
"""Delete a document."""
|
|
||||||
# get set of ids that correspond to node
|
|
||||||
node_idxs_to_delete = {doc_id}
|
|
||||||
|
|
||||||
# delete node_idxs from keyword to node idxs mapping
|
|
||||||
keywords_to_delete = set()
|
|
||||||
for keyword, node_idxs in self._index_struct.table.items():
|
|
||||||
if node_idxs_to_delete.intersection(node_idxs):
|
|
||||||
self._index_struct.table[keyword] = node_idxs.difference(
|
|
||||||
node_idxs_to_delete
|
|
||||||
)
|
|
||||||
if not self._index_struct.table[keyword]:
|
|
||||||
keywords_to_delete.add(keyword)
|
|
||||||
|
|
||||||
for keyword in keywords_to_delete:
|
|
||||||
del self._index_struct.table[keyword]
|
|
||||||
|
|
||||||
|
|
||||||
class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery):
|
|
||||||
"""GPT Keyword Table Index JIEBA Query.
|
|
||||||
|
|
||||||
Extracts keywords using JIEBA keyword extractor.
|
|
||||||
Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
response = index.query("<query_str>", mode="jieba")
|
|
||||||
|
|
||||||
See BaseGPTKeywordTableQuery for arguments.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_args(
|
|
||||||
cls,
|
|
||||||
index_struct: IS,
|
|
||||||
service_context: ServiceContext,
|
|
||||||
docstore: Optional[BaseDocumentStore] = None,
|
|
||||||
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
|
||||||
verbose: bool = False,
|
|
||||||
# response synthesizer args
|
|
||||||
response_mode: ResponseMode = ResponseMode.DEFAULT,
|
|
||||||
text_qa_template: Optional[QuestionAnswerPrompt] = None,
|
|
||||||
refine_template: Optional[RefinePrompt] = None,
|
|
||||||
simple_template: Optional[SimpleInputPrompt] = None,
|
|
||||||
response_kwargs: Optional[Dict] = None,
|
|
||||||
use_async: bool = False,
|
|
||||||
streaming: bool = False,
|
|
||||||
optimizer: Optional[BaseTokenUsageOptimizer] = None,
|
|
||||||
# class-specific args
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> "BaseGPTIndexQuery":
|
|
||||||
response_synthesizer = EnhanceResponseSynthesizer.from_args(
|
|
||||||
service_context=service_context,
|
|
||||||
text_qa_template=text_qa_template,
|
|
||||||
refine_template=refine_template,
|
|
||||||
simple_template=simple_template,
|
|
||||||
response_mode=response_mode,
|
|
||||||
response_kwargs=response_kwargs,
|
|
||||||
use_async=use_async,
|
|
||||||
streaming=streaming,
|
|
||||||
optimizer=optimizer,
|
|
||||||
)
|
|
||||||
return cls(
|
|
||||||
index_struct=index_struct,
|
|
||||||
service_context=service_context,
|
|
||||||
response_synthesizer=response_synthesizer,
|
|
||||||
docstore=docstore,
|
|
||||||
node_postprocessors=node_postprocessors,
|
|
||||||
verbose=verbose,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_keywords(self, query_str: str) -> List[str]:
|
|
||||||
"""Extract keywords."""
|
|
||||||
return list(
|
|
||||||
jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
|
|
||||||
)
|
|
@ -1,135 +0,0 @@
|
|||||||
import json
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding
|
|
||||||
from llama_index.data_structs import KeywordTable, Node
|
|
||||||
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
|
|
||||||
from llama_index.indices.registry import load_index_struct_from_dict
|
|
||||||
|
|
||||||
from core.docstore.dataset_docstore import DatesetDocumentStore
|
|
||||||
from core.docstore.empty_docstore import EmptyDocumentStore
|
|
||||||
from core.index.index_builder import IndexBuilder
|
|
||||||
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
|
|
||||||
from core.llm.llm_builder import LLMBuilder
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
|
|
||||||
|
|
||||||
|
|
||||||
class KeywordTableIndex:
|
|
||||||
|
|
||||||
def __init__(self, dataset: Dataset):
|
|
||||||
self._dataset = dataset
|
|
||||||
|
|
||||||
def add_nodes(self, nodes: List[Node]):
|
|
||||||
llm = LLMBuilder.to_llm(
|
|
||||||
tenant_id=self._dataset.tenant_id,
|
|
||||||
model_name='fake'
|
|
||||||
)
|
|
||||||
|
|
||||||
service_context = ServiceContext.from_defaults(
|
|
||||||
llm_predictor=LLMPredictor(llm=llm),
|
|
||||||
embed_model=OpenAIEmbedding()
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_keyword_table = self.get_keyword_table()
|
|
||||||
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
|
|
||||||
index_struct = KeywordTable()
|
|
||||||
else:
|
|
||||||
index_struct_dict = dataset_keyword_table.keyword_table_dict
|
|
||||||
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
|
|
||||||
|
|
||||||
# create index
|
|
||||||
index = GPTJIEBAKeywordTableIndex(
|
|
||||||
index_struct=index_struct,
|
|
||||||
docstore=EmptyDocumentStore(),
|
|
||||||
service_context=service_context
|
|
||||||
)
|
|
||||||
|
|
||||||
for node in nodes:
|
|
||||||
keywords = index._extract_keywords(node.get_text())
|
|
||||||
self.update_segment_keywords(node.doc_id, list(keywords))
|
|
||||||
index._index_struct.add_node(list(keywords), node)
|
|
||||||
|
|
||||||
index_struct_dict = index.index_struct.to_dict()
|
|
||||||
|
|
||||||
if not dataset_keyword_table:
|
|
||||||
dataset_keyword_table = DatasetKeywordTable(
|
|
||||||
dataset_id=self._dataset.id,
|
|
||||||
keyword_table=json.dumps(index_struct_dict)
|
|
||||||
)
|
|
||||||
db.session.add(dataset_keyword_table)
|
|
||||||
else:
|
|
||||||
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def del_nodes(self, node_ids: List[str]):
|
|
||||||
llm = LLMBuilder.to_llm(
|
|
||||||
tenant_id=self._dataset.tenant_id,
|
|
||||||
model_name='fake'
|
|
||||||
)
|
|
||||||
|
|
||||||
service_context = ServiceContext.from_defaults(
|
|
||||||
llm_predictor=LLMPredictor(llm=llm),
|
|
||||||
embed_model=OpenAIEmbedding()
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_keyword_table = self.get_keyword_table()
|
|
||||||
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
index_struct_dict = dataset_keyword_table.keyword_table_dict
|
|
||||||
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
|
|
||||||
|
|
||||||
# create index
|
|
||||||
index = GPTJIEBAKeywordTableIndex(
|
|
||||||
index_struct=index_struct,
|
|
||||||
docstore=EmptyDocumentStore(),
|
|
||||||
service_context=service_context
|
|
||||||
)
|
|
||||||
|
|
||||||
for node_id in node_ids:
|
|
||||||
index.delete(node_id)
|
|
||||||
|
|
||||||
index_struct_dict = index.index_struct.to_dict()
|
|
||||||
|
|
||||||
if not dataset_keyword_table:
|
|
||||||
dataset_keyword_table = DatasetKeywordTable(
|
|
||||||
dataset_id=self._dataset.id,
|
|
||||||
keyword_table=json.dumps(index_struct_dict)
|
|
||||||
)
|
|
||||||
db.session.add(dataset_keyword_table)
|
|
||||||
else:
|
|
||||||
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def query_index(self) -> Optional[BaseGPTKeywordTableIndex]:
|
|
||||||
docstore = DatesetDocumentStore(
|
|
||||||
dataset=self._dataset,
|
|
||||||
user_id=self._dataset.created_by,
|
|
||||||
embedding_model_name="text-embedding-ada-002"
|
|
||||||
)
|
|
||||||
|
|
||||||
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
|
|
||||||
|
|
||||||
dataset_keyword_table = self.get_keyword_table()
|
|
||||||
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
|
|
||||||
return None
|
|
||||||
|
|
||||||
index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict)
|
|
||||||
|
|
||||||
return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context)
|
|
||||||
|
|
||||||
def get_keyword_table(self):
|
|
||||||
dataset_keyword_table = self._dataset.dataset_keyword_table
|
|
||||||
if dataset_keyword_table:
|
|
||||||
return dataset_keyword_table
|
|
||||||
return None
|
|
||||||
|
|
||||||
def update_segment_keywords(self, node_id: str, keywords: List[str]):
|
|
||||||
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
|
|
||||||
if document_segment:
|
|
||||||
document_segment.keywords = keywords
|
|
||||||
db.session.commit()
|
|
@ -0,0 +1,33 @@
|
|||||||
|
import re
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
from jieba.analyse import default_tfidf
|
||||||
|
|
||||||
|
from core.index.keyword_table_index.stopwords import STOPWORDS
|
||||||
|
|
||||||
|
|
||||||
|
class JiebaKeywordTableHandler:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
default_tfidf.stop_words = STOPWORDS
|
||||||
|
|
||||||
|
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
|
||||||
|
"""Extract keywords with JIEBA tfidf."""
|
||||||
|
keywords = jieba.analyse.extract_tags(
|
||||||
|
sentence=text,
|
||||||
|
topK=max_keywords_per_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
return set(self._expand_tokens_with_subtokens(keywords))
|
||||||
|
|
||||||
|
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
|
||||||
|
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
||||||
|
results = set()
|
||||||
|
for token in tokens:
|
||||||
|
results.add(token)
|
||||||
|
sub_tokens = re.findall(r"\w+", token)
|
||||||
|
if len(sub_tokens) > 1:
|
||||||
|
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
|
||||||
|
|
||||||
|
return results
|
238
api/core/index/keyword_table_index/keyword_table_index.py
Normal file
238
api/core/index/keyword_table_index/keyword_table_index.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, List, Optional, Dict
|
||||||
|
|
||||||
|
from langchain.schema import Document, BaseRetriever
|
||||||
|
from pydantic import BaseModel, Field, Extra
|
||||||
|
|
||||||
|
from core.index.base import BaseIndex
|
||||||
|
from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordTableConfig(BaseModel):
|
||||||
|
max_keywords_per_chunk: int = 10
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordTableIndex(BaseIndex):
|
||||||
|
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
|
||||||
|
super().__init__(dataset)
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||||
|
keyword_table_handler = JiebaKeywordTableHandler()
|
||||||
|
keyword_table = {}
|
||||||
|
for text in texts:
|
||||||
|
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||||
|
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
|
||||||
|
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||||
|
|
||||||
|
dataset_keyword_table = DatasetKeywordTable(
|
||||||
|
dataset_id=self.dataset.id,
|
||||||
|
keyword_table=json.dumps({
|
||||||
|
'__type__': 'keyword_table',
|
||||||
|
'__data__': {
|
||||||
|
"index_id": self.dataset.id,
|
||||||
|
"summary": None,
|
||||||
|
"table": {}
|
||||||
|
}
|
||||||
|
}, cls=SetEncoder)
|
||||||
|
)
|
||||||
|
db.session.add(dataset_keyword_table)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
self._save_dataset_keyword_table(keyword_table)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_texts(self, texts: list[Document], **kwargs):
|
||||||
|
keyword_table_handler = JiebaKeywordTableHandler()
|
||||||
|
|
||||||
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
|
for text in texts:
|
||||||
|
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||||
|
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
|
||||||
|
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||||
|
|
||||||
|
self._save_dataset_keyword_table(keyword_table)
|
||||||
|
|
||||||
|
def text_exists(self, id: str) -> bool:
|
||||||
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
|
return id in set.union(*keyword_table.values())
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
|
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||||
|
|
||||||
|
self._save_dataset_keyword_table(keyword_table)
|
||||||
|
|
||||||
|
def delete_by_document_id(self, document_id: str):
|
||||||
|
# get segment ids by document_id
|
||||||
|
segments = db.session.query(DocumentSegment).filter(
|
||||||
|
DocumentSegment.dataset_id == self.dataset.id,
|
||||||
|
DocumentSegment.document_id == document_id
|
||||||
|
).all()
|
||||||
|
|
||||||
|
ids = [segment.id for segment in segments]
|
||||||
|
|
||||||
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
|
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||||
|
|
||||||
|
self._save_dataset_keyword_table(keyword_table)
|
||||||
|
|
||||||
|
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||||
|
return KeywordTableRetriever(index=self, **kwargs)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, query: str,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
keyword_table = self._get_dataset_keyword_table()
|
||||||
|
|
||||||
|
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
||||||
|
k = search_kwargs.get('k') if search_kwargs.get('k') else 4
|
||||||
|
|
||||||
|
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for chunk_index in sorted_chunk_indices:
|
||||||
|
segment = db.session.query(DocumentSegment).filter(
|
||||||
|
DocumentSegment.dataset_id == self.dataset.id,
|
||||||
|
DocumentSegment.index_node_id == chunk_index
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if segment:
|
||||||
|
documents.append(Document(
|
||||||
|
page_content=segment.content,
|
||||||
|
metadata={
|
||||||
|
"doc_id": chunk_index,
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
}
|
||||||
|
))
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||||
|
if dataset_keyword_table:
|
||||||
|
db.session.delete(dataset_keyword_table)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def _save_dataset_keyword_table(self, keyword_table):
|
||||||
|
keyword_table_dict = {
|
||||||
|
'__type__': 'keyword_table',
|
||||||
|
'__data__': {
|
||||||
|
"index_id": self.dataset.id,
|
||||||
|
"summary": None,
|
||||||
|
"table": keyword_table
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def _get_dataset_keyword_table(self) -> Optional[dict]:
|
||||||
|
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||||
|
if dataset_keyword_table:
|
||||||
|
if dataset_keyword_table.keyword_table_dict:
|
||||||
|
return dataset_keyword_table.keyword_table_dict['__data__']['table']
|
||||||
|
else:
|
||||||
|
dataset_keyword_table = DatasetKeywordTable(
|
||||||
|
dataset_id=self.dataset.id,
|
||||||
|
keyword_table=json.dumps({
|
||||||
|
'__type__': 'keyword_table',
|
||||||
|
'__data__': {
|
||||||
|
"index_id": self.dataset.id,
|
||||||
|
"summary": None,
|
||||||
|
"table": {}
|
||||||
|
}
|
||||||
|
}, cls=SetEncoder)
|
||||||
|
)
|
||||||
|
db.session.add(dataset_keyword_table)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
|
||||||
|
for keyword in keywords:
|
||||||
|
if keyword not in keyword_table:
|
||||||
|
keyword_table[keyword] = set()
|
||||||
|
keyword_table[keyword].add(id)
|
||||||
|
return keyword_table
|
||||||
|
|
||||||
|
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict:
|
||||||
|
# get set of ids that correspond to node
|
||||||
|
node_idxs_to_delete = set(ids)
|
||||||
|
|
||||||
|
# delete node_idxs from keyword to node idxs mapping
|
||||||
|
keywords_to_delete = set()
|
||||||
|
for keyword, node_idxs in keyword_table.items():
|
||||||
|
if node_idxs_to_delete.intersection(node_idxs):
|
||||||
|
keyword_table[keyword] = node_idxs.difference(
|
||||||
|
node_idxs_to_delete
|
||||||
|
)
|
||||||
|
if not keyword_table[keyword]:
|
||||||
|
keywords_to_delete.add(keyword)
|
||||||
|
|
||||||
|
for keyword in keywords_to_delete:
|
||||||
|
del keyword_table[keyword]
|
||||||
|
|
||||||
|
return keyword_table
|
||||||
|
|
||||||
|
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
|
||||||
|
keyword_table_handler = JiebaKeywordTableHandler()
|
||||||
|
keywords = keyword_table_handler.extract_keywords(query)
|
||||||
|
|
||||||
|
# go through text chunks in order of most matching keywords
|
||||||
|
chunk_indices_count: Dict[str, int] = defaultdict(int)
|
||||||
|
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
|
||||||
|
for keyword in keywords:
|
||||||
|
for node_id in keyword_table[keyword]:
|
||||||
|
chunk_indices_count[node_id] += 1
|
||||||
|
|
||||||
|
sorted_chunk_indices = sorted(
|
||||||
|
list(chunk_indices_count.keys()),
|
||||||
|
key=lambda x: chunk_indices_count[x],
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return sorted_chunk_indices[: k]
|
||||||
|
|
||||||
|
def _update_segment_keywords(self, node_id: str, keywords: List[str]):
|
||||||
|
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
|
||||||
|
if document_segment:
|
||||||
|
document_segment.keywords = keywords
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordTableRetriever(BaseRetriever, BaseModel):
|
||||||
|
index: KeywordTableIndex
|
||||||
|
search_kwargs: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
"""Get documents relevant for a query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: string to find relevant documents for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of relevant documents
|
||||||
|
"""
|
||||||
|
return self.index.search(query, **self.search_kwargs)
|
||||||
|
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
raise NotImplementedError("KeywordTableRetriever does not support async")
|
||||||
|
|
||||||
|
|
||||||
|
class SetEncoder(json.JSONEncoder):
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, set):
|
||||||
|
return list(obj)
|
||||||
|
return super().default(obj)
|
@ -1,79 +0,0 @@
|
|||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
Optional, Sequence,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_index.indices.response.response_synthesis import ResponseSynthesizer
|
|
||||||
from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder
|
|
||||||
from llama_index.indices.service_context import ServiceContext
|
|
||||||
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
|
|
||||||
from llama_index.prompts.prompts import (
|
|
||||||
QuestionAnswerPrompt,
|
|
||||||
RefinePrompt,
|
|
||||||
SimpleInputPrompt,
|
|
||||||
)
|
|
||||||
from llama_index.types import RESPONSE_TEXT_TYPE
|
|
||||||
|
|
||||||
|
|
||||||
class EnhanceResponseSynthesizer(ResponseSynthesizer):
|
|
||||||
@classmethod
|
|
||||||
def from_args(
|
|
||||||
cls,
|
|
||||||
service_context: ServiceContext,
|
|
||||||
streaming: bool = False,
|
|
||||||
use_async: bool = False,
|
|
||||||
text_qa_template: Optional[QuestionAnswerPrompt] = None,
|
|
||||||
refine_template: Optional[RefinePrompt] = None,
|
|
||||||
simple_template: Optional[SimpleInputPrompt] = None,
|
|
||||||
response_mode: ResponseMode = ResponseMode.DEFAULT,
|
|
||||||
response_kwargs: Optional[Dict] = None,
|
|
||||||
optimizer: Optional[BaseTokenUsageOptimizer] = None,
|
|
||||||
) -> "ResponseSynthesizer":
|
|
||||||
response_builder: Optional[BaseResponseBuilder] = None
|
|
||||||
if response_mode != ResponseMode.NO_TEXT:
|
|
||||||
if response_mode == 'no_synthesizer':
|
|
||||||
response_builder = NoSynthesizer(
|
|
||||||
service_context=service_context,
|
|
||||||
simple_template=simple_template,
|
|
||||||
streaming=streaming,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response_builder = get_response_builder(
|
|
||||||
service_context,
|
|
||||||
text_qa_template,
|
|
||||||
refine_template,
|
|
||||||
simple_template,
|
|
||||||
response_mode,
|
|
||||||
use_async=use_async,
|
|
||||||
streaming=streaming,
|
|
||||||
)
|
|
||||||
return cls(response_builder, response_mode, response_kwargs, optimizer)
|
|
||||||
|
|
||||||
|
|
||||||
class NoSynthesizer(BaseResponseBuilder):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
service_context: ServiceContext,
|
|
||||||
simple_template: Optional[SimpleInputPrompt] = None,
|
|
||||||
streaming: bool = False,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(service_context, streaming)
|
|
||||||
|
|
||||||
async def aget_response(
|
|
||||||
self,
|
|
||||||
query_str: str,
|
|
||||||
text_chunks: Sequence[str],
|
|
||||||
prev_response: Optional[str] = None,
|
|
||||||
**response_kwargs: Any,
|
|
||||||
) -> RESPONSE_TEXT_TYPE:
|
|
||||||
return "\n".join(text_chunks)
|
|
||||||
|
|
||||||
def get_response(
|
|
||||||
self,
|
|
||||||
query_str: str,
|
|
||||||
text_chunks: Sequence[str],
|
|
||||||
prev_response: Optional[str] = None,
|
|
||||||
**response_kwargs: Any,
|
|
||||||
) -> RESPONSE_TEXT_TYPE:
|
|
||||||
return "\n".join(text_chunks)
|
|
@ -1,22 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
from llama_index.readers.file.base_parser import BaseParser
|
|
||||||
|
|
||||||
|
|
||||||
class HTMLParser(BaseParser):
|
|
||||||
"""HTML parser."""
|
|
||||||
|
|
||||||
def _init_parser(self) -> Dict:
|
|
||||||
"""Init parser."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
|
||||||
"""Parse file."""
|
|
||||||
with open(file, "rb") as fp:
|
|
||||||
soup = BeautifulSoup(fp, 'html.parser')
|
|
||||||
text = soup.get_text()
|
|
||||||
text = text.strip() if text else ''
|
|
||||||
|
|
||||||
return text
|
|
@ -1,111 +0,0 @@
|
|||||||
"""Markdown parser.
|
|
||||||
|
|
||||||
Contains parser for md files.
|
|
||||||
|
|
||||||
"""
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
|
||||||
|
|
||||||
from llama_index.readers.file.base_parser import BaseParser
|
|
||||||
|
|
||||||
|
|
||||||
class MarkdownParser(BaseParser):
|
|
||||||
"""Markdown parser.
|
|
||||||
|
|
||||||
Extract text from markdown files.
|
|
||||||
Returns dictionary with keys as headers and values as the text between headers.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args: Any,
|
|
||||||
remove_hyperlinks: bool = True,
|
|
||||||
remove_images: bool = True,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Init params."""
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._remove_hyperlinks = remove_hyperlinks
|
|
||||||
self._remove_images = remove_images
|
|
||||||
|
|
||||||
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
|
|
||||||
"""Convert a markdown file to a dictionary.
|
|
||||||
|
|
||||||
The keys are the headers and the values are the text under each header.
|
|
||||||
|
|
||||||
"""
|
|
||||||
markdown_tups: List[Tuple[Optional[str], str]] = []
|
|
||||||
lines = markdown_text.split("\n")
|
|
||||||
|
|
||||||
current_header = None
|
|
||||||
current_text = ""
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
header_match = re.match(r"^#+\s", line)
|
|
||||||
if header_match:
|
|
||||||
if current_header is not None:
|
|
||||||
markdown_tups.append((current_header, current_text))
|
|
||||||
|
|
||||||
current_header = line
|
|
||||||
current_text = ""
|
|
||||||
else:
|
|
||||||
current_text += line + "\n"
|
|
||||||
markdown_tups.append((current_header, current_text))
|
|
||||||
|
|
||||||
if current_header is not None:
|
|
||||||
# pass linting, assert keys are defined
|
|
||||||
markdown_tups = [
|
|
||||||
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
|
|
||||||
for key, value in markdown_tups
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
markdown_tups = [
|
|
||||||
(key, re.sub("\n", "", value)) for key, value in markdown_tups
|
|
||||||
]
|
|
||||||
|
|
||||||
return markdown_tups
|
|
||||||
|
|
||||||
def remove_images(self, content: str) -> str:
|
|
||||||
"""Get a dictionary of a markdown file from its path."""
|
|
||||||
pattern = r"!{1}\[\[(.*)\]\]"
|
|
||||||
content = re.sub(pattern, "", content)
|
|
||||||
return content
|
|
||||||
|
|
||||||
def remove_hyperlinks(self, content: str) -> str:
|
|
||||||
"""Get a dictionary of a markdown file from its path."""
|
|
||||||
pattern = r"\[(.*?)\]\((.*?)\)"
|
|
||||||
content = re.sub(pattern, r"\1", content)
|
|
||||||
return content
|
|
||||||
|
|
||||||
def _init_parser(self) -> Dict:
|
|
||||||
"""Initialize the parser with the config."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def parse_tups(
|
|
||||||
self, filepath: Path, errors: str = "ignore"
|
|
||||||
) -> List[Tuple[Optional[str], str]]:
|
|
||||||
"""Parse file into tuples."""
|
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
|
||||||
content = f.read()
|
|
||||||
if self._remove_hyperlinks:
|
|
||||||
content = self.remove_hyperlinks(content)
|
|
||||||
if self._remove_images:
|
|
||||||
content = self.remove_images(content)
|
|
||||||
markdown_tups = self.markdown_to_tups(content)
|
|
||||||
return markdown_tups
|
|
||||||
|
|
||||||
def parse_file(
|
|
||||||
self, filepath: Path, errors: str = "ignore"
|
|
||||||
) -> Union[str, List[str]]:
|
|
||||||
"""Parse file into string."""
|
|
||||||
tups = self.parse_tups(filepath, errors=errors)
|
|
||||||
results = []
|
|
||||||
# TODO: don't include headers right now
|
|
||||||
for header, value in tups:
|
|
||||||
if header is None:
|
|
||||||
results.append(value)
|
|
||||||
else:
|
|
||||||
results.append(f"\n\n{header}\n{value}")
|
|
||||||
return results
|
|
@ -1,56 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from flask import current_app
|
|
||||||
from llama_index.readers.file.base_parser import BaseParser
|
|
||||||
from pypdf import PdfReader
|
|
||||||
|
|
||||||
from extensions.ext_storage import storage
|
|
||||||
from models.model import UploadFile
|
|
||||||
|
|
||||||
|
|
||||||
class PDFParser(BaseParser):
|
|
||||||
"""PDF parser."""
|
|
||||||
|
|
||||||
def _init_parser(self) -> Dict:
|
|
||||||
"""Init parser."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
|
||||||
"""Parse file."""
|
|
||||||
if not current_app.config.get('PDF_PREVIEW', True):
|
|
||||||
return ''
|
|
||||||
|
|
||||||
plaintext_file_key = ''
|
|
||||||
plaintext_file_exists = False
|
|
||||||
if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']:
|
|
||||||
upload_file: UploadFile = self._parser_config['upload_file']
|
|
||||||
if upload_file.hash:
|
|
||||||
plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext'
|
|
||||||
try:
|
|
||||||
text = storage.load(plaintext_file_key).decode('utf-8')
|
|
||||||
plaintext_file_exists = True
|
|
||||||
return text
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
text_list = []
|
|
||||||
with open(file, "rb") as fp:
|
|
||||||
# Create a PDF object
|
|
||||||
pdf = PdfReader(fp)
|
|
||||||
|
|
||||||
# Get the number of pages in the PDF document
|
|
||||||
num_pages = len(pdf.pages)
|
|
||||||
|
|
||||||
# Iterate over every page
|
|
||||||
for page in range(num_pages):
|
|
||||||
# Extract the text from the page
|
|
||||||
page_text = pdf.pages[page].extract_text()
|
|
||||||
text_list.append(page_text)
|
|
||||||
text = "\n".join(text_list)
|
|
||||||
|
|
||||||
# save plaintext file for caching
|
|
||||||
if not plaintext_file_exists and plaintext_file_key:
|
|
||||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
|
||||||
|
|
||||||
return text
|
|
@ -1,33 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
import json
|
|
||||||
from typing import Dict
|
|
||||||
from openpyxl import load_workbook
|
|
||||||
|
|
||||||
from llama_index.readers.file.base_parser import BaseParser
|
|
||||||
from flask import current_app
|
|
||||||
|
|
||||||
|
|
||||||
class XLSXParser(BaseParser):
|
|
||||||
"""XLSX parser."""
|
|
||||||
|
|
||||||
def _init_parser(self) -> Dict:
|
|
||||||
"""Init parser"""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
|
||||||
data = []
|
|
||||||
keys = []
|
|
||||||
with open(file, "r") as fp:
|
|
||||||
wb = load_workbook(filename=file, read_only=True)
|
|
||||||
# loop over all sheets
|
|
||||||
for sheet in wb:
|
|
||||||
for row in sheet.iter_rows(values_only=True):
|
|
||||||
if all(v is None for v in row):
|
|
||||||
continue
|
|
||||||
if keys == []:
|
|
||||||
keys = list(map(str, row))
|
|
||||||
else:
|
|
||||||
row_dict = dict(zip(keys, row))
|
|
||||||
row_dict = {k: v for k, v in row_dict.items() if v}
|
|
||||||
data.append(json.dumps(row_dict, ensure_ascii=False))
|
|
||||||
return '\n\n'.join(data)
|
|
@ -1,136 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from llama_index.data_structs import Node
|
|
||||||
from requests import ReadTimeout
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
from tenacity import retry, stop_after_attempt, retry_if_exception_type
|
|
||||||
|
|
||||||
from core.index.index_builder import IndexBuilder
|
|
||||||
from core.vector_store.base import BaseGPTVectorStoreIndex
|
|
||||||
from extensions.ext_vector_store import vector_store
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset, Embedding
|
|
||||||
|
|
||||||
|
|
||||||
class VectorIndex:
|
|
||||||
|
|
||||||
def __init__(self, dataset: Dataset):
|
|
||||||
self._dataset = dataset
|
|
||||||
|
|
||||||
def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
|
|
||||||
if not self._dataset.index_struct_dict:
|
|
||||||
index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
|
|
||||||
self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
|
|
||||||
|
|
||||||
index = vector_store.get_index(
|
|
||||||
service_context=service_context,
|
|
||||||
index_struct=self._dataset.index_struct_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
if duplicate_check:
|
|
||||||
nodes = self._filter_duplicate_nodes(index, nodes)
|
|
||||||
|
|
||||||
embedding_queue_nodes = []
|
|
||||||
embedded_nodes = []
|
|
||||||
for node in nodes:
|
|
||||||
node_hash = node.doc_hash
|
|
||||||
|
|
||||||
# if node hash in cached embedding tables, use cached embedding
|
|
||||||
embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
|
|
||||||
if embedding:
|
|
||||||
node.embedding = embedding.get_embedding()
|
|
||||||
embedded_nodes.append(node)
|
|
||||||
else:
|
|
||||||
embedding_queue_nodes.append(node)
|
|
||||||
|
|
||||||
if embedding_queue_nodes:
|
|
||||||
embedding_results = index._get_node_embedding_results(
|
|
||||||
embedding_queue_nodes,
|
|
||||||
set(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# pre embed nodes for cached embedding
|
|
||||||
for embedding_result in embedding_results:
|
|
||||||
node = embedding_result.node
|
|
||||||
node.embedding = embedding_result.embedding
|
|
||||||
|
|
||||||
try:
|
|
||||||
embedding = Embedding(hash=node.doc_hash)
|
|
||||||
embedding.set_embedding(node.embedding)
|
|
||||||
db.session.add(embedding)
|
|
||||||
db.session.commit()
|
|
||||||
except IntegrityError:
|
|
||||||
db.session.rollback()
|
|
||||||
continue
|
|
||||||
except:
|
|
||||||
logging.exception('Failed to add embedding to db')
|
|
||||||
continue
|
|
||||||
|
|
||||||
embedded_nodes.append(node)
|
|
||||||
|
|
||||||
self.index_insert_nodes(index, embedded_nodes)
|
|
||||||
|
|
||||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
|
||||||
def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
|
|
||||||
index.insert_nodes(nodes)
|
|
||||||
|
|
||||||
def del_nodes(self, node_ids: List[str]):
|
|
||||||
if not self._dataset.index_struct_dict:
|
|
||||||
return
|
|
||||||
|
|
||||||
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
|
|
||||||
|
|
||||||
index = vector_store.get_index(
|
|
||||||
service_context=service_context,
|
|
||||||
index_struct=self._dataset.index_struct_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
for node_id in node_ids:
|
|
||||||
self.index_delete_node(index, node_id)
|
|
||||||
|
|
||||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
|
||||||
def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
|
|
||||||
index.delete_node(node_id)
|
|
||||||
|
|
||||||
def del_doc(self, doc_id: str):
|
|
||||||
if not self._dataset.index_struct_dict:
|
|
||||||
return
|
|
||||||
|
|
||||||
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
|
|
||||||
|
|
||||||
index = vector_store.get_index(
|
|
||||||
service_context=service_context,
|
|
||||||
index_struct=self._dataset.index_struct_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
self.index_delete_doc(index, doc_id)
|
|
||||||
|
|
||||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
|
||||||
def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
|
|
||||||
index.delete(doc_id)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
|
|
||||||
if not self._dataset.index_struct_dict:
|
|
||||||
return None
|
|
||||||
|
|
||||||
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
|
|
||||||
|
|
||||||
return vector_store.get_index(
|
|
||||||
service_context=service_context,
|
|
||||||
index_struct=self._dataset.index_struct_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
|
|
||||||
for node in nodes:
|
|
||||||
node_id = node.doc_id
|
|
||||||
exists_duplicate_node = index.exists_by_node_id(node_id)
|
|
||||||
if exists_duplicate_node:
|
|
||||||
nodes.remove(node)
|
|
||||||
|
|
||||||
return nodes
|
|
175
api/core/index/vector_index/base.py
Normal file
175
api/core/index/vector_index/base.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import List, Any, cast
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import Document, BaseRetriever
|
||||||
|
from langchain.vectorstores import VectorStore
|
||||||
|
from weaviate import UnexpectedStatusCodeException
|
||||||
|
|
||||||
|
from core.index.base import BaseIndex
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset, DocumentSegment
|
||||||
|
from models.dataset import Document as DatasetDocument
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVectorIndex(BaseIndex):
|
||||||
|
|
||||||
|
def __init__(self, dataset: Dataset, embeddings: Embeddings):
|
||||||
|
super().__init__(dataset)
|
||||||
|
self._embeddings = embeddings
|
||||||
|
self._vector_store = None
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_index_name(self, dataset: Dataset) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_index_struct(self) -> dict:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_vector_store(self) -> VectorStore:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_vector_store_class(self) -> type:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, query: str,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
vector_store = self._get_vector_store()
|
||||||
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
|
||||||
|
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
|
||||||
|
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
||||||
|
|
||||||
|
if search_type == 'similarity_score_threshold':
|
||||||
|
score_threshold = search_kwargs.get("score_threshold")
|
||||||
|
if (score_threshold is None) or (not isinstance(score_threshold, float)):
|
||||||
|
search_kwargs['score_threshold'] = .0
|
||||||
|
|
||||||
|
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
|
||||||
|
query, **search_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
for doc, similarity in docs_with_similarity:
|
||||||
|
doc.metadata['score'] = similarity
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
# similarity k
|
||||||
|
# mmr k, fetch_k, lambda_mult
|
||||||
|
# similarity_score_threshold k
|
||||||
|
return vector_store.as_retriever(
|
||||||
|
search_type=search_type,
|
||||||
|
search_kwargs=search_kwargs
|
||||||
|
).get_relevant_documents(query)
|
||||||
|
|
||||||
|
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||||
|
vector_store = self._get_vector_store()
|
||||||
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
|
||||||
|
return vector_store.as_retriever(**kwargs)
|
||||||
|
|
||||||
|
def add_texts(self, texts: list[Document], **kwargs):
|
||||||
|
if self._is_origin():
|
||||||
|
self.recreate_dataset(self.dataset)
|
||||||
|
|
||||||
|
vector_store = self._get_vector_store()
|
||||||
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
|
||||||
|
if kwargs.get('duplicate_check', False):
|
||||||
|
texts = self._filter_duplicate_texts(texts)
|
||||||
|
|
||||||
|
uuids = self._get_uuids(texts)
|
||||||
|
vector_store.add_documents(texts, uuids=uuids)
|
||||||
|
|
||||||
|
def text_exists(self, id: str) -> bool:
|
||||||
|
vector_store = self._get_vector_store()
|
||||||
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
|
||||||
|
return vector_store.text_exists(id)
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
if self._is_origin():
|
||||||
|
self.recreate_dataset(self.dataset)
|
||||||
|
return
|
||||||
|
|
||||||
|
vector_store = self._get_vector_store()
|
||||||
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
|
||||||
|
for node_id in ids:
|
||||||
|
vector_store.del_text(node_id)
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
vector_store = self._get_vector_store()
|
||||||
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
|
||||||
|
vector_store.delete()
|
||||||
|
|
||||||
|
def _is_origin(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def recreate_dataset(self, dataset: Dataset):
|
||||||
|
logging.info(f"Recreating dataset {dataset.id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.delete()
|
||||||
|
except UnexpectedStatusCodeException as e:
|
||||||
|
if e.status_code != 400:
|
||||||
|
# 400 means index not exists
|
||||||
|
raise e
|
||||||
|
|
||||||
|
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||||
|
DatasetDocument.dataset_id == dataset.id,
|
||||||
|
DatasetDocument.indexing_status == 'completed',
|
||||||
|
DatasetDocument.enabled == True,
|
||||||
|
DatasetDocument.archived == False,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for dataset_document in dataset_documents:
|
||||||
|
segments = db.session.query(DocumentSegment).filter(
|
||||||
|
DocumentSegment.document_id == dataset_document.id,
|
||||||
|
DocumentSegment.status == 'completed',
|
||||||
|
DocumentSegment.enabled == True
|
||||||
|
).all()
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
document = Document(
|
||||||
|
page_content=segment.content,
|
||||||
|
metadata={
|
||||||
|
"doc_id": segment.index_node_id,
|
||||||
|
"doc_hash": segment.index_node_hash,
|
||||||
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
documents.append(document)
|
||||||
|
|
||||||
|
origin_index_struct = self.dataset.index_struct
|
||||||
|
self.dataset.index_struct = None
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
try:
|
||||||
|
self.create(documents)
|
||||||
|
except Exception as e:
|
||||||
|
self.dataset.index_struct = origin_index_struct
|
||||||
|
raise e
|
||||||
|
|
||||||
|
dataset.index_struct = json.dumps(self.to_index_struct())
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
self.dataset = dataset
|
||||||
|
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
116
api/core/index/vector_index/qdrant_vector_index.py
Normal file
116
api/core/index/vector_index/qdrant_vector_index.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Any, List, cast
|
||||||
|
|
||||||
|
import qdrant_client
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import Document, BaseRetriever
|
||||||
|
from langchain.vectorstores import VectorStore
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.index.base import BaseIndex
|
||||||
|
from core.index.vector_index.base import BaseVectorIndex
|
||||||
|
from core.vector_store.qdrant_vector_store import QdrantVectorStore
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantConfig(BaseModel):
|
||||||
|
endpoint: str
|
||||||
|
api_key: Optional[str]
|
||||||
|
root_path: Optional[str]
|
||||||
|
|
||||||
|
def to_qdrant_params(self):
|
||||||
|
if self.endpoint and self.endpoint.startswith('path:'):
|
||||||
|
path = self.endpoint.replace('path:', '')
|
||||||
|
if not os.path.isabs(path):
|
||||||
|
path = os.path.join(self.root_path, path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'path': path
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
'url': self.endpoint,
|
||||||
|
'api_key': self.api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantVectorIndex(BaseVectorIndex):
|
||||||
|
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
|
||||||
|
super().__init__(dataset, embeddings)
|
||||||
|
self._client_config = config
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return 'qdrant'
|
||||||
|
|
||||||
|
def get_index_name(self, dataset: Dataset) -> str:
|
||||||
|
if self.dataset.index_struct_dict:
|
||||||
|
return self.dataset.index_struct_dict['vector_store']['collection_name']
|
||||||
|
|
||||||
|
dataset_id = dataset.id
|
||||||
|
return "Index_" + dataset_id.replace("-", "_")
|
||||||
|
|
||||||
|
def to_index_struct(self) -> dict:
|
||||||
|
return {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
|
||||||
|
}
|
||||||
|
|
||||||
|
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||||
|
uuids = self._get_uuids(texts)
|
||||||
|
self._vector_store = QdrantVectorStore.from_documents(
|
||||||
|
texts,
|
||||||
|
self._embeddings,
|
||||||
|
collection_name=self.get_index_name(self.dataset),
|
||||||
|
ids=uuids,
|
||||||
|
content_payload_key='text',
|
||||||
|
**self._client_config.to_qdrant_params()
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _get_vector_store(self) -> VectorStore:
|
||||||
|
"""Only for created index."""
|
||||||
|
if self._vector_store:
|
||||||
|
return self._vector_store
|
||||||
|
|
||||||
|
client = qdrant_client.QdrantClient(
|
||||||
|
**self._client_config.to_qdrant_params()
|
||||||
|
)
|
||||||
|
|
||||||
|
return QdrantVectorStore(
|
||||||
|
client=client,
|
||||||
|
collection_name=self.get_index_name(self.dataset),
|
||||||
|
embeddings=self._embeddings,
|
||||||
|
content_payload_key='text'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_vector_store_class(self) -> type:
|
||||||
|
return QdrantVectorStore
|
||||||
|
|
||||||
|
def delete_by_document_id(self, document_id: str):
|
||||||
|
if self._is_origin():
|
||||||
|
self.recreate_dataset(self.dataset)
|
||||||
|
return
|
||||||
|
|
||||||
|
vector_store = self._get_vector_store()
|
||||||
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
|
||||||
|
from qdrant_client.http import models
|
||||||
|
|
||||||
|
vector_store.del_texts(models.Filter(
|
||||||
|
must=[
|
||||||
|
models.FieldCondition(
|
||||||
|
key="metadata.document_id",
|
||||||
|
match=models.MatchValue(value=document_id),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
))
|
||||||
|
|
||||||
|
def _is_origin(self):
|
||||||
|
if self.dataset.index_struct_dict:
|
||||||
|
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
|
||||||
|
if class_prefix.startswith('Vector_'):
|
||||||
|
# original class_prefix
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
69
api/core/index/vector_index/vector_index.py
Normal file
69
api/core/index/vector_index/vector_index.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
|
from core.index.vector_index.base import BaseVectorIndex
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset, Document
|
||||||
|
|
||||||
|
|
||||||
|
class VectorIndex:
|
||||||
|
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
|
||||||
|
self._dataset = dataset
|
||||||
|
self._embeddings = embeddings
|
||||||
|
self._vector_index = self._init_vector_index(dataset, config, embeddings)
|
||||||
|
|
||||||
|
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
|
||||||
|
vector_type = config.get('VECTOR_STORE')
|
||||||
|
|
||||||
|
if self._dataset.index_struct_dict:
|
||||||
|
vector_type = self._dataset.index_struct_dict['type']
|
||||||
|
|
||||||
|
if not vector_type:
|
||||||
|
raise ValueError(f"Vector store must be specified.")
|
||||||
|
|
||||||
|
if vector_type == "weaviate":
|
||||||
|
from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig
|
||||||
|
|
||||||
|
return WeaviateVectorIndex(
|
||||||
|
dataset=dataset,
|
||||||
|
config=WeaviateConfig(
|
||||||
|
endpoint=config.get('WEAVIATE_ENDPOINT'),
|
||||||
|
api_key=config.get('WEAVIATE_API_KEY'),
|
||||||
|
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
||||||
|
),
|
||||||
|
embeddings=embeddings
|
||||||
|
)
|
||||||
|
elif vector_type == "qdrant":
|
||||||
|
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||||
|
|
||||||
|
return QdrantVectorIndex(
|
||||||
|
dataset=dataset,
|
||||||
|
config=QdrantConfig(
|
||||||
|
endpoint=config.get('QDRANT_URL'),
|
||||||
|
api_key=config.get('QDRANT_API_KEY'),
|
||||||
|
root_path=current_app.root_path
|
||||||
|
),
|
||||||
|
embeddings=embeddings
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||||
|
|
||||||
|
def add_texts(self, texts: list[Document], **kwargs):
|
||||||
|
if not self._dataset.index_struct_dict:
|
||||||
|
self._vector_index.create(texts, **kwargs)
|
||||||
|
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
|
||||||
|
db.session.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
self._vector_index.add_texts(texts, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if self._vector_index is not None:
|
||||||
|
method = getattr(self._vector_index, name)
|
||||||
|
if callable(method):
|
||||||
|
return method
|
||||||
|
|
||||||
|
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
|
||||||
|
|
132
api/core/index/vector_index/weaviate_vector_index.py
Normal file
132
api/core/index/vector_index/weaviate_vector_index.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
import weaviate
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import Document, BaseRetriever
|
||||||
|
from langchain.vectorstores import VectorStore
|
||||||
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from core.index.base import BaseIndex
|
||||||
|
from core.index.vector_index.base import BaseVectorIndex
|
||||||
|
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class WeaviateConfig(BaseModel):
|
||||||
|
endpoint: str
|
||||||
|
api_key: Optional[str]
|
||||||
|
batch_size: int = 100
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_config(cls, values: dict) -> dict:
|
||||||
|
if not values['endpoint']:
|
||||||
|
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class WeaviateVectorIndex(BaseVectorIndex):
|
||||||
|
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
|
||||||
|
super().__init__(dataset, embeddings)
|
||||||
|
self._client = self._init_client(config)
|
||||||
|
|
||||||
|
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
|
||||||
|
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
|
||||||
|
|
||||||
|
weaviate.connect.connection.has_grpc = False
|
||||||
|
|
||||||
|
client = weaviate.Client(
|
||||||
|
url=config.endpoint,
|
||||||
|
auth_client_secret=auth_config,
|
||||||
|
timeout_config=(5, 60),
|
||||||
|
startup_period=None
|
||||||
|
)
|
||||||
|
|
||||||
|
client.batch.configure(
|
||||||
|
# `batch_size` takes an `int` value to enable auto-batching
|
||||||
|
# (`None` is used for manual batching)
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
# dynamically update the `batch_size` based on import speed
|
||||||
|
dynamic=True,
|
||||||
|
# `timeout_retries` takes an `int` value to retry on time outs
|
||||||
|
timeout_retries=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return 'weaviate'
|
||||||
|
|
||||||
|
def get_index_name(self, dataset: Dataset) -> str:
|
||||||
|
if self.dataset.index_struct_dict:
|
||||||
|
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||||
|
if not class_prefix.endswith('_Node'):
|
||||||
|
# original class_prefix
|
||||||
|
class_prefix += '_Node'
|
||||||
|
|
||||||
|
return class_prefix
|
||||||
|
|
||||||
|
dataset_id = dataset.id
|
||||||
|
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||||
|
|
||||||
|
def to_index_struct(self) -> dict:
|
||||||
|
return {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||||
|
}
|
||||||
|
|
||||||
|
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||||
|
uuids = self._get_uuids(texts)
|
||||||
|
self._vector_store = WeaviateVectorStore.from_documents(
|
||||||
|
texts,
|
||||||
|
self._embeddings,
|
||||||
|
client=self._client,
|
||||||
|
index_name=self.get_index_name(self.dataset),
|
||||||
|
uuids=uuids,
|
||||||
|
by_text=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _get_vector_store(self) -> VectorStore:
|
||||||
|
"""Only for created index."""
|
||||||
|
if self._vector_store:
|
||||||
|
return self._vector_store
|
||||||
|
|
||||||
|
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||||
|
if self._is_origin():
|
||||||
|
attributes = ['doc_id']
|
||||||
|
|
||||||
|
return WeaviateVectorStore(
|
||||||
|
client=self._client,
|
||||||
|
index_name=self.get_index_name(self.dataset),
|
||||||
|
text_key='text',
|
||||||
|
embedding=self._embeddings,
|
||||||
|
attributes=attributes,
|
||||||
|
by_text=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_vector_store_class(self) -> type:
|
||||||
|
return WeaviateVectorStore
|
||||||
|
|
||||||
|
def delete_by_document_id(self, document_id: str):
|
||||||
|
if self._is_origin():
|
||||||
|
self.recreate_dataset(self.dataset)
|
||||||
|
return
|
||||||
|
|
||||||
|
vector_store = self._get_vector_store()
|
||||||
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
|
||||||
|
vector_store.del_texts({
|
||||||
|
"operator": "Equal",
|
||||||
|
"path": ["document_id"],
|
||||||
|
"valueText": document_id
|
||||||
|
})
|
||||||
|
|
||||||
|
def _is_origin(self):
|
||||||
|
if self.dataset.index_struct_dict:
|
||||||
|
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||||
|
if not class_prefix.endswith('_Node'):
|
||||||
|
# original class_prefix
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
@ -1,35 +1,34 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import tempfile
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
import uuid
|
||||||
from typing import Optional, List
|
from typing import Optional, List, cast
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
|
|
||||||
from llama_index import SimpleDirectoryReader
|
from core.data_loader.file_extractor import FileExtractor
|
||||||
from llama_index.data_structs import Node
|
from core.data_loader.loader.notion import NotionLoader
|
||||||
from llama_index.data_structs.node_v2 import DocumentRelationship
|
|
||||||
from llama_index.node_parser import SimpleNodeParser, NodeParser
|
|
||||||
from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
|
|
||||||
from llama_index.readers.file.markdown_parser import MarkdownParser
|
|
||||||
|
|
||||||
from core.data_source.notion import NotionPageReader
|
|
||||||
from core.index.readers.xlsx_parser import XLSXParser
|
|
||||||
from core.docstore.dataset_docstore import DatesetDocumentStore
|
from core.docstore.dataset_docstore import DatesetDocumentStore
|
||||||
from core.index.keyword_table_index import KeywordTableIndex
|
from core.embedding.cached_embedding import CacheEmbedding
|
||||||
from core.index.readers.html_parser import HTMLParser
|
from core.index.index import IndexBuilder
|
||||||
from core.index.readers.markdown_parser import MarkdownParser
|
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||||
from core.index.readers.pdf_parser import PDFParser
|
from core.index.vector_index.vector_index import VectorIndex
|
||||||
from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
from core.llm.error import ProviderTokenNotInitError
|
||||||
from core.index.vector_index import VectorIndex
|
from core.llm.llm_builder import LLMBuilder
|
||||||
|
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
||||||
from core.llm.token_calculator import TokenCalculator
|
from core.llm.token_calculator import TokenCalculator
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule
|
from libs import helper
|
||||||
|
from models.dataset import Document as DatasetDocument
|
||||||
|
from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
from models.source import DataSourceBinding
|
from models.source import DataSourceBinding
|
||||||
|
|
||||||
@ -40,135 +39,171 @@ class IndexingRunner:
|
|||||||
self.storage = storage
|
self.storage = storage
|
||||||
self.embedding_model_name = embedding_model_name
|
self.embedding_model_name = embedding_model_name
|
||||||
|
|
||||||
def run(self, documents: List[Document]):
|
def run(self, dataset_documents: List[DatasetDocument]):
|
||||||
"""Run the indexing process."""
|
"""Run the indexing process."""
|
||||||
for document in documents:
|
for dataset_document in dataset_documents:
|
||||||
|
try:
|
||||||
|
# get dataset
|
||||||
|
dataset = Dataset.query.filter_by(
|
||||||
|
id=dataset_document.dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("no dataset found")
|
||||||
|
|
||||||
|
# load file
|
||||||
|
text_docs = self._load_data(dataset_document)
|
||||||
|
|
||||||
|
# get the process rule
|
||||||
|
processing_rule = db.session.query(DatasetProcessRule). \
|
||||||
|
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||||
|
first()
|
||||||
|
|
||||||
|
# get splitter
|
||||||
|
splitter = self._get_splitter(processing_rule)
|
||||||
|
|
||||||
|
# split to documents
|
||||||
|
documents = self._step_split(
|
||||||
|
text_docs=text_docs,
|
||||||
|
splitter=splitter,
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_document=dataset_document,
|
||||||
|
processing_rule=processing_rule
|
||||||
|
)
|
||||||
|
|
||||||
|
# build index
|
||||||
|
self._build_index(
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_document=dataset_document,
|
||||||
|
documents=documents
|
||||||
|
)
|
||||||
|
except DocumentIsPausedException:
|
||||||
|
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
|
||||||
|
except ProviderTokenNotInitError as e:
|
||||||
|
dataset_document.indexing_status = 'error'
|
||||||
|
dataset_document.error = str(e.description)
|
||||||
|
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||||
|
db.session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("consume document failed")
|
||||||
|
dataset_document.indexing_status = 'error'
|
||||||
|
dataset_document.error = str(e)
|
||||||
|
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
||||||
|
"""Run the indexing process when the index_status is splitting."""
|
||||||
|
try:
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = Dataset.query.filter_by(
|
dataset = Dataset.query.filter_by(
|
||||||
id=document.dataset_id
|
id=dataset_document.dataset_id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("no dataset found")
|
raise ValueError("no dataset found")
|
||||||
|
|
||||||
|
# get exist document_segment list and delete
|
||||||
|
document_segments = DocumentSegment.query.filter_by(
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
document_id=dataset_document.id
|
||||||
|
).all()
|
||||||
|
|
||||||
|
db.session.delete(document_segments)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
# load file
|
# load file
|
||||||
text_docs = self._load_data(document)
|
text_docs = self._load_data(dataset_document)
|
||||||
|
|
||||||
# get the process rule
|
# get the process rule
|
||||||
processing_rule = db.session.query(DatasetProcessRule). \
|
processing_rule = db.session.query(DatasetProcessRule). \
|
||||||
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
|
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||||
first()
|
first()
|
||||||
|
|
||||||
# get node parser for splitting
|
# get splitter
|
||||||
node_parser = self._get_node_parser(processing_rule)
|
splitter = self._get_splitter(processing_rule)
|
||||||
|
|
||||||
# split to nodes
|
# split to documents
|
||||||
nodes = self._step_split(
|
documents = self._step_split(
|
||||||
text_docs=text_docs,
|
text_docs=text_docs,
|
||||||
node_parser=node_parser,
|
splitter=splitter,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
document=document,
|
dataset_document=dataset_document,
|
||||||
processing_rule=processing_rule
|
processing_rule=processing_rule
|
||||||
)
|
)
|
||||||
|
|
||||||
# build index
|
# build index
|
||||||
self._build_index(
|
self._build_index(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
document=document,
|
dataset_document=dataset_document,
|
||||||
nodes=nodes
|
documents=documents
|
||||||
)
|
)
|
||||||
|
except DocumentIsPausedException:
|
||||||
|
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
|
||||||
|
except ProviderTokenNotInitError as e:
|
||||||
|
dataset_document.indexing_status = 'error'
|
||||||
|
dataset_document.error = str(e.description)
|
||||||
|
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||||
|
db.session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("consume document failed")
|
||||||
|
dataset_document.indexing_status = 'error'
|
||||||
|
dataset_document.error = str(e)
|
||||||
|
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
def run_in_splitting_status(self, document: Document):
|
def run_in_indexing_status(self, dataset_document: DatasetDocument):
|
||||||
"""Run the indexing process when the index_status is splitting."""
|
|
||||||
# get dataset
|
|
||||||
dataset = Dataset.query.filter_by(
|
|
||||||
id=document.dataset_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not dataset:
|
|
||||||
raise ValueError("no dataset found")
|
|
||||||
|
|
||||||
# get exist document_segment list and delete
|
|
||||||
document_segments = DocumentSegment.query.filter_by(
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
document_id=document.id
|
|
||||||
).all()
|
|
||||||
db.session.delete(document_segments)
|
|
||||||
db.session.commit()
|
|
||||||
# load file
|
|
||||||
text_docs = self._load_data(document)
|
|
||||||
|
|
||||||
# get the process rule
|
|
||||||
processing_rule = db.session.query(DatasetProcessRule). \
|
|
||||||
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
|
|
||||||
first()
|
|
||||||
|
|
||||||
# get node parser for splitting
|
|
||||||
node_parser = self._get_node_parser(processing_rule)
|
|
||||||
|
|
||||||
# split to nodes
|
|
||||||
nodes = self._step_split(
|
|
||||||
text_docs=text_docs,
|
|
||||||
node_parser=node_parser,
|
|
||||||
dataset=dataset,
|
|
||||||
document=document,
|
|
||||||
processing_rule=processing_rule
|
|
||||||
)
|
|
||||||
|
|
||||||
# build index
|
|
||||||
self._build_index(
|
|
||||||
dataset=dataset,
|
|
||||||
document=document,
|
|
||||||
nodes=nodes
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_in_indexing_status(self, document: Document):
|
|
||||||
"""Run the indexing process when the index_status is indexing."""
|
"""Run the indexing process when the index_status is indexing."""
|
||||||
# get dataset
|
try:
|
||||||
dataset = Dataset.query.filter_by(
|
# get dataset
|
||||||
id=document.dataset_id
|
dataset = Dataset.query.filter_by(
|
||||||
).first()
|
id=dataset_document.dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("no dataset found")
|
raise ValueError("no dataset found")
|
||||||
|
|
||||||
# get exist document_segment list and delete
|
# get exist document_segment list and delete
|
||||||
document_segments = DocumentSegment.query.filter_by(
|
document_segments = DocumentSegment.query.filter_by(
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
document_id=document.id
|
document_id=dataset_document.id
|
||||||
).all()
|
).all()
|
||||||
nodes = []
|
|
||||||
if document_segments:
|
|
||||||
for document_segment in document_segments:
|
|
||||||
# transform segment to node
|
|
||||||
if document_segment.status != "completed":
|
|
||||||
relationships = {
|
|
||||||
DocumentRelationship.SOURCE: document_segment.document_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
previous_segment = document_segment.previous_segment
|
documents = []
|
||||||
if previous_segment:
|
if document_segments:
|
||||||
relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
|
for document_segment in document_segments:
|
||||||
|
# transform segment to node
|
||||||
|
if document_segment.status != "completed":
|
||||||
|
document = Document(
|
||||||
|
page_content=document_segment.content,
|
||||||
|
metadata={
|
||||||
|
"doc_id": document_segment.index_node_id,
|
||||||
|
"doc_hash": document_segment.index_node_hash,
|
||||||
|
"document_id": document_segment.document_id,
|
||||||
|
"dataset_id": document_segment.dataset_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
next_segment = document_segment.next_segment
|
documents.append(document)
|
||||||
if next_segment:
|
|
||||||
relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
|
|
||||||
node = Node(
|
|
||||||
doc_id=document_segment.index_node_id,
|
|
||||||
doc_hash=document_segment.index_node_hash,
|
|
||||||
text=document_segment.content,
|
|
||||||
extra_info=None,
|
|
||||||
node_info=None,
|
|
||||||
relationships=relationships
|
|
||||||
)
|
|
||||||
nodes.append(node)
|
|
||||||
|
|
||||||
# build index
|
# build index
|
||||||
self._build_index(
|
self._build_index(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
document=document,
|
dataset_document=dataset_document,
|
||||||
nodes=nodes
|
documents=documents
|
||||||
)
|
)
|
||||||
|
except DocumentIsPausedException:
|
||||||
|
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
|
||||||
|
except ProviderTokenNotInitError as e:
|
||||||
|
dataset_document.indexing_status = 'error'
|
||||||
|
dataset_document.error = str(e.description)
|
||||||
|
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||||
|
db.session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("consume document failed")
|
||||||
|
dataset_document.indexing_status = 'error'
|
||||||
|
dataset_document.error = str(e)
|
||||||
|
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
|
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
@ -179,28 +214,28 @@ class IndexingRunner:
|
|||||||
total_segments = 0
|
total_segments = 0
|
||||||
for file_detail in file_details:
|
for file_detail in file_details:
|
||||||
# load data from file
|
# load data from file
|
||||||
text_docs = self._load_data_from_file(file_detail)
|
text_docs = FileExtractor.load(file_detail)
|
||||||
|
|
||||||
processing_rule = DatasetProcessRule(
|
processing_rule = DatasetProcessRule(
|
||||||
mode=tmp_processing_rule["mode"],
|
mode=tmp_processing_rule["mode"],
|
||||||
rules=json.dumps(tmp_processing_rule["rules"])
|
rules=json.dumps(tmp_processing_rule["rules"])
|
||||||
)
|
)
|
||||||
|
|
||||||
# get node parser for splitting
|
# get splitter
|
||||||
node_parser = self._get_node_parser(processing_rule)
|
splitter = self._get_splitter(processing_rule)
|
||||||
|
|
||||||
# split to nodes
|
# split to documents
|
||||||
nodes = self._split_to_nodes(
|
documents = self._split_to_documents(
|
||||||
text_docs=text_docs,
|
text_docs=text_docs,
|
||||||
node_parser=node_parser,
|
splitter=splitter,
|
||||||
processing_rule=processing_rule
|
processing_rule=processing_rule
|
||||||
)
|
)
|
||||||
total_segments += len(nodes)
|
total_segments += len(documents)
|
||||||
for node in nodes:
|
for document in documents:
|
||||||
if len(preview_texts) < 5:
|
if len(preview_texts) < 5:
|
||||||
preview_texts.append(node.get_text())
|
preview_texts.append(document.page_content)
|
||||||
|
|
||||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
|
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_segments": total_segments,
|
"total_segments": total_segments,
|
||||||
@ -230,35 +265,36 @@ class IndexingRunner:
|
|||||||
).first()
|
).first()
|
||||||
if not data_source_binding:
|
if not data_source_binding:
|
||||||
raise ValueError('Data source binding not found.')
|
raise ValueError('Data source binding not found.')
|
||||||
reader = NotionPageReader(integration_token=data_source_binding.access_token)
|
|
||||||
for page in notion_info['pages']:
|
for page in notion_info['pages']:
|
||||||
if page['type'] == 'page':
|
loader = NotionLoader(
|
||||||
page_ids = [page['page_id']]
|
notion_access_token=data_source_binding.access_token,
|
||||||
documents = reader.load_data_as_documents(page_ids=page_ids)
|
notion_workspace_id=workspace_id,
|
||||||
elif page['type'] == 'database':
|
notion_obj_id=page['page_id'],
|
||||||
documents = reader.load_data_as_documents(database_id=page['page_id'])
|
notion_page_type=page['type']
|
||||||
else:
|
)
|
||||||
documents = []
|
documents = loader.load()
|
||||||
|
|
||||||
processing_rule = DatasetProcessRule(
|
processing_rule = DatasetProcessRule(
|
||||||
mode=tmp_processing_rule["mode"],
|
mode=tmp_processing_rule["mode"],
|
||||||
rules=json.dumps(tmp_processing_rule["rules"])
|
rules=json.dumps(tmp_processing_rule["rules"])
|
||||||
)
|
)
|
||||||
|
|
||||||
# get node parser for splitting
|
# get splitter
|
||||||
node_parser = self._get_node_parser(processing_rule)
|
splitter = self._get_splitter(processing_rule)
|
||||||
|
|
||||||
# split to nodes
|
# split to documents
|
||||||
nodes = self._split_to_nodes(
|
documents = self._split_to_documents(
|
||||||
text_docs=documents,
|
text_docs=documents,
|
||||||
node_parser=node_parser,
|
splitter=splitter,
|
||||||
processing_rule=processing_rule
|
processing_rule=processing_rule
|
||||||
)
|
)
|
||||||
total_segments += len(nodes)
|
total_segments += len(documents)
|
||||||
for node in nodes:
|
for document in documents:
|
||||||
if len(preview_texts) < 5:
|
if len(preview_texts) < 5:
|
||||||
preview_texts.append(node.get_text())
|
preview_texts.append(document.page_content)
|
||||||
|
|
||||||
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
|
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_segments": total_segments,
|
"total_segments": total_segments,
|
||||||
@ -268,14 +304,14 @@ class IndexingRunner:
|
|||||||
"preview": preview_texts
|
"preview": preview_texts
|
||||||
}
|
}
|
||||||
|
|
||||||
def _load_data(self, document: Document) -> List[Document]:
|
def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
|
||||||
# load file
|
# load file
|
||||||
if document.data_source_type not in ["upload_file", "notion_import"]:
|
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
data_source_info = document.data_source_info_dict
|
data_source_info = dataset_document.data_source_info_dict
|
||||||
text_docs = []
|
text_docs = []
|
||||||
if document.data_source_type == 'upload_file':
|
if dataset_document.data_source_type == 'upload_file':
|
||||||
if not data_source_info or 'upload_file_id' not in data_source_info:
|
if not data_source_info or 'upload_file_id' not in data_source_info:
|
||||||
raise ValueError("no upload file found")
|
raise ValueError("no upload file found")
|
||||||
|
|
||||||
@ -283,47 +319,28 @@ class IndexingRunner:
|
|||||||
filter(UploadFile.id == data_source_info['upload_file_id']). \
|
filter(UploadFile.id == data_source_info['upload_file_id']). \
|
||||||
one_or_none()
|
one_or_none()
|
||||||
|
|
||||||
text_docs = self._load_data_from_file(file_detail)
|
text_docs = FileExtractor.load(file_detail)
|
||||||
elif document.data_source_type == 'notion_import':
|
elif dataset_document.data_source_type == 'notion_import':
|
||||||
if not data_source_info or 'notion_page_id' not in data_source_info \
|
loader = NotionLoader.from_document(dataset_document)
|
||||||
or 'notion_workspace_id' not in data_source_info:
|
text_docs = loader.load()
|
||||||
raise ValueError("no notion page found")
|
|
||||||
workspace_id = data_source_info['notion_workspace_id']
|
|
||||||
page_id = data_source_info['notion_page_id']
|
|
||||||
page_type = data_source_info['type']
|
|
||||||
data_source_binding = DataSourceBinding.query.filter(
|
|
||||||
db.and_(
|
|
||||||
DataSourceBinding.tenant_id == document.tenant_id,
|
|
||||||
DataSourceBinding.provider == 'notion',
|
|
||||||
DataSourceBinding.disabled == False,
|
|
||||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
|
||||||
)
|
|
||||||
).first()
|
|
||||||
if not data_source_binding:
|
|
||||||
raise ValueError('Data source binding not found.')
|
|
||||||
if page_type == 'page':
|
|
||||||
# add page last_edited_time to data_source_info
|
|
||||||
self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document)
|
|
||||||
text_docs = self._load_page_data_from_notion(page_id, data_source_binding.access_token)
|
|
||||||
elif page_type == 'database':
|
|
||||||
# add page last_edited_time to data_source_info
|
|
||||||
self._get_notion_database_last_edited_time(page_id, data_source_binding.access_token, document)
|
|
||||||
text_docs = self._load_database_data_from_notion(page_id, data_source_binding.access_token)
|
|
||||||
# update document status to splitting
|
# update document status to splitting
|
||||||
self._update_document_index_status(
|
self._update_document_index_status(
|
||||||
document_id=document.id,
|
document_id=dataset_document.id,
|
||||||
after_indexing_status="splitting",
|
after_indexing_status="splitting",
|
||||||
extra_update_params={
|
extra_update_params={
|
||||||
Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
|
DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
|
||||||
Document.parsing_completed_at: datetime.datetime.utcnow()
|
DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# replace doc id to document model id
|
# replace doc id to document model id
|
||||||
|
text_docs = cast(List[Document], text_docs)
|
||||||
for text_doc in text_docs:
|
for text_doc in text_docs:
|
||||||
# remove invalid symbol
|
# remove invalid symbol
|
||||||
text_doc.text = self.filter_string(text_doc.get_text())
|
text_doc.page_content = self.filter_string(text_doc.page_content)
|
||||||
text_doc.doc_id = document.id
|
text_doc.metadata['document_id'] = dataset_document.id
|
||||||
|
text_doc.metadata['dataset_id'] = dataset_document.dataset_id
|
||||||
|
|
||||||
return text_docs
|
return text_docs
|
||||||
|
|
||||||
@ -331,61 +348,7 @@ class IndexingRunner:
|
|||||||
pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
|
pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
|
||||||
return pattern.sub('', text)
|
return pattern.sub('', text)
|
||||||
|
|
||||||
def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]:
|
def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
suffix = Path(upload_file.key).suffix
|
|
||||||
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
|
||||||
self.storage.download(upload_file.key, filepath)
|
|
||||||
|
|
||||||
file_extractor = DEFAULT_FILE_EXTRACTOR.copy()
|
|
||||||
file_extractor[".markdown"] = MarkdownParser()
|
|
||||||
file_extractor[".md"] = MarkdownParser()
|
|
||||||
file_extractor[".html"] = HTMLParser()
|
|
||||||
file_extractor[".htm"] = HTMLParser()
|
|
||||||
file_extractor[".pdf"] = PDFParser({'upload_file': upload_file})
|
|
||||||
file_extractor[".xlsx"] = XLSXParser()
|
|
||||||
|
|
||||||
loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor)
|
|
||||||
text_docs = loader.load_data()
|
|
||||||
|
|
||||||
return text_docs
|
|
||||||
|
|
||||||
def _load_page_data_from_notion(self, page_id: str, access_token: str) -> List[Document]:
|
|
||||||
page_ids = [page_id]
|
|
||||||
reader = NotionPageReader(integration_token=access_token)
|
|
||||||
text_docs = reader.load_data_as_documents(page_ids=page_ids)
|
|
||||||
return text_docs
|
|
||||||
|
|
||||||
def _load_database_data_from_notion(self, database_id: str, access_token: str) -> List[Document]:
|
|
||||||
reader = NotionPageReader(integration_token=access_token)
|
|
||||||
text_docs = reader.load_data_as_documents(database_id=database_id)
|
|
||||||
return text_docs
|
|
||||||
|
|
||||||
def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document):
|
|
||||||
reader = NotionPageReader(integration_token=access_token)
|
|
||||||
last_edited_time = reader.get_page_last_edited_time(page_id)
|
|
||||||
data_source_info = document.data_source_info_dict
|
|
||||||
data_source_info['last_edited_time'] = last_edited_time
|
|
||||||
update_params = {
|
|
||||||
Document.data_source_info: json.dumps(data_source_info)
|
|
||||||
}
|
|
||||||
|
|
||||||
Document.query.filter_by(id=document.id).update(update_params)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def _get_notion_database_last_edited_time(self, page_id: str, access_token: str, document: Document):
|
|
||||||
reader = NotionPageReader(integration_token=access_token)
|
|
||||||
last_edited_time = reader.get_database_last_edited_time(page_id)
|
|
||||||
data_source_info = document.data_source_info_dict
|
|
||||||
data_source_info['last_edited_time'] = last_edited_time
|
|
||||||
update_params = {
|
|
||||||
Document.data_source_info: json.dumps(data_source_info)
|
|
||||||
}
|
|
||||||
|
|
||||||
Document.query.filter_by(id=document.id).update(update_params)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
|
|
||||||
"""
|
"""
|
||||||
Get the NodeParser object according to the processing rule.
|
Get the NodeParser object according to the processing rule.
|
||||||
"""
|
"""
|
||||||
@ -414,68 +377,83 @@ class IndexingRunner:
|
|||||||
separators=["\n\n", "。", ".", " ", ""]
|
separators=["\n\n", "。", ".", " ", ""]
|
||||||
)
|
)
|
||||||
|
|
||||||
return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True)
|
return character_splitter
|
||||||
|
|
||||||
def _step_split(self, text_docs: List[Document], node_parser: NodeParser,
|
def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
|
||||||
dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]:
|
dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
|
||||||
|
-> List[Document]:
|
||||||
"""
|
"""
|
||||||
Split the text documents into nodes and save them to the document segment.
|
Split the text documents into documents and save them to the document segment.
|
||||||
"""
|
"""
|
||||||
nodes = self._split_to_nodes(
|
documents = self._split_to_documents(
|
||||||
text_docs=text_docs,
|
text_docs=text_docs,
|
||||||
node_parser=node_parser,
|
splitter=splitter,
|
||||||
processing_rule=processing_rule
|
processing_rule=processing_rule
|
||||||
)
|
)
|
||||||
|
|
||||||
# save node to document segment
|
# save node to document segment
|
||||||
doc_store = DatesetDocumentStore(
|
doc_store = DatesetDocumentStore(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
user_id=document.created_by,
|
user_id=dataset_document.created_by,
|
||||||
embedding_model_name=self.embedding_model_name,
|
embedding_model_name=self.embedding_model_name,
|
||||||
document_id=document.id
|
document_id=dataset_document.id
|
||||||
)
|
)
|
||||||
|
|
||||||
# add document segments
|
# add document segments
|
||||||
doc_store.add_documents(nodes)
|
doc_store.add_documents(documents)
|
||||||
|
|
||||||
# update document status to indexing
|
# update document status to indexing
|
||||||
cur_time = datetime.datetime.utcnow()
|
cur_time = datetime.datetime.utcnow()
|
||||||
self._update_document_index_status(
|
self._update_document_index_status(
|
||||||
document_id=document.id,
|
document_id=dataset_document.id,
|
||||||
after_indexing_status="indexing",
|
after_indexing_status="indexing",
|
||||||
extra_update_params={
|
extra_update_params={
|
||||||
Document.cleaning_completed_at: cur_time,
|
DatasetDocument.cleaning_completed_at: cur_time,
|
||||||
Document.splitting_completed_at: cur_time,
|
DatasetDocument.splitting_completed_at: cur_time,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# update segment status to indexing
|
# update segment status to indexing
|
||||||
self._update_segments_by_document(
|
self._update_segments_by_document(
|
||||||
document_id=document.id,
|
dataset_document_id=dataset_document.id,
|
||||||
update_params={
|
update_params={
|
||||||
DocumentSegment.status: "indexing",
|
DocumentSegment.status: "indexing",
|
||||||
DocumentSegment.indexing_at: datetime.datetime.utcnow()
|
DocumentSegment.indexing_at: datetime.datetime.utcnow()
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return nodes
|
return documents
|
||||||
|
|
||||||
def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser,
|
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
|
||||||
processing_rule: DatasetProcessRule) -> List[Node]:
|
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
Split the text documents into nodes.
|
Split the text documents into nodes.
|
||||||
"""
|
"""
|
||||||
all_nodes = []
|
all_documents = []
|
||||||
for text_doc in text_docs:
|
for text_doc in text_docs:
|
||||||
# document clean
|
# document clean
|
||||||
document_text = self._document_clean(text_doc.get_text(), processing_rule)
|
document_text = self._document_clean(text_doc.page_content, processing_rule)
|
||||||
text_doc.text = document_text
|
text_doc.page_content = document_text
|
||||||
|
|
||||||
# parse document to nodes
|
# parse document to nodes
|
||||||
nodes = node_parser.get_nodes_from_documents([text_doc])
|
documents = splitter.split_documents([text_doc])
|
||||||
nodes = [node for node in nodes if node.text is not None and node.text.strip()]
|
|
||||||
all_nodes.extend(nodes)
|
|
||||||
|
|
||||||
return all_nodes
|
split_documents = []
|
||||||
|
for document in documents:
|
||||||
|
if document.page_content is None or not document.page_content.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc_id = str(uuid.uuid4())
|
||||||
|
hash = helper.generate_text_hash(document.page_content)
|
||||||
|
|
||||||
|
document.metadata['doc_id'] = doc_id
|
||||||
|
document.metadata['doc_hash'] = hash
|
||||||
|
|
||||||
|
split_documents.append(document)
|
||||||
|
|
||||||
|
all_documents.extend(split_documents)
|
||||||
|
|
||||||
|
return all_documents
|
||||||
|
|
||||||
def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
|
def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
|
||||||
"""
|
"""
|
||||||
@ -506,37 +484,38 @@ class IndexingRunner:
|
|||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None:
|
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
|
||||||
"""
|
"""
|
||||||
Build the index for the document.
|
Build the index for the document.
|
||||||
"""
|
"""
|
||||||
vector_index = VectorIndex(dataset=dataset)
|
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
|
|
||||||
# chunk nodes by chunk size
|
# chunk nodes by chunk size
|
||||||
indexing_start_at = time.perf_counter()
|
indexing_start_at = time.perf_counter()
|
||||||
tokens = 0
|
tokens = 0
|
||||||
chunk_size = 100
|
chunk_size = 100
|
||||||
for i in range(0, len(nodes), chunk_size):
|
for i in range(0, len(documents), chunk_size):
|
||||||
# check document is paused
|
# check document is paused
|
||||||
self._check_document_paused_status(document.id)
|
self._check_document_paused_status(dataset_document.id)
|
||||||
chunk_nodes = nodes[i:i + chunk_size]
|
chunk_documents = documents[i:i + chunk_size]
|
||||||
|
|
||||||
tokens += sum(
|
tokens += sum(
|
||||||
TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes
|
TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
|
||||||
|
for document in chunk_documents
|
||||||
)
|
)
|
||||||
|
|
||||||
# save vector index
|
# save vector index
|
||||||
if dataset.indexing_technique == "high_quality":
|
if vector_index:
|
||||||
vector_index.add_nodes(chunk_nodes)
|
vector_index.add_texts(chunk_documents)
|
||||||
|
|
||||||
# save keyword index
|
# save keyword index
|
||||||
keyword_table_index.add_nodes(chunk_nodes)
|
keyword_table_index.add_texts(chunk_documents)
|
||||||
|
|
||||||
node_ids = [node.doc_id for node in chunk_nodes]
|
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
|
||||||
db.session.query(DocumentSegment).filter(
|
db.session.query(DocumentSegment).filter(
|
||||||
DocumentSegment.document_id == document.id,
|
DocumentSegment.document_id == dataset_document.id,
|
||||||
DocumentSegment.index_node_id.in_(node_ids),
|
DocumentSegment.index_node_id.in_(document_ids),
|
||||||
DocumentSegment.status == "indexing"
|
DocumentSegment.status == "indexing"
|
||||||
).update({
|
).update({
|
||||||
DocumentSegment.status: "completed",
|
DocumentSegment.status: "completed",
|
||||||
@ -549,12 +528,12 @@ class IndexingRunner:
|
|||||||
|
|
||||||
# update document status to completed
|
# update document status to completed
|
||||||
self._update_document_index_status(
|
self._update_document_index_status(
|
||||||
document_id=document.id,
|
document_id=dataset_document.id,
|
||||||
after_indexing_status="completed",
|
after_indexing_status="completed",
|
||||||
extra_update_params={
|
extra_update_params={
|
||||||
Document.tokens: tokens,
|
DatasetDocument.tokens: tokens,
|
||||||
Document.completed_at: datetime.datetime.utcnow(),
|
DatasetDocument.completed_at: datetime.datetime.utcnow(),
|
||||||
Document.indexing_latency: indexing_end_at - indexing_start_at,
|
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -569,25 +548,25 @@ class IndexingRunner:
|
|||||||
"""
|
"""
|
||||||
Update the document indexing status.
|
Update the document indexing status.
|
||||||
"""
|
"""
|
||||||
count = Document.query.filter_by(id=document_id, is_paused=True).count()
|
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
|
||||||
if count > 0:
|
if count > 0:
|
||||||
raise DocumentIsPausedException()
|
raise DocumentIsPausedException()
|
||||||
|
|
||||||
update_params = {
|
update_params = {
|
||||||
Document.indexing_status: after_indexing_status
|
DatasetDocument.indexing_status: after_indexing_status
|
||||||
}
|
}
|
||||||
|
|
||||||
if extra_update_params:
|
if extra_update_params:
|
||||||
update_params.update(extra_update_params)
|
update_params.update(extra_update_params)
|
||||||
|
|
||||||
Document.query.filter_by(id=document_id).update(update_params)
|
DatasetDocument.query.filter_by(id=document_id).update(update_params)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def _update_segments_by_document(self, document_id: str, update_params: dict) -> None:
|
def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Update the document segment by document id.
|
Update the document segment by document id.
|
||||||
"""
|
"""
|
||||||
DocumentSegment.query.filter_by(document_id=document_id).update(update_params)
|
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from typing import Union, Optional
|
from typing import Union, Optional, List
|
||||||
|
|
||||||
from langchain.callbacks import CallbackManager
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.llms.fake import FakeListLLM
|
|
||||||
|
|
||||||
from core.constant import llm_constant
|
from core.constant import llm_constant
|
||||||
from core.llm.error import ProviderTokenNotInitError
|
from core.llm.error import ProviderTokenNotInitError
|
||||||
@ -32,12 +31,11 @@ class LLMBuilder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
|
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||||
if model_name == 'fake':
|
|
||||||
return FakeListLLM(responses=[])
|
|
||||||
|
|
||||||
provider = cls.get_default_provider(tenant_id)
|
provider = cls.get_default_provider(tenant_id)
|
||||||
|
|
||||||
|
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
||||||
|
|
||||||
mode = cls.get_mode_by_model(model_name)
|
mode = cls.get_mode_by_model(model_name)
|
||||||
if mode == 'chat':
|
if mode == 'chat':
|
||||||
if provider == 'openai':
|
if provider == 'openai':
|
||||||
@ -52,16 +50,21 @@ class LLMBuilder:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"model name {model_name} is not supported.")
|
raise ValueError(f"model name {model_name} is not supported.")
|
||||||
|
|
||||||
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
|
||||||
|
model_kwargs = {
|
||||||
|
'top_p': kwargs.get('top_p', 1),
|
||||||
|
'frequency_penalty': kwargs.get('frequency_penalty', 0),
|
||||||
|
'presence_penalty': kwargs.get('presence_penalty', 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
|
||||||
|
|
||||||
return llm_cls(
|
return llm_cls(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=kwargs.get('temperature', 0),
|
temperature=kwargs.get('temperature', 0),
|
||||||
max_tokens=kwargs.get('max_tokens', 256),
|
max_tokens=kwargs.get('max_tokens', 256),
|
||||||
top_p=kwargs.get('top_p', 1),
|
**model_extras_kwargs,
|
||||||
frequency_penalty=kwargs.get('frequency_penalty', 0),
|
callbacks=kwargs.get('callbacks', None),
|
||||||
presence_penalty=kwargs.get('presence_penalty', 0),
|
|
||||||
callback_manager=kwargs.get('callback_manager', None),
|
|
||||||
streaming=kwargs.get('streaming', False),
|
streaming=kwargs.get('streaming', False),
|
||||||
# request_timeout=None
|
# request_timeout=None
|
||||||
**model_credentials
|
**model_credentials
|
||||||
@ -69,7 +72,7 @@ class LLMBuilder:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
||||||
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||||
model_name = model.get("name")
|
model_name = model.get("name")
|
||||||
completion_params = model.get("completion_params", {})
|
completion_params = model.get("completion_params", {})
|
||||||
|
|
||||||
@ -82,7 +85,7 @@ class LLMBuilder:
|
|||||||
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
||||||
presence_penalty=completion_params.get('presence_penalty', 0.1),
|
presence_penalty=completion_params.get('presence_penalty', 0.1),
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
callback_manager=callback_manager
|
callbacks=callbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -42,7 +42,10 @@ class AzureProvider(BaseProvider):
|
|||||||
"""
|
"""
|
||||||
config = self.get_provider_api_key(model_id=model_id)
|
config = self.get_provider_api_key(model_id=model_id)
|
||||||
config['openai_api_type'] = 'azure'
|
config['openai_api_type'] = 'azure'
|
||||||
config['deployment_name'] = model_id.replace('.', '') if model_id else None
|
if model_id == 'text-embedding-ada-002':
|
||||||
|
config['deployment'] = model_id.replace('.', '') if model_id else None
|
||||||
|
else:
|
||||||
|
config['deployment_name'] = model_id.replace('.', '') if model_id else None
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def get_provider_name(self):
|
def get_provider_name(self):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
|
||||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||||
from langchain.chat_models import AzureChatOpenAI
|
from langchain.chat_models import AzureChatOpenAI
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
|||||||
|
|
||||||
return message_tokens
|
return message_tokens
|
||||||
|
|
||||||
def _generate(
|
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
||||||
) -> ChatResult:
|
|
||||||
self.callback_manager.on_llm_start(
|
|
||||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
|
|
||||||
verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_result = super()._generate(messages, stop)
|
|
||||||
|
|
||||||
result = LLMResult(
|
|
||||||
generations=[chat_result.generations],
|
|
||||||
llm_output=chat_result.llm_output
|
|
||||||
)
|
|
||||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
|
||||||
|
|
||||||
return chat_result
|
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
||||||
) -> ChatResult:
|
|
||||||
if self.callback_manager.is_async:
|
|
||||||
await self.callback_manager.on_llm_start(
|
|
||||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
|
|
||||||
verbose=self.verbose
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.callback_manager.on_llm_start(
|
|
||||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
|
|
||||||
verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_result = super()._generate(messages, stop)
|
|
||||||
|
|
||||||
result = LLMResult(
|
|
||||||
generations=[chat_result.generations],
|
|
||||||
llm_output=chat_result.llm_output
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.callback_manager.is_async:
|
|
||||||
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
|
||||||
else:
|
|
||||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
|
||||||
|
|
||||||
return chat_result
|
|
||||||
|
|
||||||
@handle_llm_exceptions
|
@handle_llm_exceptions
|
||||||
def generate(
|
def generate(
|
||||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
self,
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return super().generate(messages, stop)
|
return super().generate(messages, stop, callbacks, **kwargs)
|
||||||
|
|
||||||
@handle_llm_exceptions_async
|
@handle_llm_exceptions_async
|
||||||
async def agenerate(
|
async def agenerate(
|
||||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
self,
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return await super().agenerate(messages, stop)
|
return await super().agenerate(messages, stop, callbacks, **kwargs)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import os
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
|
||||||
from langchain.llms import AzureOpenAI
|
from langchain.llms import AzureOpenAI
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
from typing import Optional, List, Dict, Mapping, Any
|
from typing import Optional, List, Dict, Mapping, Any
|
||||||
@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI):
|
|||||||
|
|
||||||
@handle_llm_exceptions
|
@handle_llm_exceptions
|
||||||
def generate(
|
def generate(
|
||||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return super().generate(prompts, stop)
|
return super().generate(prompts, stop, callbacks, **kwargs)
|
||||||
|
|
||||||
@handle_llm_exceptions_async
|
@handle_llm_exceptions_async
|
||||||
async def agenerate(
|
async def agenerate(
|
||||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return await super().agenerate(prompts, stop)
|
return await super().agenerate(prompts, stop, callbacks, **kwargs)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.schema import BaseMessage, LLMResult
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI):
|
|||||||
|
|
||||||
return message_tokens
|
return message_tokens
|
||||||
|
|
||||||
def _generate(
|
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
||||||
) -> ChatResult:
|
|
||||||
self.callback_manager.on_llm_start(
|
|
||||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_result = super()._generate(messages, stop)
|
|
||||||
|
|
||||||
result = LLMResult(
|
|
||||||
generations=[chat_result.generations],
|
|
||||||
llm_output=chat_result.llm_output
|
|
||||||
)
|
|
||||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
|
||||||
|
|
||||||
return chat_result
|
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
||||||
) -> ChatResult:
|
|
||||||
if self.callback_manager.is_async:
|
|
||||||
await self.callback_manager.on_llm_start(
|
|
||||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.callback_manager.on_llm_start(
|
|
||||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_result = super()._generate(messages, stop)
|
|
||||||
|
|
||||||
result = LLMResult(
|
|
||||||
generations=[chat_result.generations],
|
|
||||||
llm_output=chat_result.llm_output
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.callback_manager.is_async:
|
|
||||||
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
|
||||||
else:
|
|
||||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
|
||||||
|
|
||||||
return chat_result
|
|
||||||
|
|
||||||
@handle_llm_exceptions
|
@handle_llm_exceptions
|
||||||
def generate(
|
def generate(
|
||||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
self,
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return super().generate(messages, stop)
|
return super().generate(messages, stop, callbacks, **kwargs)
|
||||||
|
|
||||||
@handle_llm_exceptions_async
|
@handle_llm_exceptions_async
|
||||||
async def agenerate(
|
async def agenerate(
|
||||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
self,
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return await super().agenerate(messages, stop)
|
return await super().agenerate(messages, stop, callbacks, **kwargs)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
from typing import Optional, List, Dict, Any, Mapping
|
from typing import Optional, List, Dict, Any, Mapping
|
||||||
from langchain import OpenAI
|
from langchain import OpenAI
|
||||||
@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI):
|
|||||||
"organization": self.openai_organization if self.openai_organization else None,
|
"organization": self.openai_organization if self.openai_organization else None,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
|
||||||
@handle_llm_exceptions
|
@handle_llm_exceptions
|
||||||
def generate(
|
def generate(
|
||||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return super().generate(prompts, stop)
|
return super().generate(prompts, stop, callbacks, **kwargs)
|
||||||
|
|
||||||
@handle_llm_exceptions_async
|
@handle_llm_exceptions_async
|
||||||
async def agenerate(
|
async def agenerate(
|
||||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return await super().agenerate(prompts, stop)
|
return await super().agenerate(prompts, stop, callbacks, **kwargs)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Any, List, Dict
|
from typing import Any, List, Dict
|
||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel
|
from langchain.schema import get_buffer_string
|
||||||
|
|
||||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from llama_index import QueryKeywordExtractPrompt
|
|
||||||
|
|
||||||
CONVERSATION_TITLE_PROMPT = (
|
CONVERSATION_TITLE_PROMPT = (
|
||||||
"Human:{query}\n-----\n"
|
"Human:{query}\n-----\n"
|
||||||
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
|
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
|
||||||
@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
|||||||
"[\"question1\",\"question2\",\"question3\"]\n"
|
"[\"question1\",\"question2\",\"question3\"]\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
|
|
||||||
"A question is provided below. Given the question, extract up to {max_keywords} "
|
|
||||||
"keywords from the text. Focus on extracting the keywords that we can use "
|
|
||||||
"to best lookup answers to the question. Avoid stopwords."
|
|
||||||
"I am not sure which language the following question is in. "
|
|
||||||
"If the user asked the question in Chinese, please return the keywords in Chinese. "
|
|
||||||
"If the user asked the question in English, please return the keywords in English.\n"
|
|
||||||
"---------------------\n"
|
|
||||||
"{question}\n"
|
|
||||||
"---------------------\n"
|
|
||||||
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt(
|
|
||||||
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
|
|
||||||
)
|
|
||||||
|
|
||||||
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
|
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
|
||||||
the model prompt that best suits the input.
|
the model prompt that best suits the input.
|
||||||
You will be provided with the prompt, variables, and an opening statement.
|
You will be provided with the prompt, variables, and an opening statement.
|
||||||
|
87
api/core/tool/dataset_index_tool.py
Normal file
87
api/core/tool/dataset_index_tool.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from flask import current_app
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
|
from core.embedding.cached_embedding import CacheEmbedding
|
||||||
|
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||||
|
from core.index.vector_index.vector_index import VectorIndex
|
||||||
|
from core.llm.llm_builder import LLMBuilder
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTool(BaseTool):
|
||||||
|
"""Tool for querying a Dataset."""
|
||||||
|
|
||||||
|
dataset: Dataset
|
||||||
|
k: int = 2
|
||||||
|
|
||||||
|
def _run(self, tool_input: str) -> str:
|
||||||
|
if self.dataset.indexing_technique == "economy":
|
||||||
|
# use keyword table query
|
||||||
|
kw_table_index = KeywordTableIndex(
|
||||||
|
dataset=self.dataset,
|
||||||
|
config=KeywordTableConfig(
|
||||||
|
max_keywords_per_chunk=5
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
|
||||||
|
else:
|
||||||
|
model_credentials = LLMBuilder.get_model_credentials(
|
||||||
|
tenant_id=self.dataset.tenant_id,
|
||||||
|
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
|
||||||
|
model_name='text-embedding-ada-002'
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||||
|
**model_credentials
|
||||||
|
))
|
||||||
|
|
||||||
|
vector_index = VectorIndex(
|
||||||
|
dataset=self.dataset,
|
||||||
|
config=current_app.config,
|
||||||
|
embeddings=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
documents = vector_index.search(
|
||||||
|
tool_input,
|
||||||
|
search_type='similarity',
|
||||||
|
search_kwargs={
|
||||||
|
'k': self.k
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
|
||||||
|
hit_callback.on_tool_end(documents)
|
||||||
|
|
||||||
|
return str("\n".join([document.page_content for document in documents]))
|
||||||
|
|
||||||
|
async def _arun(self, tool_input: str) -> str:
|
||||||
|
model_credentials = LLMBuilder.get_model_credentials(
|
||||||
|
tenant_id=self.dataset.tenant_id,
|
||||||
|
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
|
||||||
|
model_name='text-embedding-ada-002'
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||||
|
**model_credentials
|
||||||
|
))
|
||||||
|
|
||||||
|
vector_index = VectorIndex(
|
||||||
|
dataset=self.dataset,
|
||||||
|
config=current_app.config,
|
||||||
|
embeddings=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
documents = await vector_index.asearch(
|
||||||
|
tool_input,
|
||||||
|
search_type='similarity',
|
||||||
|
search_kwargs={
|
||||||
|
'k': 10
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
|
||||||
|
hit_callback.on_tool_end(documents)
|
||||||
|
return str("\n".join([document.page_content for document in documents]))
|
@ -1,73 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from langchain.callbacks import CallbackManager
|
|
||||||
from llama_index.langchain_helpers.agents import IndexToolConfig
|
|
||||||
|
|
||||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
|
||||||
from core.index.keyword_table_index import KeywordTableIndex
|
|
||||||
from core.index.vector_index import VectorIndex
|
|
||||||
from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
|
|
||||||
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
|
|
||||||
from models.dataset import Dataset
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetToolBuilder:
|
|
||||||
@classmethod
|
|
||||||
def build_dataset_tool(cls, dataset: Dataset,
|
|
||||||
response_mode: str = "no_synthesizer",
|
|
||||||
callback_handler: Optional[DatasetToolCallbackHandler] = None):
|
|
||||||
if dataset.indexing_technique == "economy":
|
|
||||||
# use keyword table query
|
|
||||||
index = KeywordTableIndex(dataset=dataset).query_index
|
|
||||||
|
|
||||||
if not index:
|
|
||||||
return None
|
|
||||||
|
|
||||||
query_kwargs = {
|
|
||||||
"mode": "default",
|
|
||||||
"response_mode": response_mode,
|
|
||||||
"query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE,
|
|
||||||
"max_keywords_per_query": 5,
|
|
||||||
# If num_chunks_per_query is too large,
|
|
||||||
# it will slow down the synthesis process due to multiple iterations of refinement.
|
|
||||||
"num_chunks_per_query": 2
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
index = VectorIndex(dataset=dataset).query_index
|
|
||||||
|
|
||||||
if not index:
|
|
||||||
return None
|
|
||||||
|
|
||||||
query_kwargs = {
|
|
||||||
"mode": "default",
|
|
||||||
"response_mode": response_mode,
|
|
||||||
# If top_k is too large,
|
|
||||||
# it will slow down the synthesis process due to multiple iterations of refinement.
|
|
||||||
"similarity_top_k": 2
|
|
||||||
}
|
|
||||||
|
|
||||||
# fulfill description when it is empty
|
|
||||||
description = dataset.description
|
|
||||||
if not description:
|
|
||||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
|
||||||
|
|
||||||
index_tool_config = IndexToolConfig(
|
|
||||||
index=index,
|
|
||||||
name=f"dataset-{dataset.id}",
|
|
||||||
description=description,
|
|
||||||
index_query_kwargs=query_kwargs,
|
|
||||||
tool_kwargs={
|
|
||||||
"callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()])
|
|
||||||
},
|
|
||||||
# tool_kwargs={"return_direct": True},
|
|
||||||
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
|
|
||||||
)
|
|
||||||
|
|
||||||
index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id)
|
|
||||||
|
|
||||||
return EnhanceLlamaIndexTool.from_tool_config(
|
|
||||||
tool_config=index_tool_config,
|
|
||||||
callback_handler=index_callback_handler
|
|
||||||
)
|
|
@ -1,43 +0,0 @@
|
|||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
from llama_index.indices.base import BaseGPTIndex
|
|
||||||
from llama_index.langchain_helpers.agents import IndexToolConfig
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
|
|
||||||
|
|
||||||
|
|
||||||
class EnhanceLlamaIndexTool(BaseTool):
|
|
||||||
"""Tool for querying a LlamaIndex."""
|
|
||||||
|
|
||||||
# NOTE: name/description still needs to be set
|
|
||||||
index: BaseGPTIndex
|
|
||||||
query_kwargs: Dict = Field(default_factory=dict)
|
|
||||||
return_sources: bool = False
|
|
||||||
callback_handler: IndexToolCallbackHandler
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_tool_config(cls, tool_config: IndexToolConfig,
|
|
||||||
callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
|
|
||||||
"""Create a tool from a tool config."""
|
|
||||||
return_sources = tool_config.tool_kwargs.pop("return_sources", False)
|
|
||||||
return cls(
|
|
||||||
index=tool_config.index,
|
|
||||||
callback_handler=callback_handler,
|
|
||||||
name=tool_config.name,
|
|
||||||
description=tool_config.description,
|
|
||||||
return_sources=return_sources,
|
|
||||||
query_kwargs=tool_config.index_query_kwargs,
|
|
||||||
**tool_config.tool_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run(self, tool_input: str) -> str:
|
|
||||||
response = self.index.query(tool_input, **self.query_kwargs)
|
|
||||||
self.callback_handler.on_tool_end(response)
|
|
||||||
return str(response)
|
|
||||||
|
|
||||||
async def _arun(self, tool_input: str) -> str:
|
|
||||||
response = await self.index.aquery(tool_input, **self.query_kwargs)
|
|
||||||
self.callback_handler.on_tool_end(response)
|
|
||||||
return str(response)
|
|
@ -1,34 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from llama_index import ServiceContext, GPTVectorStoreIndex
|
|
||||||
from llama_index.data_structs import Node
|
|
||||||
from llama_index.vector_stores.types import VectorStore
|
|
||||||
|
|
||||||
|
|
||||||
class BaseVectorStoreClient(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def to_index_config(self, index_id: str) -> dict:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class BaseGPTVectorStoreIndex(GPTVectorStoreIndex):
|
|
||||||
def delete_node(self, node_id: str):
|
|
||||||
self._vector_store.delete_node(node_id)
|
|
||||||
|
|
||||||
def exists_by_node_id(self, node_id: str) -> bool:
|
|
||||||
return self._vector_store.exists_by_node_id(node_id)
|
|
||||||
|
|
||||||
|
|
||||||
class EnhanceVectorStore(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def delete_node(self, node_id: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def exists_by_node_id(self, node_id: str) -> bool:
|
|
||||||
pass
|
|
69
api/core/vector_store/qdrant_vector_store.py
Normal file
69
api/core/vector_store/qdrant_vector_store.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from typing import cast, Any
|
||||||
|
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.vectorstores import Qdrant
|
||||||
|
from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
|
||||||
|
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantVectorStore(Qdrant):
|
||||||
|
def del_texts(self, filter: Filter):
|
||||||
|
if not filter:
|
||||||
|
raise ValueError('filter must not be empty')
|
||||||
|
|
||||||
|
self._reload_if_needed()
|
||||||
|
|
||||||
|
self.client.delete(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
points_selector=FilterSelector(
|
||||||
|
filter=filter
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def del_text(self, uuid: str) -> None:
|
||||||
|
self._reload_if_needed()
|
||||||
|
|
||||||
|
self.client.delete(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
points_selector=PointIdsList(
|
||||||
|
points=[uuid],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def text_exists(self, uuid: str) -> bool:
|
||||||
|
self._reload_if_needed()
|
||||||
|
|
||||||
|
response = self.client.retrieve(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
ids=[uuid]
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(response) > 0
|
||||||
|
|
||||||
|
def delete(self):
|
||||||
|
self._reload_if_needed()
|
||||||
|
|
||||||
|
self.client.delete_collection(collection_name=self.collection_name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _document_from_scored_point(
|
||||||
|
cls,
|
||||||
|
scored_point: Any,
|
||||||
|
content_payload_key: str,
|
||||||
|
metadata_payload_key: str,
|
||||||
|
) -> Document:
|
||||||
|
if scored_point.payload.get('doc_id'):
|
||||||
|
return Document(
|
||||||
|
page_content=scored_point.payload.get(content_payload_key),
|
||||||
|
metadata={'doc_id': scored_point.id}
|
||||||
|
)
|
||||||
|
|
||||||
|
return Document(
|
||||||
|
page_content=scored_point.payload.get(content_payload_key),
|
||||||
|
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _reload_if_needed(self):
|
||||||
|
if isinstance(self.client, QdrantLocal):
|
||||||
|
self.client = cast(QdrantLocal, self.client)
|
||||||
|
self.client._load()
|
@ -1,147 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import cast, List
|
|
||||||
|
|
||||||
from llama_index.data_structs import Node
|
|
||||||
from llama_index.data_structs.node_v2 import DocumentRelationship
|
|
||||||
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult
|
|
||||||
from qdrant_client.http.models import Payload, Filter
|
|
||||||
|
|
||||||
import qdrant_client
|
|
||||||
from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex
|
|
||||||
from llama_index.data_structs.data_structs_v2 import QdrantIndexDict
|
|
||||||
from llama_index.vector_stores import QdrantVectorStore
|
|
||||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
|
||||||
|
|
||||||
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
|
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorStoreClient(BaseVectorStoreClient):
|
|
||||||
|
|
||||||
def __init__(self, url: str, api_key: str, root_path: str):
|
|
||||||
self._client = self.init_from_config(url, api_key, root_path)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def init_from_config(cls, url: str, api_key: str, root_path: str):
|
|
||||||
if url and url.startswith('path:'):
|
|
||||||
path = url.replace('path:', '')
|
|
||||||
if not os.path.isabs(path):
|
|
||||||
path = os.path.join(root_path, path)
|
|
||||||
|
|
||||||
return qdrant_client.QdrantClient(
|
|
||||||
path=path
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return qdrant_client.QdrantClient(
|
|
||||||
url=url,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
|
|
||||||
index_struct = QdrantIndexDict()
|
|
||||||
|
|
||||||
if self._client is None:
|
|
||||||
raise Exception("Vector client is not initialized.")
|
|
||||||
|
|
||||||
# {"collection_name": "Gpt_index_xxx"}
|
|
||||||
collection_name = config.get('collection_name')
|
|
||||||
if not collection_name:
|
|
||||||
raise Exception("collection_name cannot be None.")
|
|
||||||
|
|
||||||
return GPTQdrantEnhanceIndex(
|
|
||||||
service_context=service_context,
|
|
||||||
index_struct=index_struct,
|
|
||||||
vector_store=QdrantEnhanceVectorStore(
|
|
||||||
client=self._client,
|
|
||||||
collection_name=collection_name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_index_config(self, index_id: str) -> dict:
|
|
||||||
return {"collection_name": index_id}
|
|
||||||
|
|
||||||
|
|
||||||
class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore):
|
|
||||||
def delete_node(self, node_id: str):
|
|
||||||
"""
|
|
||||||
Delete node from the index.
|
|
||||||
|
|
||||||
:param node_id: node id
|
|
||||||
"""
|
|
||||||
from qdrant_client.http import models as rest
|
|
||||||
|
|
||||||
self._reload_if_needed()
|
|
||||||
|
|
||||||
self._client.delete(
|
|
||||||
collection_name=self._collection_name,
|
|
||||||
points_selector=rest.Filter(
|
|
||||||
must=[
|
|
||||||
rest.FieldCondition(
|
|
||||||
key="id", match=rest.MatchValue(value=node_id)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def exists_by_node_id(self, node_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Get node from the index by node id.
|
|
||||||
|
|
||||||
:param node_id: node id
|
|
||||||
"""
|
|
||||||
self._reload_if_needed()
|
|
||||||
|
|
||||||
response = self._client.retrieve(
|
|
||||||
collection_name=self._collection_name,
|
|
||||||
ids=[node_id]
|
|
||||||
)
|
|
||||||
|
|
||||||
return len(response) > 0
|
|
||||||
|
|
||||||
def query(
|
|
||||||
self,
|
|
||||||
query: VectorStoreQuery,
|
|
||||||
) -> VectorStoreQueryResult:
|
|
||||||
"""Query index for top k most similar nodes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (VectorStoreQuery): query
|
|
||||||
"""
|
|
||||||
query_embedding = cast(List[float], query.query_embedding)
|
|
||||||
|
|
||||||
self._reload_if_needed()
|
|
||||||
|
|
||||||
response = self._client.search(
|
|
||||||
collection_name=self._collection_name,
|
|
||||||
query_vector=query_embedding,
|
|
||||||
limit=cast(int, query.similarity_top_k),
|
|
||||||
query_filter=cast(Filter, self._build_query_filter(query)),
|
|
||||||
with_vectors=True
|
|
||||||
)
|
|
||||||
|
|
||||||
nodes = []
|
|
||||||
similarities = []
|
|
||||||
ids = []
|
|
||||||
for point in response:
|
|
||||||
payload = cast(Payload, point.payload)
|
|
||||||
node = Node(
|
|
||||||
doc_id=str(point.id),
|
|
||||||
text=payload.get("text"),
|
|
||||||
embedding=point.vector,
|
|
||||||
extra_info=payload.get("extra_info"),
|
|
||||||
relationships={
|
|
||||||
DocumentRelationship.SOURCE: payload.get("doc_id", "None"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
nodes.append(node)
|
|
||||||
similarities.append(point.score)
|
|
||||||
ids.append(str(point.id))
|
|
||||||
|
|
||||||
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
|
|
||||||
|
|
||||||
def _reload_if_needed(self):
|
|
||||||
if isinstance(self._client._client, QdrantLocal):
|
|
||||||
self._client._client._load()
|
|
@ -1,62 +0,0 @@
|
|||||||
from flask import Flask
|
|
||||||
from llama_index import ServiceContext, GPTVectorStoreIndex
|
|
||||||
from requests import ReadTimeout
|
|
||||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
|
||||||
|
|
||||||
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
|
|
||||||
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
|
|
||||||
|
|
||||||
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant']
|
|
||||||
|
|
||||||
|
|
||||||
class VectorStore:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._vector_store = None
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
def init_app(self, app: Flask):
|
|
||||||
if not app.config['VECTOR_STORE']:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._vector_store = app.config['VECTOR_STORE']
|
|
||||||
if self._vector_store not in SUPPORTED_VECTOR_STORES:
|
|
||||||
raise ValueError(f"Vector store {self._vector_store} is not supported.")
|
|
||||||
|
|
||||||
if self._vector_store == 'weaviate':
|
|
||||||
self._client = WeaviateVectorStoreClient(
|
|
||||||
endpoint=app.config['WEAVIATE_ENDPOINT'],
|
|
||||||
api_key=app.config['WEAVIATE_API_KEY'],
|
|
||||||
grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'],
|
|
||||||
batch_size=app.config['WEAVIATE_BATCH_SIZE']
|
|
||||||
)
|
|
||||||
elif self._vector_store == 'qdrant':
|
|
||||||
self._client = QdrantVectorStoreClient(
|
|
||||||
url=app.config['QDRANT_URL'],
|
|
||||||
api_key=app.config['QDRANT_API_KEY'],
|
|
||||||
root_path=app.root_path
|
|
||||||
)
|
|
||||||
|
|
||||||
app.extensions['vector_store'] = self
|
|
||||||
|
|
||||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
|
||||||
def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex:
|
|
||||||
vector_store_config: dict = index_struct.get('vector_store')
|
|
||||||
index = self.get_client().get_index(
|
|
||||||
service_context=service_context,
|
|
||||||
config=vector_store_config
|
|
||||||
)
|
|
||||||
|
|
||||||
return index
|
|
||||||
|
|
||||||
def to_index_struct(self, index_id: str) -> dict:
|
|
||||||
return {
|
|
||||||
"type": self._vector_store,
|
|
||||||
"vector_store": self.get_client().to_index_config(index_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_client(self):
|
|
||||||
if not self._client:
|
|
||||||
raise Exception("Vector store client is not initialized.")
|
|
||||||
|
|
||||||
return self._client
|
|
@ -1,66 +0,0 @@
|
|||||||
from llama_index.indices.query.base import IS
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Optional
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_index.docstore import BaseDocumentStore
|
|
||||||
from llama_index.indices.postprocessor.node import (
|
|
||||||
BaseNodePostprocessor,
|
|
||||||
)
|
|
||||||
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
|
|
||||||
from llama_index.indices.response.response_builder import ResponseMode
|
|
||||||
from llama_index.indices.service_context import ServiceContext
|
|
||||||
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
|
|
||||||
from llama_index.prompts.prompts import (
|
|
||||||
QuestionAnswerPrompt,
|
|
||||||
RefinePrompt,
|
|
||||||
SimpleInputPrompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
from core.index.query.synthesizer import EnhanceResponseSynthesizer
|
|
||||||
|
|
||||||
|
|
||||||
class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery):
|
|
||||||
@classmethod
|
|
||||||
def from_args(
|
|
||||||
cls,
|
|
||||||
index_struct: IS,
|
|
||||||
service_context: ServiceContext,
|
|
||||||
docstore: Optional[BaseDocumentStore] = None,
|
|
||||||
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
|
||||||
verbose: bool = False,
|
|
||||||
# response synthesizer args
|
|
||||||
response_mode: ResponseMode = ResponseMode.DEFAULT,
|
|
||||||
text_qa_template: Optional[QuestionAnswerPrompt] = None,
|
|
||||||
refine_template: Optional[RefinePrompt] = None,
|
|
||||||
simple_template: Optional[SimpleInputPrompt] = None,
|
|
||||||
response_kwargs: Optional[Dict] = None,
|
|
||||||
use_async: bool = False,
|
|
||||||
streaming: bool = False,
|
|
||||||
optimizer: Optional[BaseTokenUsageOptimizer] = None,
|
|
||||||
# class-specific args
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> "BaseGPTIndexQuery":
|
|
||||||
response_synthesizer = EnhanceResponseSynthesizer.from_args(
|
|
||||||
service_context=service_context,
|
|
||||||
text_qa_template=text_qa_template,
|
|
||||||
refine_template=refine_template,
|
|
||||||
simple_template=simple_template,
|
|
||||||
response_mode=response_mode,
|
|
||||||
response_kwargs=response_kwargs,
|
|
||||||
use_async=use_async,
|
|
||||||
streaming=streaming,
|
|
||||||
optimizer=optimizer,
|
|
||||||
)
|
|
||||||
return cls(
|
|
||||||
index_struct=index_struct,
|
|
||||||
service_context=service_context,
|
|
||||||
response_synthesizer=response_synthesizer,
|
|
||||||
docstore=docstore,
|
|
||||||
node_postprocessors=node_postprocessors,
|
|
||||||
verbose=verbose,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
38
api/core/vector_store/weaviate_vector_store.py
Normal file
38
api/core/vector_store/weaviate_vector_store.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from langchain.vectorstores import Weaviate
|
||||||
|
|
||||||
|
|
||||||
|
class WeaviateVectorStore(Weaviate):
|
||||||
|
def del_texts(self, where_filter: dict):
|
||||||
|
if not where_filter:
|
||||||
|
raise ValueError('where_filter must not be empty')
|
||||||
|
|
||||||
|
self._client.batch.delete_objects(
|
||||||
|
class_name=self._index_name,
|
||||||
|
where=where_filter,
|
||||||
|
output='minimal'
|
||||||
|
)
|
||||||
|
|
||||||
|
def del_text(self, uuid: str) -> None:
|
||||||
|
self._client.data_object.delete(
|
||||||
|
uuid,
|
||||||
|
class_name=self._index_name
|
||||||
|
)
|
||||||
|
|
||||||
|
def text_exists(self, uuid: str) -> bool:
|
||||||
|
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
|
||||||
|
"path": ["doc_id"],
|
||||||
|
"operator": "Equal",
|
||||||
|
"valueText": uuid,
|
||||||
|
}).with_limit(1).do()
|
||||||
|
|
||||||
|
if "errors" in result:
|
||||||
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
|
|
||||||
|
entries = result["data"]["Get"][self._index_name]
|
||||||
|
if len(entries) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def delete(self):
|
||||||
|
self._client.schema.delete_class(self._index_name)
|
@ -1,270 +0,0 @@
|
|||||||
import json
|
|
||||||
import weaviate
|
|
||||||
from dataclasses import field
|
|
||||||
from typing import List, Any, Dict, Optional
|
|
||||||
|
|
||||||
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
|
|
||||||
from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
|
|
||||||
from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node
|
|
||||||
from llama_index.data_structs.node_v2 import DocumentRelationship
|
|
||||||
from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger
|
|
||||||
from llama_index.vector_stores import WeaviateVectorStore
|
|
||||||
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
|
|
||||||
from llama_index.readers.weaviate.utils import (
|
|
||||||
parse_get_response,
|
|
||||||
validate_client,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVectorStoreClient(BaseVectorStoreClient):
|
|
||||||
|
|
||||||
def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
|
|
||||||
self._client = self.init_from_config(endpoint, api_key, grpc_enabled, batch_size)
|
|
||||||
|
|
||||||
def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
|
|
||||||
auth_config = weaviate.auth.AuthApiKey(api_key=api_key)
|
|
||||||
|
|
||||||
weaviate.connect.connection.has_grpc = grpc_enabled
|
|
||||||
|
|
||||||
client = weaviate.Client(
|
|
||||||
url=endpoint,
|
|
||||||
auth_client_secret=auth_config,
|
|
||||||
timeout_config=(5, 60),
|
|
||||||
startup_period=None
|
|
||||||
)
|
|
||||||
|
|
||||||
client.batch.configure(
|
|
||||||
# `batch_size` takes an `int` value to enable auto-batching
|
|
||||||
# (`None` is used for manual batching)
|
|
||||||
batch_size=batch_size,
|
|
||||||
# dynamically update the `batch_size` based on import speed
|
|
||||||
dynamic=True,
|
|
||||||
# `timeout_retries` takes an `int` value to retry on time outs
|
|
||||||
timeout_retries=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
|
|
||||||
index_struct = WeaviateIndexDict()
|
|
||||||
|
|
||||||
if self._client is None:
|
|
||||||
raise Exception("Vector client is not initialized.")
|
|
||||||
|
|
||||||
# {"class_prefix": "Gpt_index_xxx"}
|
|
||||||
class_prefix = config.get('class_prefix')
|
|
||||||
if not class_prefix:
|
|
||||||
raise Exception("class_prefix cannot be None.")
|
|
||||||
|
|
||||||
return GPTWeaviateEnhanceIndex(
|
|
||||||
service_context=service_context,
|
|
||||||
index_struct=index_struct,
|
|
||||||
vector_store=WeaviateWithSimilaritiesVectorStore(
|
|
||||||
weaviate_client=self._client,
|
|
||||||
class_prefix=class_prefix
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_index_config(self, index_id: str) -> dict:
|
|
||||||
return {"class_prefix": index_id}
|
|
||||||
|
|
||||||
|
|
||||||
class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
|
|
||||||
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
|
|
||||||
"""Query index for top k most similar nodes."""
|
|
||||||
nodes = self.weaviate_query(
|
|
||||||
self._client,
|
|
||||||
self._class_prefix,
|
|
||||||
query,
|
|
||||||
)
|
|
||||||
nodes = nodes[: query.similarity_top_k]
|
|
||||||
node_idxs = [str(i) for i in range(len(nodes))]
|
|
||||||
|
|
||||||
similarities = []
|
|
||||||
for node in nodes:
|
|
||||||
similarities.append(node.extra_info['similarity'])
|
|
||||||
del node.extra_info['similarity']
|
|
||||||
|
|
||||||
return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities)
|
|
||||||
|
|
||||||
def weaviate_query(
|
|
||||||
self,
|
|
||||||
client: Any,
|
|
||||||
class_prefix: str,
|
|
||||||
query_spec: VectorStoreQuery,
|
|
||||||
) -> List[Node]:
|
|
||||||
"""Convert to LlamaIndex list."""
|
|
||||||
validate_client(client)
|
|
||||||
|
|
||||||
class_name = _class_name(class_prefix)
|
|
||||||
prop_names = [p["name"] for p in NODE_SCHEMA]
|
|
||||||
vector = query_spec.query_embedding
|
|
||||||
|
|
||||||
# build query
|
|
||||||
query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"])
|
|
||||||
if query_spec.mode == VectorStoreQueryMode.DEFAULT:
|
|
||||||
_logger.debug("Using vector search")
|
|
||||||
if vector is not None:
|
|
||||||
query = query.with_near_vector(
|
|
||||||
{
|
|
||||||
"vector": vector,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif query_spec.mode == VectorStoreQueryMode.HYBRID:
|
|
||||||
_logger.debug(f"Using hybrid search with alpha {query_spec.alpha}")
|
|
||||||
query = query.with_hybrid(
|
|
||||||
query=query_spec.query_str,
|
|
||||||
alpha=query_spec.alpha,
|
|
||||||
vector=vector,
|
|
||||||
)
|
|
||||||
query = query.with_limit(query_spec.similarity_top_k)
|
|
||||||
_logger.debug(f"Using limit of {query_spec.similarity_top_k}")
|
|
||||||
|
|
||||||
# execute query
|
|
||||||
query_result = query.do()
|
|
||||||
|
|
||||||
# parse results
|
|
||||||
parsed_result = parse_get_response(query_result)
|
|
||||||
entries = parsed_result[class_name]
|
|
||||||
results = [self._to_node(entry) for entry in entries]
|
|
||||||
return results
|
|
||||||
|
|
||||||
def _to_node(self, entry: Dict) -> Node:
|
|
||||||
"""Convert to Node."""
|
|
||||||
extra_info_str = entry["extra_info"]
|
|
||||||
if extra_info_str == "":
|
|
||||||
extra_info = None
|
|
||||||
else:
|
|
||||||
extra_info = json.loads(extra_info_str)
|
|
||||||
|
|
||||||
if 'certainty' in entry['_additional']:
|
|
||||||
if extra_info:
|
|
||||||
extra_info['similarity'] = entry['_additional']['certainty']
|
|
||||||
else:
|
|
||||||
extra_info = {'similarity': entry['_additional']['certainty']}
|
|
||||||
|
|
||||||
node_info_str = entry["node_info"]
|
|
||||||
if node_info_str == "":
|
|
||||||
node_info = None
|
|
||||||
else:
|
|
||||||
node_info = json.loads(node_info_str)
|
|
||||||
|
|
||||||
relationships_str = entry["relationships"]
|
|
||||||
relationships: Dict[DocumentRelationship, str]
|
|
||||||
if relationships_str == "":
|
|
||||||
relationships = field(default_factory=dict)
|
|
||||||
else:
|
|
||||||
relationships = {
|
|
||||||
DocumentRelationship(k): v for k, v in json.loads(relationships_str).items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return Node(
|
|
||||||
text=entry["text"],
|
|
||||||
doc_id=entry["doc_id"],
|
|
||||||
embedding=entry["_additional"]["vector"],
|
|
||||||
extra_info=extra_info,
|
|
||||||
node_info=node_info,
|
|
||||||
relationships=relationships,
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
|
|
||||||
"""Delete a document.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc_id (str): document id
|
|
||||||
|
|
||||||
"""
|
|
||||||
delete_document(self._client, doc_id, self._class_prefix)
|
|
||||||
|
|
||||||
def delete_node(self, node_id: str):
|
|
||||||
"""
|
|
||||||
Delete node from the index.
|
|
||||||
|
|
||||||
:param node_id: node id
|
|
||||||
"""
|
|
||||||
delete_node(self._client, node_id, self._class_prefix)
|
|
||||||
|
|
||||||
def exists_by_node_id(self, node_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Get node from the index by node id.
|
|
||||||
|
|
||||||
:param node_id: node id
|
|
||||||
"""
|
|
||||||
entry = get_by_node_id(self._client, node_id, self._class_prefix)
|
|
||||||
return True if entry else False
|
|
||||||
|
|
||||||
|
|
||||||
class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None:
|
|
||||||
"""Delete entry."""
|
|
||||||
validate_client(client)
|
|
||||||
# make sure that each entry
|
|
||||||
class_name = _class_name(class_prefix)
|
|
||||||
where_filter = {
|
|
||||||
"path": ["ref_doc_id"],
|
|
||||||
"operator": "Equal",
|
|
||||||
"valueString": ref_doc_id,
|
|
||||||
}
|
|
||||||
query = (
|
|
||||||
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
|
|
||||||
)
|
|
||||||
|
|
||||||
query_result = query.do()
|
|
||||||
parsed_result = parse_get_response(query_result)
|
|
||||||
entries = parsed_result[class_name]
|
|
||||||
for entry in entries:
|
|
||||||
client.data_object.delete(entry["_additional"]["id"], class_name)
|
|
||||||
|
|
||||||
while len(entries) > 0:
|
|
||||||
query_result = query.do()
|
|
||||||
parsed_result = parse_get_response(query_result)
|
|
||||||
entries = parsed_result[class_name]
|
|
||||||
for entry in entries:
|
|
||||||
client.data_object.delete(entry["_additional"]["id"], class_name)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_node(client: Any, node_id: str, class_prefix: str) -> None:
|
|
||||||
"""Delete entry."""
|
|
||||||
validate_client(client)
|
|
||||||
# make sure that each entry
|
|
||||||
class_name = _class_name(class_prefix)
|
|
||||||
where_filter = {
|
|
||||||
"path": ["doc_id"],
|
|
||||||
"operator": "Equal",
|
|
||||||
"valueString": node_id,
|
|
||||||
}
|
|
||||||
query = (
|
|
||||||
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
|
|
||||||
)
|
|
||||||
|
|
||||||
query_result = query.do()
|
|
||||||
parsed_result = parse_get_response(query_result)
|
|
||||||
entries = parsed_result[class_name]
|
|
||||||
for entry in entries:
|
|
||||||
client.data_object.delete(entry["_additional"]["id"], class_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]:
|
|
||||||
"""Delete entry."""
|
|
||||||
validate_client(client)
|
|
||||||
# make sure that each entry
|
|
||||||
class_name = _class_name(class_prefix)
|
|
||||||
where_filter = {
|
|
||||||
"path": ["doc_id"],
|
|
||||||
"operator": "Equal",
|
|
||||||
"valueString": node_id,
|
|
||||||
}
|
|
||||||
query = (
|
|
||||||
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
|
|
||||||
)
|
|
||||||
|
|
||||||
query_result = query.do()
|
|
||||||
parsed_result = parse_get_response(query_result)
|
|
||||||
entries = parsed_result[class_name]
|
|
||||||
if len(entries) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return entries[0]
|
|
@ -1,7 +0,0 @@
|
|||||||
from core.vector_store.vector_store import VectorStore
|
|
||||||
|
|
||||||
vector_store = VectorStore()
|
|
||||||
|
|
||||||
|
|
||||||
def init_app(app):
|
|
||||||
vector_store.init_app(app)
|
|
@ -3,6 +3,7 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from hashlib import sha256
|
||||||
from zoneinfo import available_timezones
|
from zoneinfo import available_timezones
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
@ -147,3 +148,8 @@ def get_remote_ip(request):
|
|||||||
return request.headers.getlist("X-Forwarded-For")[0]
|
return request.headers.getlist("X-Forwarded-For")[0]
|
||||||
else:
|
else:
|
||||||
return request.remote_addr
|
return request.remote_addr
|
||||||
|
|
||||||
|
|
||||||
|
def generate_text_hash(text: str) -> str:
|
||||||
|
hash_text = str(text) + 'None'
|
||||||
|
return sha256(hash_text.encode()).hexdigest()
|
||||||
|
@ -38,8 +38,6 @@ class Account(UserMixin, db.Model):
|
|||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
|
|
||||||
_current_tenant: db.Model = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_tenant(self):
|
def current_tenant(self):
|
||||||
return self._current_tenant
|
return self._current_tenant
|
||||||
|
@ -66,6 +66,23 @@ class Dataset(db.Model):
|
|||||||
def document_count(self):
|
def document_count(self):
|
||||||
return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
|
return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_document_count(self):
|
||||||
|
return db.session.query(func.count(Document.id)).filter(
|
||||||
|
Document.dataset_id == self.id,
|
||||||
|
Document.indexing_status == 'completed',
|
||||||
|
Document.enabled == True,
|
||||||
|
Document.archived == False
|
||||||
|
).scalar()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_segment_count(self):
|
||||||
|
return db.session.query(func.count(DocumentSegment.id)).filter(
|
||||||
|
DocumentSegment.dataset_id == self.id,
|
||||||
|
DocumentSegment.status == 'completed',
|
||||||
|
DocumentSegment.enabled == True
|
||||||
|
).scalar()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def word_count(self):
|
def word_count(self):
|
||||||
return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
|
return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
|
||||||
@ -260,7 +277,7 @@ class Document(db.Model):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
return Dataset.query.get(self.dataset_id)
|
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def segment_count(self):
|
def segment_count(self):
|
||||||
@ -395,7 +412,18 @@ class DatasetKeywordTable(db.Model):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def keyword_table_dict(self):
|
def keyword_table_dict(self):
|
||||||
return json.loads(self.keyword_table) if self.keyword_table else None
|
class SetDecoder(json.JSONDecoder):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(object_hook=self.object_hook, *args, **kwargs)
|
||||||
|
|
||||||
|
def object_hook(self, dct):
|
||||||
|
if isinstance(dct, dict):
|
||||||
|
for keyword, node_idxs in dct.items():
|
||||||
|
if isinstance(node_idxs, list):
|
||||||
|
dct[keyword] = set(node_idxs)
|
||||||
|
return dct
|
||||||
|
|
||||||
|
return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
|
||||||
|
|
||||||
|
|
||||||
class Embedding(db.Model):
|
class Embedding(db.Model):
|
||||||
|
@ -2,6 +2,7 @@ coverage~=7.2.4
|
|||||||
beautifulsoup4==4.12.2
|
beautifulsoup4==4.12.2
|
||||||
flask~=2.3.2
|
flask~=2.3.2
|
||||||
Flask-SQLAlchemy~=3.0.3
|
Flask-SQLAlchemy~=3.0.3
|
||||||
|
SQLAlchemy~=1.4.28
|
||||||
flask-login==0.6.2
|
flask-login==0.6.2
|
||||||
flask-migrate~=4.0.4
|
flask-migrate~=4.0.4
|
||||||
flask-restful==0.3.9
|
flask-restful==0.3.9
|
||||||
@ -9,8 +10,7 @@ flask-session2==1.3.1
|
|||||||
flask-cors==3.0.10
|
flask-cors==3.0.10
|
||||||
gunicorn~=20.1.0
|
gunicorn~=20.1.0
|
||||||
gevent~=22.10.2
|
gevent~=22.10.2
|
||||||
langchain==0.0.142
|
langchain==0.0.209
|
||||||
llama-index==0.5.27
|
|
||||||
openai~=0.27.5
|
openai~=0.27.5
|
||||||
psycopg2-binary~=2.9.6
|
psycopg2-binary~=2.9.6
|
||||||
pycryptodome==3.17
|
pycryptodome==3.17
|
||||||
@ -29,6 +29,7 @@ sentry-sdk[flask]~=1.21.1
|
|||||||
jieba==0.42.1
|
jieba==0.42.1
|
||||||
celery==5.2.7
|
celery==5.2.7
|
||||||
redis~=4.5.4
|
redis~=4.5.4
|
||||||
pypdf==3.8.1
|
|
||||||
openpyxl==3.1.2
|
openpyxl==3.1.2
|
||||||
chardet~=5.1.0
|
chardet~=5.1.0
|
||||||
|
docx2txt==0.8
|
||||||
|
pypdfium2==4.16.0
|
@ -4,7 +4,6 @@ import uuid
|
|||||||
from core.constant import llm_constant
|
from core.constant import llm_constant
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from services.errors.account import NoPermissionError
|
|
||||||
|
|
||||||
|
|
||||||
class AppModelConfigService:
|
class AppModelConfigService:
|
||||||
|
@ -7,7 +7,6 @@ from typing import Optional, List
|
|||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
|
||||||
from core.index.index_builder import IndexBuilder
|
|
||||||
from events.dataset_event import dataset_was_deleted
|
from events.dataset_event import dataset_was_deleted
|
||||||
from events.document_event import document_was_deleted
|
from events.document_event import document_was_deleted
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -386,8 +385,6 @@ class DocumentService:
|
|||||||
|
|
||||||
dataset.indexing_technique = document_data["indexing_technique"]
|
dataset.indexing_technique = document_data["indexing_technique"]
|
||||||
|
|
||||||
if dataset.indexing_technique == 'high_quality':
|
|
||||||
IndexBuilder.get_default_service_context(dataset.tenant_id)
|
|
||||||
documents = []
|
documents = []
|
||||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||||
if 'original_document_id' in document_data and document_data["original_document_id"]:
|
if 'original_document_id' in document_data and document_data["original_document_id"]:
|
||||||
|
@ -3,47 +3,56 @@ import time
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from llama_index.data_structs.node_v2 import NodeWithScore
|
from flask import current_app
|
||||||
from llama_index.indices.query.schema import QueryBundle
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import Document
|
||||||
from sklearn.manifold import TSNE
|
from sklearn.manifold import TSNE
|
||||||
|
|
||||||
from core.docstore.empty_docstore import EmptyDocumentStore
|
from core.embedding.cached_embedding import CacheEmbedding
|
||||||
from core.index.vector_index import VectorIndex
|
from core.index.vector_index.vector_index import VectorIndex
|
||||||
|
from core.llm.llm_builder import LLMBuilder
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.dataset import Dataset, DocumentSegment, DatasetQuery
|
from models.dataset import Dataset, DocumentSegment, DatasetQuery
|
||||||
from services.errors.index import IndexNotInitializedError
|
|
||||||
|
|
||||||
|
|
||||||
class HitTestingService:
|
class HitTestingService:
|
||||||
@classmethod
|
@classmethod
|
||||||
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
|
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
|
||||||
index = VectorIndex(dataset=dataset).query_index
|
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
|
||||||
|
return {
|
||||||
|
"query": {
|
||||||
|
"content": query,
|
||||||
|
"tsne_position": {'x': 0, 'y': 0},
|
||||||
|
},
|
||||||
|
"records": []
|
||||||
|
}
|
||||||
|
|
||||||
if not index:
|
model_credentials = LLMBuilder.get_model_credentials(
|
||||||
raise IndexNotInitializedError()
|
tenant_id=dataset.tenant_id,
|
||||||
|
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
|
||||||
index_query = GPTVectorStoreIndexQuery(
|
model_name='text-embedding-ada-002'
|
||||||
index_struct=index.index_struct,
|
|
||||||
service_context=index.service_context,
|
|
||||||
vector_store=index.query_context.get('vector_store'),
|
|
||||||
docstore=EmptyDocumentStore(),
|
|
||||||
response_synthesizer=None,
|
|
||||||
similarity_top_k=limit
|
|
||||||
)
|
)
|
||||||
|
|
||||||
query_bundle = QueryBundle(
|
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||||
query_str=query,
|
**model_credentials
|
||||||
custom_embedding_strs=[query],
|
))
|
||||||
)
|
|
||||||
|
|
||||||
query_bundle.embedding = index.service_context.embed_model.get_agg_embedding_from_queries(
|
vector_index = VectorIndex(
|
||||||
query_bundle.embedding_strs
|
dataset=dataset,
|
||||||
|
config=current_app.config,
|
||||||
|
embeddings=embeddings
|
||||||
)
|
)
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
nodes = index_query.retrieve(query_bundle=query_bundle)
|
documents = vector_index.search(
|
||||||
|
query,
|
||||||
|
search_type='similarity_score_threshold',
|
||||||
|
search_kwargs={
|
||||||
|
'k': 10
|
||||||
|
}
|
||||||
|
)
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
|
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
|
||||||
|
|
||||||
@ -58,25 +67,24 @@ class HitTestingService:
|
|||||||
db.session.add(dataset_query)
|
db.session.add(dataset_query)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return cls.compact_retrieve_response(dataset, query_bundle, nodes)
|
return cls.compact_retrieve_response(dataset, embeddings, query, documents)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compact_retrieve_response(cls, dataset: Dataset, query_bundle: QueryBundle, nodes: List[NodeWithScore]):
|
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
|
||||||
embeddings = [
|
text_embeddings = [
|
||||||
query_bundle.embedding
|
embeddings.embed_query(query)
|
||||||
]
|
]
|
||||||
|
|
||||||
for node in nodes:
|
text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))
|
||||||
embeddings.append(node.node.embedding)
|
|
||||||
|
|
||||||
tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings)
|
tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)
|
||||||
|
|
||||||
query_position = tsne_position_data.pop(0)
|
query_position = tsne_position_data.pop(0)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
records = []
|
records = []
|
||||||
for node in nodes:
|
for document in documents:
|
||||||
index_node_id = node.node.doc_id
|
index_node_id = document.metadata['doc_id']
|
||||||
|
|
||||||
segment = db.session.query(DocumentSegment).filter(
|
segment = db.session.query(DocumentSegment).filter(
|
||||||
DocumentSegment.dataset_id == dataset.id,
|
DocumentSegment.dataset_id == dataset.id,
|
||||||
@ -91,7 +99,7 @@ class HitTestingService:
|
|||||||
|
|
||||||
record = {
|
record = {
|
||||||
"segment": segment,
|
"segment": segment,
|
||||||
"score": node.score,
|
"score": document.metadata['score'],
|
||||||
"tsne_position": tsne_position_data[i]
|
"tsne_position": tsne_position_data[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,7 +109,7 @@ class HitTestingService:
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"query": {
|
"query": {
|
||||||
"content": query_bundle.query_str,
|
"content": query,
|
||||||
"tsne_position": query_position,
|
"tsne_position": query_position,
|
||||||
},
|
},
|
||||||
"records": records
|
"records": records
|
||||||
|
@ -4,96 +4,81 @@ import time
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from llama_index.data_structs import Node
|
from langchain.schema import Document
|
||||||
from llama_index.data_structs.node_v2 import DocumentRelationship
|
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from core.index.keyword_table_index import KeywordTableIndex
|
from core.index.index import IndexBuilder
|
||||||
from core.index.vector_index import VectorIndex
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import DocumentSegment, Document
|
from models.dataset import DocumentSegment
|
||||||
|
from models.dataset import Document as DatasetDocument
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
def add_document_to_index_task(document_id: str):
|
def add_document_to_index_task(dataset_document_id: str):
|
||||||
"""
|
"""
|
||||||
Async Add document to index
|
Async Add document to index
|
||||||
:param document_id:
|
:param document_id:
|
||||||
|
|
||||||
Usage: add_document_to_index.delay(document_id)
|
Usage: add_document_to_index.delay(document_id)
|
||||||
"""
|
"""
|
||||||
logging.info(click.style('Start add document to index: {}'.format(document_id), fg='green'))
|
logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green'))
|
||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
|
||||||
document = db.session.query(Document).filter(Document.id == document_id).first()
|
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first()
|
||||||
if not document:
|
if not dataset_document:
|
||||||
raise NotFound('Document not found')
|
raise NotFound('Document not found')
|
||||||
|
|
||||||
if document.indexing_status != 'completed':
|
if dataset_document.indexing_status != 'completed':
|
||||||
return
|
return
|
||||||
|
|
||||||
indexing_cache_key = 'document_{}_indexing'.format(document.id)
|
indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
segments = db.session.query(DocumentSegment).filter(
|
segments = db.session.query(DocumentSegment).filter(
|
||||||
DocumentSegment.document_id == document.id,
|
DocumentSegment.document_id == dataset_document.id,
|
||||||
DocumentSegment.enabled == True
|
DocumentSegment.enabled == True
|
||||||
) \
|
) \
|
||||||
.order_by(DocumentSegment.position.asc()).all()
|
.order_by(DocumentSegment.position.asc()).all()
|
||||||
|
|
||||||
nodes = []
|
documents = []
|
||||||
previous_node = None
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
relationships = {
|
document = Document(
|
||||||
DocumentRelationship.SOURCE: document.id
|
page_content=segment.content,
|
||||||
}
|
metadata={
|
||||||
|
"doc_id": segment.index_node_id,
|
||||||
if previous_node:
|
"doc_hash": segment.index_node_hash,
|
||||||
relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id
|
}
|
||||||
|
|
||||||
node = Node(
|
|
||||||
doc_id=segment.index_node_id,
|
|
||||||
doc_hash=segment.index_node_hash,
|
|
||||||
text=segment.content,
|
|
||||||
extra_info=None,
|
|
||||||
node_info=None,
|
|
||||||
relationships=relationships
|
|
||||||
)
|
)
|
||||||
|
|
||||||
previous_node = node
|
documents.append(document)
|
||||||
|
|
||||||
nodes.append(node)
|
dataset = dataset_document.dataset
|
||||||
|
|
||||||
dataset = document.dataset
|
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception('Document has no dataset')
|
raise Exception('Document has no dataset')
|
||||||
|
|
||||||
vector_index = VectorIndex(dataset=dataset)
|
|
||||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
|
||||||
|
|
||||||
# save vector index
|
# save vector index
|
||||||
if dataset.indexing_technique == "high_quality":
|
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
vector_index.add_nodes(
|
if index:
|
||||||
nodes=nodes,
|
index.add_texts(documents)
|
||||||
duplicate_check=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# save keyword index
|
# save keyword index
|
||||||
keyword_table_index.add_nodes(nodes)
|
index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
|
if index:
|
||||||
|
index.add_texts(documents)
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(
|
logging.info(
|
||||||
click.style('Document added to index: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
|
click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green'))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("add document to index failed")
|
logging.exception("add document to index failed")
|
||||||
document.enabled = False
|
dataset_document.enabled = False
|
||||||
document.disabled_at = datetime.datetime.utcnow()
|
dataset_document.disabled_at = datetime.datetime.utcnow()
|
||||||
document.status = 'error'
|
dataset_document.status = 'error'
|
||||||
document.error = str(e)
|
dataset_document.error = str(e)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
finally:
|
finally:
|
||||||
redis_client.delete(indexing_cache_key)
|
redis_client.delete(indexing_cache_key)
|
||||||
|
@ -4,12 +4,10 @@ import time
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from llama_index.data_structs import Node
|
from langchain.schema import Document
|
||||||
from llama_index.data_structs.node_v2 import DocumentRelationship
|
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from core.index.keyword_table_index import KeywordTableIndex
|
from core.index.index import IndexBuilder
|
||||||
from core.index.vector_index import VectorIndex
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import DocumentSegment
|
from models.dataset import DocumentSegment
|
||||||
@ -36,44 +34,41 @@ def add_segment_to_index_task(segment_id: str):
|
|||||||
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
relationships = {
|
document = Document(
|
||||||
DocumentRelationship.SOURCE: segment.document_id,
|
page_content=segment.content,
|
||||||
}
|
metadata={
|
||||||
|
"doc_id": segment.index_node_id,
|
||||||
previous_segment = segment.previous_segment
|
"doc_hash": segment.index_node_hash,
|
||||||
if previous_segment:
|
"document_id": segment.document_id,
|
||||||
relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
|
"dataset_id": segment.dataset_id,
|
||||||
|
}
|
||||||
next_segment = segment.next_segment
|
|
||||||
if next_segment:
|
|
||||||
relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
|
|
||||||
|
|
||||||
node = Node(
|
|
||||||
doc_id=segment.index_node_id,
|
|
||||||
doc_hash=segment.index_node_hash,
|
|
||||||
text=segment.content,
|
|
||||||
extra_info=None,
|
|
||||||
node_info=None,
|
|
||||||
relationships=relationships
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = segment.dataset
|
dataset = segment.dataset
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception('Segment has no dataset')
|
logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
|
||||||
|
return
|
||||||
|
|
||||||
vector_index = VectorIndex(dataset=dataset)
|
dataset_document = segment.document
|
||||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
|
||||||
|
if not dataset_document:
|
||||||
|
logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
|
||||||
|
return
|
||||||
|
|
||||||
|
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
|
||||||
|
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
|
||||||
|
return
|
||||||
|
|
||||||
# save vector index
|
# save vector index
|
||||||
if dataset.indexing_technique == "high_quality":
|
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
vector_index.add_nodes(
|
if index:
|
||||||
nodes=[node],
|
index.add_texts([document], duplicate_check=True)
|
||||||
duplicate_check=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# save keyword index
|
# save keyword index
|
||||||
keyword_table_index.add_nodes([node])
|
index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
|
if index:
|
||||||
|
index.add_texts([document])
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
|
logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
|
||||||
|
@ -4,8 +4,7 @@ import time
|
|||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
|
||||||
from core.index.keyword_table_index import KeywordTableIndex
|
from core.index.index import IndexBuilder
|
||||||
from core.index.vector_index import VectorIndex
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
|
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
|
||||||
AppDatasetJoin
|
AppDatasetJoin
|
||||||
@ -33,29 +32,24 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
|||||||
index_struct=index_struct
|
index_struct=index_struct
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_index = VectorIndex(dataset=dataset)
|
|
||||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
|
||||||
|
|
||||||
documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
|
documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
|
||||||
index_doc_ids = [document.id for document in documents]
|
|
||||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
|
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
|
||||||
|
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
|
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
|
|
||||||
# delete from vector index
|
# delete from vector index
|
||||||
if dataset.indexing_technique == "high_quality":
|
if vector_index:
|
||||||
for index_doc_id in index_doc_ids:
|
try:
|
||||||
try:
|
vector_index.delete()
|
||||||
vector_index.del_doc(index_doc_id)
|
except Exception:
|
||||||
except Exception:
|
logging.exception("Delete doc index failed when dataset deleted.")
|
||||||
logging.exception("Delete doc index failed when dataset deleted.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# delete from keyword index
|
# delete from keyword index
|
||||||
if index_node_ids:
|
try:
|
||||||
try:
|
kw_index.delete()
|
||||||
keyword_table_index.del_nodes(index_node_ids)
|
except Exception:
|
||||||
except Exception:
|
logging.exception("Delete nodes index failed when dataset deleted.")
|
||||||
logging.exception("Delete nodes index failed when dataset deleted.")
|
|
||||||
|
|
||||||
for document in documents:
|
for document in documents:
|
||||||
db.session.delete(document)
|
db.session.delete(document)
|
||||||
@ -63,7 +57,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
|||||||
for segment in segments:
|
for segment in segments:
|
||||||
db.session.delete(segment)
|
db.session.delete(segment)
|
||||||
|
|
||||||
db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == dataset_id).delete()
|
|
||||||
db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete()
|
db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete()
|
||||||
db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete()
|
db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete()
|
||||||
db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete()
|
db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete()
|
||||||
|
@ -4,8 +4,7 @@ import time
|
|||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
|
||||||
from core.index.keyword_table_index import KeywordTableIndex
|
from core.index.index import IndexBuilder
|
||||||
from core.index.vector_index import VectorIndex
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import DocumentSegment, Dataset
|
from models.dataset import DocumentSegment, Dataset
|
||||||
|
|
||||||
@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str):
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception('Document has no dataset')
|
raise Exception('Document has no dataset')
|
||||||
|
|
||||||
vector_index = VectorIndex(dataset=dataset)
|
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
|
|
||||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
|
||||||
# delete from vector index
|
# delete from vector index
|
||||||
vector_index.del_nodes(index_node_ids)
|
if vector_index:
|
||||||
|
vector_index.delete_by_document_id(document_id)
|
||||||
|
|
||||||
# delete from keyword index
|
# delete from keyword index
|
||||||
if index_node_ids:
|
if index_node_ids:
|
||||||
keyword_table_index.del_nodes(index_node_ids)
|
kw_index.delete_by_ids(index_node_ids)
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
db.session.delete(segment)
|
db.session.delete(segment)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(
|
logging.info(
|
||||||
|
@ -5,8 +5,7 @@ from typing import List
|
|||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
|
||||||
from core.index.keyword_table_index import KeywordTableIndex
|
from core.index.index import IndexBuilder
|
||||||
from core.index.vector_index import VectorIndex
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import DocumentSegment, Dataset, Document
|
from models.dataset import DocumentSegment, Dataset, Document
|
||||||
|
|
||||||
@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str):
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception('Document has no dataset')
|
raise Exception('Document has no dataset')
|
||||||
|
|
||||||
vector_index = VectorIndex(dataset=dataset)
|
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
for document_id in document_ids:
|
for document_id in document_ids:
|
||||||
document = db.session.query(Document).filter(
|
document = db.session.query(Document).filter(
|
||||||
Document.id == document_id
|
Document.id == document_id
|
||||||
).first()
|
).first()
|
||||||
db.session.delete(document)
|
db.session.delete(document)
|
||||||
|
|
||||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
|
||||||
# delete from vector index
|
# delete from vector index
|
||||||
vector_index.del_nodes(index_node_ids)
|
if vector_index:
|
||||||
|
vector_index.delete_by_document_id(document_id)
|
||||||
|
|
||||||
# delete from keyword index
|
# delete from keyword index
|
||||||
if index_node_ids:
|
if index_node_ids:
|
||||||
keyword_table_index.del_nodes(index_node_ids)
|
kw_index.delete_by_ids(index_node_ids)
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
db.session.delete(segment)
|
db.session.delete(segment)
|
||||||
|
@ -3,10 +3,12 @@ import time
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from llama_index.data_structs.node_v2 import DocumentRelationship, Node
|
from langchain.schema import Document
|
||||||
from core.index.vector_index import VectorIndex
|
|
||||||
|
from core.index.index import IndexBuilder
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import DocumentSegment, Document, Dataset
|
from models.dataset import DocumentSegment, Dataset
|
||||||
|
from models.dataset import Document as DatasetDocument
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@ -24,49 +26,47 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
|||||||
dataset = Dataset.query.filter_by(
|
dataset = Dataset.query.filter_by(
|
||||||
id=dataset_id
|
id=dataset_id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception('Dataset not found')
|
raise Exception('Dataset not found')
|
||||||
documents = Document.query.filter_by(dataset_id=dataset_id).all()
|
|
||||||
if documents:
|
if action == "remove":
|
||||||
vector_index = VectorIndex(dataset=dataset)
|
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
|
||||||
for document in documents:
|
index.delete()
|
||||||
# delete from vector index
|
elif action == "add":
|
||||||
if action == "remove":
|
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||||
vector_index.del_doc(document.id)
|
DatasetDocument.dataset_id == dataset_id,
|
||||||
elif action == "add":
|
DatasetDocument.indexing_status == 'completed',
|
||||||
|
DatasetDocument.enabled == True,
|
||||||
|
DatasetDocument.archived == False,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
if dataset_documents:
|
||||||
|
# save vector index
|
||||||
|
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
|
||||||
|
for dataset_document in dataset_documents:
|
||||||
|
# delete from vector index
|
||||||
segments = db.session.query(DocumentSegment).filter(
|
segments = db.session.query(DocumentSegment).filter(
|
||||||
DocumentSegment.document_id == document.id,
|
DocumentSegment.document_id == dataset_document.id,
|
||||||
DocumentSegment.enabled == True
|
DocumentSegment.enabled == True
|
||||||
) .order_by(DocumentSegment.position.asc()).all()
|
) .order_by(DocumentSegment.position.asc()).all()
|
||||||
|
|
||||||
nodes = []
|
documents = []
|
||||||
previous_node = None
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
relationships = {
|
document = Document(
|
||||||
DocumentRelationship.SOURCE: document.id
|
page_content=segment.content,
|
||||||
}
|
metadata={
|
||||||
|
"doc_id": segment.index_node_id,
|
||||||
if previous_node:
|
"doc_hash": segment.index_node_hash,
|
||||||
relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id
|
"document_id": segment.document_id,
|
||||||
|
"dataset_id": segment.dataset_id,
|
||||||
previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id
|
}
|
||||||
|
|
||||||
node = Node(
|
|
||||||
doc_id=segment.index_node_id,
|
|
||||||
doc_hash=segment.index_node_hash,
|
|
||||||
text=segment.content,
|
|
||||||
extra_info=None,
|
|
||||||
node_info=None,
|
|
||||||
relationships=relationships
|
|
||||||
)
|
)
|
||||||
|
|
||||||
previous_node = node
|
documents.append(document)
|
||||||
nodes.append(node)
|
|
||||||
# save vector index
|
# save vector index
|
||||||
vector_index.add_nodes(
|
index.add_texts(documents)
|
||||||
nodes=nodes,
|
|
||||||
duplicate_check=True
|
|
||||||
)
|
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(
|
logging.info(
|
||||||
|
@ -6,11 +6,9 @@ import click
|
|||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from core.data_source.notion import NotionPageReader
|
from core.data_loader.loader.notion import NotionLoader
|
||||||
from core.index.keyword_table_index import KeywordTableIndex
|
from core.index.index import IndexBuilder
|
||||||
from core.index.vector_index import VectorIndex
|
|
||||||
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
|
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
|
||||||
from core.llm.error import ProviderTokenNotInitError
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Document, Dataset, DocumentSegment
|
from models.dataset import Document, Dataset, DocumentSegment
|
||||||
from models.source import DataSourceBinding
|
from models.source import DataSourceBinding
|
||||||
@ -43,6 +41,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||||||
raise ValueError("no notion page found")
|
raise ValueError("no notion page found")
|
||||||
workspace_id = data_source_info['notion_workspace_id']
|
workspace_id = data_source_info['notion_workspace_id']
|
||||||
page_id = data_source_info['notion_page_id']
|
page_id = data_source_info['notion_page_id']
|
||||||
|
page_type = data_source_info['type']
|
||||||
page_edited_time = data_source_info['last_edited_time']
|
page_edited_time = data_source_info['last_edited_time']
|
||||||
data_source_binding = DataSourceBinding.query.filter(
|
data_source_binding = DataSourceBinding.query.filter(
|
||||||
db.and_(
|
db.and_(
|
||||||
@ -54,8 +53,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||||||
).first()
|
).first()
|
||||||
if not data_source_binding:
|
if not data_source_binding:
|
||||||
raise ValueError('Data source binding not found.')
|
raise ValueError('Data source binding not found.')
|
||||||
reader = NotionPageReader(integration_token=data_source_binding.access_token)
|
|
||||||
last_edited_time = reader.get_page_last_edited_time(page_id)
|
loader = NotionLoader(
|
||||||
|
notion_access_token=data_source_binding.access_token,
|
||||||
|
notion_workspace_id=workspace_id,
|
||||||
|
notion_obj_id=page_id,
|
||||||
|
notion_page_type=page_type
|
||||||
|
)
|
||||||
|
|
||||||
|
last_edited_time = loader.get_notion_last_edited_time()
|
||||||
|
|
||||||
# check the page is updated
|
# check the page is updated
|
||||||
if last_edited_time != page_edited_time:
|
if last_edited_time != page_edited_time:
|
||||||
document.indexing_status = 'parsing'
|
document.indexing_status = 'parsing'
|
||||||
@ -68,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception('Dataset not found')
|
raise Exception('Dataset not found')
|
||||||
|
|
||||||
vector_index = VectorIndex(dataset=dataset)
|
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
|
|
||||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
|
||||||
# delete from vector index
|
# delete from vector index
|
||||||
vector_index.del_nodes(index_node_ids)
|
if vector_index:
|
||||||
|
vector_index.delete_by_document_id(document_id)
|
||||||
|
|
||||||
# delete from keyword index
|
# delete from keyword index
|
||||||
if index_node_ids:
|
if index_node_ids:
|
||||||
keyword_table_index.del_nodes(index_node_ids)
|
kw_index.delete_by_ids(index_node_ids)
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
db.session.delete(segment)
|
db.session.delete(segment)
|
||||||
@ -89,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||||||
click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
|
click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Cleaned document when document update data source or process rule failed")
|
logging.exception("Cleaned document when document update data source or process rule failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
indexing_runner.run([document])
|
indexing_runner.run([document])
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
|
logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
|
||||||
except DocumentIsPausedException:
|
except DocumentIsPausedException as ex:
|
||||||
logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow'))
|
logging.info(click.style(str(ex), fg='yellow'))
|
||||||
except ProviderTokenNotInitError as e:
|
except Exception:
|
||||||
document.indexing_status = 'error'
|
pass
|
||||||
document.error = str(e.description)
|
|
||||||
document.stopped_at = datetime.datetime.utcnow()
|
|
||||||
db.session.commit()
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("consume update document failed")
|
|
||||||
document.indexing_status = 'error'
|
|
||||||
document.error = str(e)
|
|
||||||
document.stopped_at = datetime.datetime.utcnow()
|
|
||||||
db.session.commit()
|
|
||||||
|
@ -7,7 +7,6 @@ from celery import shared_task
|
|||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
|
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
|
||||||
from core.llm.error import ProviderTokenNotInitError
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Document
|
from models.dataset import Document
|
||||||
|
|
||||||
@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
|||||||
Usage: document_indexing_task.delay(dataset_id, document_id)
|
Usage: document_indexing_task.delay(dataset_id, document_id)
|
||||||
"""
|
"""
|
||||||
documents = []
|
documents = []
|
||||||
|
start_at = time.perf_counter()
|
||||||
for document_id in document_ids:
|
for document_id in document_ids:
|
||||||
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
|
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
|
||||||
start_at = time.perf_counter()
|
|
||||||
|
|
||||||
document = db.session.query(Document).filter(
|
document = db.session.query(Document).filter(
|
||||||
Document.id == document_id,
|
Document.id == document_id,
|
||||||
@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
|||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
indexing_runner.run(documents)
|
indexing_runner.run(documents)
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
|
logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
|
||||||
except DocumentIsPausedException:
|
except DocumentIsPausedException as ex:
|
||||||
logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow'))
|
logging.info(click.style(str(ex), fg='yellow'))
|
||||||
except ProviderTokenNotInitError as e:
|
except Exception:
|
||||||
document.indexing_status = 'error'
|
pass
|
||||||
document.error = str(e.description)
|
|
||||||
document.stopped_at = datetime.datetime.utcnow()
|
|
||||||
db.session.commit()
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("consume document failed")
|
|
||||||
document.indexing_status = 'error'
|
|
||||||
document.error = str(e)
|
|
||||||
document.stopped_at = datetime.datetime.utcnow()
|
|
||||||
db.session.commit()
|
|
||||||
|
@ -6,10 +6,8 @@ import click
|
|||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from core.index.keyword_table_index import KeywordTableIndex
|
from core.index.index import IndexBuilder
|
||||||
from core.index.vector_index import VectorIndex
|
|
||||||
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
|
from core.indexing_runner import IndexingRunner, DocumentIsPausedException
|
||||||
from core.llm.error import ProviderTokenNotInitError
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Document, Dataset, DocumentSegment
|
from models.dataset import Document, Dataset, DocumentSegment
|
||||||
|
|
||||||
@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception('Dataset not found')
|
raise Exception('Dataset not found')
|
||||||
|
|
||||||
vector_index = VectorIndex(dataset=dataset)
|
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
|
|
||||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
|
||||||
# delete from vector index
|
# delete from vector index
|
||||||
vector_index.del_nodes(index_node_ids)
|
if vector_index:
|
||||||
|
vector_index.delete_by_ids(index_node_ids)
|
||||||
|
|
||||||
# delete from keyword index
|
# delete from keyword index
|
||||||
if index_node_ids:
|
if index_node_ids:
|
||||||
keyword_table_index.del_nodes(index_node_ids)
|
kw_index.delete_by_ids(index_node_ids)
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
db.session.delete(segment)
|
db.session.delete(segment)
|
||||||
@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||||||
click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
|
click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Cleaned document when document update data source or process rule failed")
|
logging.exception("Cleaned document when document update data source or process rule failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
indexing_runner.run([document])
|
indexing_runner.run([document])
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
|
logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
|
||||||
except DocumentIsPausedException:
|
except DocumentIsPausedException as ex:
|
||||||
logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow'))
|
logging.info(click.style(str(ex), fg='yellow'))
|
||||||
except ProviderTokenNotInitError as e:
|
except Exception:
|
||||||
document.indexing_status = 'error'
|
pass
|
||||||
document.error = str(e.description)
|
|
||||||
document.stopped_at = datetime.datetime.utcnow()
|
|
||||||
db.session.commit()
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("consume update document failed")
|
|
||||||
document.indexing_status = 'error'
|
|
||||||
document.error = str(e)
|
|
||||||
document.stopped_at = datetime.datetime.utcnow()
|
|
||||||
db.session.commit()
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import datetime
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
|
|||||||
indexing_runner.run_in_indexing_status(document)
|
indexing_runner.run_in_indexing_status(document)
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
|
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
|
||||||
except DocumentIsPausedException:
|
except DocumentIsPausedException as ex:
|
||||||
logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow'))
|
logging.info(click.style(str(ex), fg='yellow'))
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logging.exception("consume document failed")
|
pass
|
||||||
document.indexing_status = 'error'
|
|
||||||
document.error = str(e)
|
|
||||||
document.stopped_at = datetime.datetime.utcnow()
|
|
||||||
db.session.commit()
|
|
||||||
|
@ -5,8 +5,7 @@ import click
|
|||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from core.index.keyword_table_index import KeywordTableIndex
|
from core.index.index import IndexBuilder
|
||||||
from core.index.vector_index import VectorIndex
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import DocumentSegment, Document
|
from models.dataset import DocumentSegment, Document
|
||||||
@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str):
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception('Document has no dataset')
|
raise Exception('Document has no dataset')
|
||||||
|
|
||||||
vector_index = VectorIndex(dataset=dataset)
|
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
|
|
||||||
# delete from vector index
|
# delete from vector index
|
||||||
vector_index.del_doc(document.id)
|
vector_index.delete_by_document_id(document.id)
|
||||||
|
|
||||||
# delete from keyword index
|
# delete from keyword index
|
||||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()
|
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
if index_node_ids:
|
if index_node_ids:
|
||||||
keyword_table_index.del_nodes(index_node_ids)
|
kw_index.delete_by_ids(index_node_ids)
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(
|
logging.info(
|
||||||
|
@ -5,8 +5,7 @@ import click
|
|||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from core.index.keyword_table_index import KeywordTableIndex
|
from core.index.index import IndexBuilder
|
||||||
from core.index.vector_index import VectorIndex
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import DocumentSegment
|
from models.dataset import DocumentSegment
|
||||||
@ -36,17 +35,28 @@ def remove_segment_from_index_task(segment_id: str):
|
|||||||
dataset = segment.dataset
|
dataset = segment.dataset
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception('Segment has no dataset')
|
logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
|
||||||
|
return
|
||||||
|
|
||||||
vector_index = VectorIndex(dataset=dataset)
|
dataset_document = segment.document
|
||||||
keyword_table_index = KeywordTableIndex(dataset=dataset)
|
|
||||||
|
if not dataset_document:
|
||||||
|
logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
|
||||||
|
return
|
||||||
|
|
||||||
|
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
|
||||||
|
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
|
||||||
|
return
|
||||||
|
|
||||||
|
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||||
|
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||||
|
|
||||||
# delete from vector index
|
# delete from vector index
|
||||||
if dataset.indexing_technique == "high_quality":
|
if vector_index:
|
||||||
vector_index.del_nodes([segment.index_node_id])
|
vector_index.delete_by_ids([segment.index_node_id])
|
||||||
|
|
||||||
# delete from keyword index
|
# delete from keyword index
|
||||||
keyword_table_index.del_nodes([segment.index_node_id])
|
kw_index.delete_by_ids([segment.index_node_id])
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
|
logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user