From 4f1a56f0f0322bdd132d32134a57ca96f948c1a3 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 8 Nov 2024 17:32:27 +0800 Subject: [PATCH] update document and segment word count (#10449) --- api/services/dataset_service.py | 38 +++++++++++++++++-- .../batch_create_segment_to_index_task.py | 7 +++- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index fcf7bffdc9..8562dad1d3 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1414,9 +1414,13 @@ class SegmentService: created_by=current_user.id, ) if document.doc_form == "qa_model": + segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] db.session.add(segment_document) + # update document word count + document.word_count += segment_document.word_count + db.session.add(document) db.session.commit() # save vector index @@ -1435,6 +1439,7 @@ class SegmentService: @classmethod def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): lock_name = "multi_add_segment_lock_document_id_{}".format(document.id) + increment_word_count = 0 with redis_client.lock(lock_name, timeout=600): embedding_model = None if dataset.indexing_technique == "high_quality": @@ -1460,7 +1465,10 @@ class SegmentService: tokens = 0 if dataset.indexing_technique == "high_quality" and embedding_model: # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + if document.doc_form == "qa_model": + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]]) + else: + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1478,6 +1486,8 @@ class SegmentService: ) if document.doc_form == "qa_model": segment_document.answer = segment_item["answer"] + segment_document.word_count += len(segment_item["answer"]) + increment_word_count += segment_document.word_count db.session.add(segment_document) segment_data_list.append(segment_document) @@ -1486,7 +1496,9 @@ class SegmentService: keywords_list.append(segment_item["keywords"]) else: keywords_list.append(None) - + # update document word count + document.word_count += increment_word_count + db.session.add(document) try: # save vector index VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) @@ -1527,10 +1539,14 @@ class SegmentService: else: raise ValueError("Can't update disabled segment") try: + word_count_change = segment.word_count content = segment_update_entity.content if segment.content == content: + segment.word_count = len(content) if document.doc_form == "qa_model": segment.answer = segment_update_entity.answer + segment.word_count += len(segment_update_entity.answer) + word_count_change = segment.word_count - word_count_change if segment_update_entity.keywords: segment.keywords = segment_update_entity.keywords segment.enabled = True @@ -1538,6 +1554,10 @@ class SegmentService: segment.disabled_by = None db.session.add(segment) db.session.commit() + # update document word count + if word_count_change != 0: + document.word_count = max(0, document.word_count + word_count_change) + db.session.add(document) # update segment index task if segment_update_entity.enabled: VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset) @@ -1554,7 +1574,10 @@ class SegmentService: ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + if document.doc_form == "qa_model": + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer]) + else: + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) segment.content = content segment.index_node_hash = segment_hash segment.word_count = len(content) @@ -1569,6 +1592,12 @@ class SegmentService: segment.disabled_by = None if document.doc_form == "qa_model": segment.answer = segment_update_entity.answer + segment.word_count += len(segment_update_entity.answer) + word_count_change = segment.word_count - word_count_change + # update document word count + if word_count_change != 0: + document.word_count = max(0, document.word_count + word_count_change) + db.session.add(document) db.session.add(segment) db.session.commit() # update segment vector index @@ -1597,6 +1626,9 @@ class SegmentService: redis_client.setex(indexing_cache_key, 600, 1) delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) db.session.delete(segment) + # update document word count + document.word_count -= segment.word_count + db.session.add(document) db.session.commit() diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index de7f0ddec1..d1b41f2675 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -57,7 +57,7 @@ def batch_create_segment_to_index_task( model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) - + word_count_change = 0 for segment in content: content = segment["content"] doc_id = str(uuid.uuid4()) @@ -86,8 +86,13 @@ def batch_create_segment_to_index_task( ) if dataset_document.doc_form == "qa_model": segment_document.answer = segment["answer"] + segment_document.word_count += len(segment["answer"]) + word_count_change += segment_document.word_count db.session.add(segment_document) document_segments.append(segment_document) + # update document word count + dataset_document.word_count += word_count_change + db.session.add(dataset_document) # add index to db indexing_runner = IndexingRunner() indexing_runner.batch_add_segments(document_segments, dataset)