refactor: use tiktoken for token calculation (#12416)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-01-07 13:32:30 +08:00 committed by GitHub
parent 196ed8101b
commit d3f5b1cbb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,13 +1,10 @@
from concurrent.futures import ProcessPoolExecutor
from os.path import abspath, dirname, join
from threading import Lock from threading import Lock
from typing import Any, cast from typing import Any
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore import tiktoken
_tokenizer: Any = None _tokenizer: Any = None
_lock = Lock() _lock = Lock()
_executor = ProcessPoolExecutor(max_workers=1)
class GPT2Tokenizer: class GPT2Tokenizer:
@ -17,22 +14,28 @@ class GPT2Tokenizer:
use gpt2 tokenizer to get num tokens use gpt2 tokenizer to get num tokens
""" """
_tokenizer = GPT2Tokenizer.get_encoder() _tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text, verbose=False) tokens = _tokenizer.encode(text)
return len(tokens) return len(tokens)
@staticmethod @staticmethod
def get_num_tokens(text: str) -> int: def get_num_tokens(text: str) -> int:
future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text) # Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
result = future.result() #
return cast(int, result) # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
# result = future.result()
# return cast(int, result)
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
@staticmethod @staticmethod
def get_encoder() -> Any: def get_encoder() -> Any:
global _tokenizer, _lock global _tokenizer, _lock
with _lock: with _lock:
if _tokenizer is None: if _tokenizer is None:
base_path = abspath(__file__) # Try to use tiktoken to get the tokenizer because it is faster
gpt2_tokenizer_path = join(dirname(base_path), "gpt2") #
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) _tokenizer = tiktoken.get_encoding("gpt2")
# base_path = abspath(__file__)
# gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
# _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
return _tokenizer return _tokenizer