mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 21:09:05 +08:00
refactor: use tiktoken for token calculation (#12416)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
196ed8101b
commit
d3f5b1cbb6
@ -1,13 +1,10 @@
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from os.path import abspath, dirname, join
|
||||
from threading import Lock
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||
import tiktoken
|
||||
|
||||
_tokenizer: Any = None
|
||||
_lock = Lock()
|
||||
_executor = ProcessPoolExecutor(max_workers=1)
|
||||
|
||||
|
||||
class GPT2Tokenizer:
|
||||
@ -17,22 +14,28 @@ class GPT2Tokenizer:
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text, verbose=False)
|
||||
tokens = _tokenizer.encode(text)
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||
result = future.result()
|
||||
return cast(int, result)
|
||||
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
|
||||
#
|
||||
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||
# result = future.result()
|
||||
# return cast(int, result)
|
||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||
|
||||
@staticmethod
|
||||
def get_encoder() -> Any:
|
||||
global _tokenizer, _lock
|
||||
with _lock:
|
||||
if _tokenizer is None:
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
# Try to use tiktoken to get the tokenizer because it is faster
|
||||
#
|
||||
_tokenizer = tiktoken.get_encoding("gpt2")
|
||||
# base_path = abspath(__file__)
|
||||
# gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
# _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
|
||||
return _tokenizer
|
||||
|
Loading…
x
Reference in New Issue
Block a user