feat: support multi token count

This commit is contained in:
Yeuoly 2024-11-26 18:59:03 +08:00
parent e4b8220bc2
commit db726e02a0
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
7 changed files with 36 additions and 30 deletions

View File

@ -720,10 +720,8 @@ class IndexingRunner:
tokens = 0 tokens = 0
if embedding_model_instance: if embedding_model_instance:
tokens += sum( page_content_list = [document.page_content for document in chunk_documents]
embedding_model_instance.get_text_embedding_num_tokens([document.page_content]) tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))
for document in chunk_documents
)
# load index # load index
index_processor.load(dataset, chunk_documents, with_keywords=False) index_processor.load(dataset, chunk_documents, with_keywords=False)

View File

@ -175,7 +175,7 @@ class ModelInstance:
def get_llm_num_tokens( def get_llm_num_tokens(
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> int: ) -> list[int]:
""" """
Get number of tokens for llm Get number of tokens for llm
@ -235,7 +235,7 @@ class ModelInstance:
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
texts=texts, texts=texts,
)[0] # TODO: fix this, this is only for temporary compatibility with old )
def invoke_rerank( def invoke_rerank(
self, self,

View File

@ -79,7 +79,13 @@ class DatasetDocumentStore:
model=self._dataset.embedding_model, model=self._dataset.embedding_model,
) )
for doc in docs: if embedding_model:
page_content_list = [doc.page_content for doc in docs]
tokens_list = embedding_model.get_text_embedding_num_tokens(page_content_list)
else:
tokens_list = [0] * len(docs)
for doc, tokens in zip(docs, tokens_list):
if not isinstance(doc, Document): if not isinstance(doc, Document):
raise ValueError("doc must be a Document") raise ValueError("doc must be a Document")
@ -91,12 +97,6 @@ class DatasetDocumentStore:
f"doc_id {doc.metadata['doc_id']} already exists. Set allow_update to True to overwrite." f"doc_id {doc.metadata['doc_id']} already exists. Set allow_update to True to overwrite."
) )
# calc embedding use tokens
if embedding_model:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content])
else:
tokens = 0
if not segment_document: if not segment_document:
max_position += 1 max_position += 1

View File

@ -65,8 +65,9 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
chunks = [text] chunks = [text]
final_chunks = [] final_chunks = []
for chunk in chunks: chunks_lengths = self._length_function(chunks)
if self._length_function(chunk) > self._chunk_size: for chunk, chunk_length in zip(chunks, chunks_lengths):
if chunk_length > self._chunk_size:
final_chunks.extend(self.recursive_split_text(chunk)) final_chunks.extend(self.recursive_split_text(chunk))
else: else:
final_chunks.append(chunk) final_chunks.append(chunk)
@ -93,7 +94,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
# Now go merging things, recursively splitting longer texts. # Now go merging things, recursively splitting longer texts.
_good_splits = [] _good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits _good_splits_lengths = [] # cache the lengths of the splits
for s in splits: s_lens = self._length_function(splits)
for s, s_len in zip(splits, s_lens):
s_len = self._length_function(s) s_len = self._length_function(s)
if s_len < self._chunk_size: if s_len < self._chunk_size:
_good_splits.append(s) _good_splits.append(s)

View File

@ -45,7 +45,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
self, self,
chunk_size: int = 4000, chunk_size: int = 4000,
chunk_overlap: int = 200, chunk_overlap: int = 200,
length_function: Callable[[str], int] = len, length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x],
keep_separator: bool = False, keep_separator: bool = False,
add_start_index: bool = False, add_start_index: bool = False,
) -> None: ) -> None:
@ -224,8 +224,8 @@ class CharacterTextSplitter(TextSplitter):
splits = _split_text_with_regex(text, self._separator, self._keep_separator) splits = _split_text_with_regex(text, self._separator, self._keep_separator)
_separator = "" if self._keep_separator else self._separator _separator = "" if self._keep_separator else self._separator
_good_splits_lengths = [] # cache the lengths of the splits _good_splits_lengths = [] # cache the lengths of the splits
for split in splits: if splits:
_good_splits_lengths.append(self._length_function(split)) _good_splits_lengths.extend(self._length_function(splits))
return self._merge_splits(splits, _separator, _good_splits_lengths) return self._merge_splits(splits, _separator, _good_splits_lengths)
@ -478,9 +478,8 @@ class RecursiveCharacterTextSplitter(TextSplitter):
_good_splits = [] _good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits _good_splits_lengths = [] # cache the lengths of the splits
_separator = "" if self._keep_separator else separator _separator = "" if self._keep_separator else separator
s_lens = self._length_function(splits)
for s in splits: for s, s_len in zip(splits, s_lens):
s_len = self._length_function(s)
if s_len < self._chunk_size: if s_len < self._chunk_size:
_good_splits.append(s) _good_splits.append(s)
_good_splits_lengths.append(s_len) _good_splits_lengths.append(s_len)

View File

@ -1390,7 +1390,7 @@ class SegmentService:
model=dataset.embedding_model, model=dataset.embedding_model,
) )
# calc embedding use tokens # calc embedding use tokens
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
lock_name = "add_segment_lock_document_id_{}".format(document.id) lock_name = "add_segment_lock_document_id_{}".format(document.id)
with redis_client.lock(lock_name, timeout=600): with redis_client.lock(lock_name, timeout=600):
max_position = ( max_position = (
@ -1467,9 +1467,12 @@ class SegmentService:
if dataset.indexing_technique == "high_quality" and embedding_model: if dataset.indexing_technique == "high_quality" and embedding_model:
# calc embedding use tokens # calc embedding use tokens
if document.doc_form == "qa_model": if document.doc_form == "qa_model":
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]]) tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content + segment_item["answer"]]
)[0]
else: else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
segment_document = DocumentSegment( segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id, dataset_id=document.dataset_id,
@ -1577,9 +1580,9 @@ class SegmentService:
# calc embedding use tokens # calc embedding use tokens
if document.doc_form == "qa_model": if document.doc_form == "qa_model":
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer]) tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0]
else: else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
segment.content = content segment.content = content
segment.index_node_hash = segment_hash segment.index_node_hash = segment_hash
segment.word_count = len(content) segment.word_count = len(content)

View File

@ -58,12 +58,16 @@ def batch_create_segment_to_index_task(
model=dataset.embedding_model, model=dataset.embedding_model,
) )
word_count_change = 0 word_count_change = 0
for segment in content: if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
texts=[segment["content"] for segment in content]
)
else:
tokens_list = [0] * len(content)
for segment, tokens in zip(content, tokens_list):
content = segment["content"] content = segment["content"]
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content) segment_hash = helper.generate_text_hash(content)
# calc embedding use tokens
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0
max_position = ( max_position = (
db.session.query(func.max(DocumentSegment.position)) db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == dataset_document.id) .filter(DocumentSegment.document_id == dataset_document.id)