diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 977fa3b62a..0f4cbccff7 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -48,6 +48,9 @@ class Jieba(BaseKeyword): text = texts[i] if keywords_list: keywords = keywords_list[i] + if not keywords: + keywords = keyword_table_handler.extract_keywords(text.page_content, + self._config.max_keywords_per_chunk) else: keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 44a48af58b..78b191ac31 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1046,73 +1046,11 @@ class SegmentService: credentials=embedding_model.credentials, texts=[content] ) - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() - segment_document = DocumentSegment( - tenant_id=current_user.current_tenant_id, - dataset_id=document.dataset_id, - document_id=document.id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=max_position + 1 if max_position else 1, - content=content, - word_count=len(content), - tokens=tokens, - status='completed', - indexing_at=datetime.datetime.utcnow(), - completed_at=datetime.datetime.utcnow(), - created_by=current_user.id - ) - if document.doc_form == 'qa_model': - segment_document.answer = args['answer'] - - db.session.add(segment_document) - db.session.commit() - - # save vector index - try: - VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) - except Exception as e: - logging.exception("create segment index failed") - segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.utcnow() - segment_document.status = 'error' - segment_document.error = str(e) - db.session.commit() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() - return segment - - @classmethod - def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): - embedding_model = None - if dataset.indexing_technique == 'high_quality': - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() - pre_segment_data_list = [] - segment_data_list = [] - keywords_list = [] - for segment_item in segments: - content = segment_item['content'] - doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) - tokens = 0 - if dataset.indexing_technique == 'high_quality' and embedding_model: - # calc embedding use tokens - model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) - tokens = model_type_instance.get_num_tokens( - model=embedding_model.model, - credentials=embedding_model.credentials, - texts=[content] - ) + lock_name = 'add_segment_lock_document_id_{}'.format(document.id) + with redis_client.lock(lock_name, timeout=600): + max_position = db.session.query(func.max(DocumentSegment.position)).filter( + DocumentSegment.document_id == document.id + ).scalar() segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1129,25 +1067,91 @@ class SegmentService: created_by=current_user.id ) if document.doc_form == 'qa_model': - segment_document.answer = segment_item['answer'] + segment_document.answer = args['answer'] + db.session.add(segment_document) - segment_data_list.append(segment_document) + db.session.commit() - pre_segment_data_list.append(segment_document) - keywords_list.append(segment_item['keywords']) - - try: # save vector index - VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) - except Exception as e: - logging.exception("create segment index failed") - for segment_document in segment_data_list: + try: + VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) + except Exception as e: + logging.exception("create segment index failed") segment_document.enabled = False segment_document.disabled_at = datetime.datetime.utcnow() segment_document.status = 'error' segment_document.error = str(e) - db.session.commit() - return segment_data_list + db.session.commit() + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() + return segment + + @classmethod + def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): + lock_name = 'multi_add_segment_lock_document_id_{}'.format(document.id) + with redis_client.lock(lock_name, timeout=600): + embedding_model = None + if dataset.indexing_technique == 'high_quality': + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model + ) + max_position = db.session.query(func.max(DocumentSegment.position)).filter( + DocumentSegment.document_id == document.id + ).scalar() + pre_segment_data_list = [] + segment_data_list = [] + keywords_list = [] + for segment_item in segments: + content = segment_item['content'] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + tokens = 0 + if dataset.indexing_technique == 'high_quality' and embedding_model: + # calc embedding use tokens + model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) + tokens = model_type_instance.get_num_tokens( + model=embedding_model.model, + credentials=embedding_model.credentials, + texts=[content] + ) + segment_document = DocumentSegment( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + status='completed', + indexing_at=datetime.datetime.utcnow(), + completed_at=datetime.datetime.utcnow(), + created_by=current_user.id + ) + if document.doc_form == 'qa_model': + segment_document.answer = segment_item['answer'] + db.session.add(segment_document) + segment_data_list.append(segment_document) + + pre_segment_data_list.append(segment_document) + keywords_list.append(segment_item['keywords']) + + try: + # save vector index + VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) + except Exception as e: + logging.exception("create segment index failed") + for segment_document in segment_data_list: + segment_document.enabled = False + segment_document.disabled_at = datetime.datetime.utcnow() + segment_document.status = 'error' + segment_document.error = str(e) + db.session.commit() + return segment_data_list @classmethod def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): diff --git a/web/app/components/base/tag-input/index.tsx b/web/app/components/base/tag-input/index.tsx index d974280ec0..dc6dfa98a1 100644 --- a/web/app/components/base/tag-input/index.tsx +++ b/web/app/components/base/tag-input/index.tsx @@ -56,7 +56,9 @@ const TagInput: FC = ({ } onChange([...items, valueTrimed]) - setValue('') + setTimeout(() => { + setValue('') + }) } }