fix keyword index error when storage source is S3 (#3182)

This commit is contained in:
Jyong 2024-04-09 01:42:58 +08:00 committed by GitHub
parent a81c1ab6ae
commit 283979fc46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 110 additions and 76 deletions

View File

@ -19,6 +19,7 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -657,18 +658,25 @@ class IndexingRunner:
if embedding_model_instance: if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: # create keyword index
futures = [] create_keyword_thread = threading.Thread(target=self._process_keyword_index,
for i in range(0, len(documents), chunk_size): args=(current_app._get_current_object(),
chunk_documents = documents[i:i + chunk_size] dataset, dataset_document, documents))
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor, create_keyword_thread.start()
chunk_documents, dataset, if dataset.indexing_technique == 'high_quality':
dataset_document, embedding_model_instance, with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
embedding_model_type_instance)) futures = []
for i in range(0, len(documents), chunk_size):
chunk_documents = documents[i:i + chunk_size]
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
chunk_documents, dataset,
dataset_document, embedding_model_instance,
embedding_model_type_instance))
for future in futures: for future in futures:
tokens += future.result() tokens += future.result()
create_keyword_thread.join()
indexing_end_at = time.perf_counter() indexing_end_at = time.perf_counter()
# update document status to completed # update document status to completed
@ -682,6 +690,24 @@ class IndexingRunner:
} }
) )
def _process_keyword_index(self, flask_app, dataset, dataset_document, documents):
with flask_app.app_context():
keyword = Keyword(dataset)
keyword.create(documents)
if dataset.indexing_technique != 'high_quality':
document_ids = [document.metadata['doc_id'] for document in documents]
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing"
).update({
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.utcnow()
})
db.session.commit()
def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document, def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
embedding_model_instance, embedding_model_type_instance): embedding_model_instance, embedding_model_type_instance):
with flask_app.app_context(): with flask_app.app_context():
@ -700,7 +726,7 @@ class IndexingRunner:
) )
# load index # load index
index_processor.load(dataset, chunk_documents) index_processor.load(dataset, chunk_documents, with_keywords=False)
document_ids = [document.metadata['doc_id'] for document in chunk_documents] document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(

View File

@ -24,56 +24,64 @@ class Jieba(BaseKeyword):
self._config = KeywordTableConfig() self._config = KeywordTableConfig()
def create(self, texts: list[Document], **kwargs) -> BaseKeyword: def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
keyword_table_handler = JiebaKeywordTableHandler() lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
keyword_table = self._get_dataset_keyword_table() with redis_client.lock(lock_name, timeout=600):
for text in texts: keyword_table_handler = JiebaKeywordTableHandler()
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) keyword_table = self._get_dataset_keyword_table()
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) for text in texts:
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
return self return self
def add_texts(self, texts: list[Document], **kwargs): def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler() lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get('keywords_list', None) keywords_list = kwargs.get('keywords_list', None)
for i in range(len(texts)): for i in range(len(texts)):
text = texts[i] text = texts[i]
if keywords_list: if keywords_list:
keywords = keywords_list[i] keywords = keywords_list[i]
else: else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
return id in set.union(*keyword_table.values()) return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
keyword_table = self._get_dataset_keyword_table() lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
# get segment ids by document_id lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
segments = db.session.query(DocumentSegment).filter( with redis_client.lock(lock_name, timeout=600):
DocumentSegment.dataset_id == self.dataset.id, # get segment ids by document_id
DocumentSegment.document_id == document_id segments = db.session.query(DocumentSegment).filter(
).all() DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()
ids = [segment.index_node_id for segment in segments] ids = [segment.index_node_id for segment in segments]
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
def search( def search(
self, query: str, self, query: str,
@ -106,13 +114,15 @@ class Jieba(BaseKeyword):
return documents return documents
def delete(self) -> None: def delete(self) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
if dataset_keyword_table: with redis_client.lock(lock_name, timeout=600):
db.session.delete(dataset_keyword_table) dataset_keyword_table = self.dataset.dataset_keyword_table
db.session.commit() if dataset_keyword_table:
if dataset_keyword_table.data_source_type != 'database': db.session.delete(dataset_keyword_table)
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' db.session.commit()
storage.delete(file_key) if dataset_keyword_table.data_source_type != 'database':
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
storage.delete(file_key)
def _save_dataset_keyword_table(self, keyword_table): def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = { keyword_table_dict = {
@ -135,33 +145,31 @@ class Jieba(BaseKeyword):
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8')) storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8'))
def _get_dataset_keyword_table(self) -> Optional[dict]: def _get_dataset_keyword_table(self) -> Optional[dict]:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) dataset_keyword_table = self.dataset.dataset_keyword_table
with redis_client.lock(lock_name, timeout=20): if dataset_keyword_table:
dataset_keyword_table = self.dataset.dataset_keyword_table keyword_table_dict = dataset_keyword_table.keyword_table_dict
if dataset_keyword_table: if keyword_table_dict:
keyword_table_dict = dataset_keyword_table.keyword_table_dict return keyword_table_dict['__data__']['table']
if keyword_table_dict: else:
return keyword_table_dict['__data__']['table'] keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
else: dataset_keyword_table = DatasetKeywordTable(
keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE'] dataset_id=self.dataset.id,
dataset_keyword_table = DatasetKeywordTable( keyword_table='',
dataset_id=self.dataset.id, data_source_type=keyword_data_source_type,
keyword_table='', )
data_source_type=keyword_data_source_type, if keyword_data_source_type == 'database':
) dataset_keyword_table.keyword_table = json.dumps({
if keyword_data_source_type == 'database': '__type__': 'keyword_table',
dataset_keyword_table.keyword_table = json.dumps({ '__data__': {
'__type__': 'keyword_table', "index_id": self.dataset.id,
'__data__': { "summary": None,
"index_id": self.dataset.id, "table": {}
"summary": None, }
"table": {} }, cls=SetEncoder)
} db.session.add(dataset_keyword_table)
}, cls=SetEncoder) db.session.commit()
db.session.add(dataset_keyword_table)
db.session.commit()
return {} return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict: def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
for keyword in keywords: for keyword in keywords: