feat: implement asynchronous token counting in GPT2Tokenizer (#12239)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-31 17:02:08 +08:00 committed by GitHub
parent 63a0b8ba79
commit 6a85960605
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,11 +1,13 @@
from os.path import abspath, dirname, join
from threading import Lock
from typing import Any
from typing import Any, cast
import gevent.threadpool # type: ignore
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
_tokenizer: Any = None
_lock = Lock()
_pool = gevent.threadpool.ThreadPool(1)
class GPT2Tokenizer:
@ -20,7 +22,9 @@ class GPT2Tokenizer:
@staticmethod
def get_num_tokens(text: str) -> int:
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
future = _pool.spawn(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
result = future.get(block=True)
return cast(int, result)
@staticmethod
def get_encoder() -> Any: