refactor: use tiktoken for token calculation (#12416)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-01-07 13:32:30 +08:00 committed by GitHub
parent 196ed8101b
commit d3f5b1cbb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,13 +1,10 @@
from concurrent.futures import ProcessPoolExecutor
from os.path import abspath, dirname, join
from threading import Lock
from typing import Any, cast
from typing import Any
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
import tiktoken
_tokenizer: Any = None
_lock = Lock()
_executor = ProcessPoolExecutor(max_workers=1)
class GPT2Tokenizer:
@ -17,22 +14,28 @@ class GPT2Tokenizer:
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text, verbose=False)
tokens = _tokenizer.encode(text)
return len(tokens)
@staticmethod
def get_num_tokens(text: str) -> int:
future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
result = future.result()
return cast(int, result)
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
#
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
# result = future.result()
# return cast(int, result)
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
@staticmethod
def get_encoder() -> Any:
global _tokenizer, _lock
with _lock:
if _tokenizer is None:
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
# Try to use tiktoken to get the tokenizer because it is faster
#
_tokenizer = tiktoken.get_encoding("gpt2")
# base_path = abspath(__file__)
# gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
# _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
return _tokenizer