feat: implement asynchronous token counting in GPT2Tokenizer (#12239)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-31 17:02:08 +08:00 committed by GitHub
parent 63a0b8ba79
commit 6a85960605
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,11 +1,13 @@
from os.path import abspath, dirname, join from os.path import abspath, dirname, join
from threading import Lock from threading import Lock
from typing import Any from typing import Any, cast
import gevent.threadpool # type: ignore
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
_tokenizer: Any = None _tokenizer: Any = None
_lock = Lock() _lock = Lock()
_pool = gevent.threadpool.ThreadPool(1)
class GPT2Tokenizer: class GPT2Tokenizer:
@ -20,7 +22,9 @@ class GPT2Tokenizer:
@staticmethod @staticmethod
def get_num_tokens(text: str) -> int: def get_num_tokens(text: str) -> int:
return GPT2Tokenizer._get_num_tokens_by_gpt2(text) future = _pool.spawn(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
result = future.get(block=True)
return cast(int, result)
@staticmethod @staticmethod
def get_encoder() -> Any: def get_encoder() -> Any: