From db726e02a0888f4ee820b80e5c0e934f915d8379 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 26 Nov 2024 18:59:03 +0800 Subject: [PATCH] feat: support multi token count --- api/core/indexing_runner.py | 6 ++---- api/core/model_manager.py | 4 ++-- api/core/rag/docstore/dataset_docstore.py | 14 +++++++------- api/core/rag/splitter/fixed_text_splitter.py | 8 +++++--- api/core/rag/splitter/text_splitter.py | 11 +++++------ api/services/dataset_service.py | 13 ++++++++----- api/tasks/batch_create_segment_to_index_task.py | 10 +++++++--- 7 files changed, 36 insertions(+), 30 deletions(-) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 388cbc0be3..3bcd4c2a4d 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -720,10 +720,8 @@ class IndexingRunner: tokens = 0 if embedding_model_instance: - tokens += sum( - embedding_model_instance.get_text_embedding_num_tokens([document.page_content]) - for document in chunk_documents - ) + page_content_list = [document.page_content for document in chunk_documents] + tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list)) # load index index_processor.load(dataset, chunk_documents, with_keywords=False) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index fc7ae18783..5956ea1ae9 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -175,7 +175,7 @@ class ModelInstance: def get_llm_num_tokens( self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None - ) -> int: + ) -> list[int]: """ Get number of tokens for llm @@ -235,7 +235,7 @@ class ModelInstance: model=self.model, credentials=self.credentials, texts=texts, - )[0] # TODO: fix this, this is only for temporary compatibility with old + ) def invoke_rerank( self, diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 319a2612c7..306f0c27ea 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -79,7 +79,13 @@ class DatasetDocumentStore: model=self._dataset.embedding_model, ) - for doc in docs: + if embedding_model: + page_content_list = [doc.page_content for doc in docs] + tokens_list = embedding_model.get_text_embedding_num_tokens(page_content_list) + else: + tokens_list = [0] * len(docs) + + for doc, tokens in zip(docs, tokens_list): if not isinstance(doc, Document): raise ValueError("doc must be a Document") @@ -91,12 +97,6 @@ class DatasetDocumentStore: f"doc_id {doc.metadata['doc_id']} already exists. Set allow_update to True to overwrite." ) - # calc embedding use tokens - if embedding_model: - tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content]) - else: - tokens = 0 - if not segment_document: max_position += 1 diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 53032b34d5..e0cd3e53f1 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -65,8 +65,9 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) chunks = [text] final_chunks = [] - for chunk in chunks: - if self._length_function(chunk) > self._chunk_size: + chunks_lengths = self._length_function(chunks) + for chunk, chunk_length in zip(chunks, chunks_lengths): + if chunk_length > self._chunk_size: final_chunks.extend(self.recursive_split_text(chunk)) else: final_chunks.append(chunk) @@ -93,7 +94,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) # Now go merging things, recursively splitting longer texts. _good_splits = [] _good_splits_lengths = [] # cache the lengths of the splits - for s in splits: + s_lens = self._length_function(splits) + for s, s_len in zip(splits, s_lens): s_len = self._length_function(s) if s_len < self._chunk_size: _good_splits.append(s) diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 7dd62f8de1..89a9650b68 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -45,7 +45,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): self, chunk_size: int = 4000, chunk_overlap: int = 200, - length_function: Callable[[str], int] = len, + length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x], keep_separator: bool = False, add_start_index: bool = False, ) -> None: @@ -224,8 +224,8 @@ class CharacterTextSplitter(TextSplitter): splits = _split_text_with_regex(text, self._separator, self._keep_separator) _separator = "" if self._keep_separator else self._separator _good_splits_lengths = [] # cache the lengths of the splits - for split in splits: - _good_splits_lengths.append(self._length_function(split)) + if splits: + _good_splits_lengths.extend(self._length_function(splits)) return self._merge_splits(splits, _separator, _good_splits_lengths) @@ -478,9 +478,8 @@ class RecursiveCharacterTextSplitter(TextSplitter): _good_splits = [] _good_splits_lengths = [] # cache the lengths of the splits _separator = "" if self._keep_separator else separator - - for s in splits: - s_len = self._length_function(s) + s_lens = self._length_function(splits) + for s, s_len in zip(splits, s_lens): if s_len < self._chunk_size: _good_splits.append(s) _good_splits_lengths.append(s_len) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index d38729f31e..e76a660494 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1390,7 +1390,7 @@ class SegmentService: model=dataset.embedding_model, ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] lock_name = "add_segment_lock_document_id_{}".format(document.id) with redis_client.lock(lock_name, timeout=600): max_position = ( @@ -1467,9 +1467,12 @@ class SegmentService: if dataset.indexing_technique == "high_quality" and embedding_model: # calc embedding use tokens if document.doc_form == "qa_model": - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]]) + tokens = embedding_model.get_text_embedding_num_tokens( + texts=[content + segment_item["answer"]] + )[0] else: - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] + segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1577,9 +1580,9 @@ class SegmentService: # calc embedding use tokens if document.doc_form == "qa_model": - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer]) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] else: - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] segment.content = content segment.index_node_hash = segment_hash segment.word_count = len(content) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index dcb7009e44..41e1419d25 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -58,12 +58,16 @@ def batch_create_segment_to_index_task( model=dataset.embedding_model, ) word_count_change = 0 - for segment in content: + if embedding_model: + tokens_list = embedding_model.get_text_embedding_num_tokens( + texts=[segment["content"] for segment in content] + ) + else: + tokens_list = [0] * len(content) + for segment, tokens in zip(content, tokens_list): content = segment["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) - # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0 max_position = ( db.session.query(func.max(DocumentSegment.position)) .filter(DocumentSegment.document_id == dataset_document.id)