mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-20 16:09:11 +08:00
feat: support multi token count
This commit is contained in:
parent
e4b8220bc2
commit
db726e02a0
@ -720,10 +720,8 @@ class IndexingRunner:
|
|||||||
|
|
||||||
tokens = 0
|
tokens = 0
|
||||||
if embedding_model_instance:
|
if embedding_model_instance:
|
||||||
tokens += sum(
|
page_content_list = [document.page_content for document in chunk_documents]
|
||||||
embedding_model_instance.get_text_embedding_num_tokens([document.page_content])
|
tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))
|
||||||
for document in chunk_documents
|
|
||||||
)
|
|
||||||
|
|
||||||
# load index
|
# load index
|
||||||
index_processor.load(dataset, chunk_documents, with_keywords=False)
|
index_processor.load(dataset, chunk_documents, with_keywords=False)
|
||||||
|
@ -175,7 +175,7 @@ class ModelInstance:
|
|||||||
|
|
||||||
def get_llm_num_tokens(
|
def get_llm_num_tokens(
|
||||||
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||||
) -> int:
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for llm
|
Get number of tokens for llm
|
||||||
|
|
||||||
@ -235,7 +235,7 @@ class ModelInstance:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
texts=texts,
|
texts=texts,
|
||||||
)[0] # TODO: fix this, this is only for temporary compatibility with old
|
)
|
||||||
|
|
||||||
def invoke_rerank(
|
def invoke_rerank(
|
||||||
self,
|
self,
|
||||||
|
@ -79,7 +79,13 @@ class DatasetDocumentStore:
|
|||||||
model=self._dataset.embedding_model,
|
model=self._dataset.embedding_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
for doc in docs:
|
if embedding_model:
|
||||||
|
page_content_list = [doc.page_content for doc in docs]
|
||||||
|
tokens_list = embedding_model.get_text_embedding_num_tokens(page_content_list)
|
||||||
|
else:
|
||||||
|
tokens_list = [0] * len(docs)
|
||||||
|
|
||||||
|
for doc, tokens in zip(docs, tokens_list):
|
||||||
if not isinstance(doc, Document):
|
if not isinstance(doc, Document):
|
||||||
raise ValueError("doc must be a Document")
|
raise ValueError("doc must be a Document")
|
||||||
|
|
||||||
@ -91,12 +97,6 @@ class DatasetDocumentStore:
|
|||||||
f"doc_id {doc.metadata['doc_id']} already exists. Set allow_update to True to overwrite."
|
f"doc_id {doc.metadata['doc_id']} already exists. Set allow_update to True to overwrite."
|
||||||
)
|
)
|
||||||
|
|
||||||
# calc embedding use tokens
|
|
||||||
if embedding_model:
|
|
||||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content])
|
|
||||||
else:
|
|
||||||
tokens = 0
|
|
||||||
|
|
||||||
if not segment_document:
|
if not segment_document:
|
||||||
max_position += 1
|
max_position += 1
|
||||||
|
|
||||||
|
@ -65,8 +65,9 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||||||
chunks = [text]
|
chunks = [text]
|
||||||
|
|
||||||
final_chunks = []
|
final_chunks = []
|
||||||
for chunk in chunks:
|
chunks_lengths = self._length_function(chunks)
|
||||||
if self._length_function(chunk) > self._chunk_size:
|
for chunk, chunk_length in zip(chunks, chunks_lengths):
|
||||||
|
if chunk_length > self._chunk_size:
|
||||||
final_chunks.extend(self.recursive_split_text(chunk))
|
final_chunks.extend(self.recursive_split_text(chunk))
|
||||||
else:
|
else:
|
||||||
final_chunks.append(chunk)
|
final_chunks.append(chunk)
|
||||||
@ -93,7 +94,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||||||
# Now go merging things, recursively splitting longer texts.
|
# Now go merging things, recursively splitting longer texts.
|
||||||
_good_splits = []
|
_good_splits = []
|
||||||
_good_splits_lengths = [] # cache the lengths of the splits
|
_good_splits_lengths = [] # cache the lengths of the splits
|
||||||
for s in splits:
|
s_lens = self._length_function(splits)
|
||||||
|
for s, s_len in zip(splits, s_lens):
|
||||||
s_len = self._length_function(s)
|
s_len = self._length_function(s)
|
||||||
if s_len < self._chunk_size:
|
if s_len < self._chunk_size:
|
||||||
_good_splits.append(s)
|
_good_splits.append(s)
|
||||||
|
@ -45,7 +45,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
self,
|
self,
|
||||||
chunk_size: int = 4000,
|
chunk_size: int = 4000,
|
||||||
chunk_overlap: int = 200,
|
chunk_overlap: int = 200,
|
||||||
length_function: Callable[[str], int] = len,
|
length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x],
|
||||||
keep_separator: bool = False,
|
keep_separator: bool = False,
|
||||||
add_start_index: bool = False,
|
add_start_index: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -224,8 +224,8 @@ class CharacterTextSplitter(TextSplitter):
|
|||||||
splits = _split_text_with_regex(text, self._separator, self._keep_separator)
|
splits = _split_text_with_regex(text, self._separator, self._keep_separator)
|
||||||
_separator = "" if self._keep_separator else self._separator
|
_separator = "" if self._keep_separator else self._separator
|
||||||
_good_splits_lengths = [] # cache the lengths of the splits
|
_good_splits_lengths = [] # cache the lengths of the splits
|
||||||
for split in splits:
|
if splits:
|
||||||
_good_splits_lengths.append(self._length_function(split))
|
_good_splits_lengths.extend(self._length_function(splits))
|
||||||
return self._merge_splits(splits, _separator, _good_splits_lengths)
|
return self._merge_splits(splits, _separator, _good_splits_lengths)
|
||||||
|
|
||||||
|
|
||||||
@ -478,9 +478,8 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
_good_splits = []
|
_good_splits = []
|
||||||
_good_splits_lengths = [] # cache the lengths of the splits
|
_good_splits_lengths = [] # cache the lengths of the splits
|
||||||
_separator = "" if self._keep_separator else separator
|
_separator = "" if self._keep_separator else separator
|
||||||
|
s_lens = self._length_function(splits)
|
||||||
for s in splits:
|
for s, s_len in zip(splits, s_lens):
|
||||||
s_len = self._length_function(s)
|
|
||||||
if s_len < self._chunk_size:
|
if s_len < self._chunk_size:
|
||||||
_good_splits.append(s)
|
_good_splits.append(s)
|
||||||
_good_splits_lengths.append(s_len)
|
_good_splits_lengths.append(s_len)
|
||||||
|
@ -1390,7 +1390,7 @@ class SegmentService:
|
|||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
)
|
)
|
||||||
# calc embedding use tokens
|
# calc embedding use tokens
|
||||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
|
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||||
lock_name = "add_segment_lock_document_id_{}".format(document.id)
|
lock_name = "add_segment_lock_document_id_{}".format(document.id)
|
||||||
with redis_client.lock(lock_name, timeout=600):
|
with redis_client.lock(lock_name, timeout=600):
|
||||||
max_position = (
|
max_position = (
|
||||||
@ -1467,9 +1467,12 @@ class SegmentService:
|
|||||||
if dataset.indexing_technique == "high_quality" and embedding_model:
|
if dataset.indexing_technique == "high_quality" and embedding_model:
|
||||||
# calc embedding use tokens
|
# calc embedding use tokens
|
||||||
if document.doc_form == "qa_model":
|
if document.doc_form == "qa_model":
|
||||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]])
|
tokens = embedding_model.get_text_embedding_num_tokens(
|
||||||
|
texts=[content + segment_item["answer"]]
|
||||||
|
)[0]
|
||||||
else:
|
else:
|
||||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
|
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||||
|
|
||||||
segment_document = DocumentSegment(
|
segment_document = DocumentSegment(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
dataset_id=document.dataset_id,
|
dataset_id=document.dataset_id,
|
||||||
@ -1577,9 +1580,9 @@ class SegmentService:
|
|||||||
|
|
||||||
# calc embedding use tokens
|
# calc embedding use tokens
|
||||||
if document.doc_form == "qa_model":
|
if document.doc_form == "qa_model":
|
||||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])
|
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0]
|
||||||
else:
|
else:
|
||||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
|
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||||
segment.content = content
|
segment.content = content
|
||||||
segment.index_node_hash = segment_hash
|
segment.index_node_hash = segment_hash
|
||||||
segment.word_count = len(content)
|
segment.word_count = len(content)
|
||||||
|
@ -58,12 +58,16 @@ def batch_create_segment_to_index_task(
|
|||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
)
|
)
|
||||||
word_count_change = 0
|
word_count_change = 0
|
||||||
for segment in content:
|
if embedding_model:
|
||||||
|
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||||
|
texts=[segment["content"] for segment in content]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tokens_list = [0] * len(content)
|
||||||
|
for segment, tokens in zip(content, tokens_list):
|
||||||
content = segment["content"]
|
content = segment["content"]
|
||||||
doc_id = str(uuid.uuid4())
|
doc_id = str(uuid.uuid4())
|
||||||
segment_hash = helper.generate_text_hash(content)
|
segment_hash = helper.generate_text_hash(content)
|
||||||
# calc embedding use tokens
|
|
||||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0
|
|
||||||
max_position = (
|
max_position = (
|
||||||
db.session.query(func.max(DocumentSegment.position))
|
db.session.query(func.max(DocumentSegment.position))
|
||||||
.filter(DocumentSegment.document_id == dataset_document.id)
|
.filter(DocumentSegment.document_id == dataset_document.id)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user