mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 21:59:00 +08:00
feat: implement asynchronous token counting in GPT2Tokenizer (#12239)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
63a0b8ba79
commit
6a85960605
@ -1,11 +1,13 @@
|
|||||||
from os.path import abspath, dirname, join
|
from os.path import abspath, dirname, join
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import gevent.threadpool # type: ignore
|
||||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||||
|
|
||||||
_tokenizer: Any = None
|
_tokenizer: Any = None
|
||||||
_lock = Lock()
|
_lock = Lock()
|
||||||
|
_pool = gevent.threadpool.ThreadPool(1)
|
||||||
|
|
||||||
|
|
||||||
class GPT2Tokenizer:
|
class GPT2Tokenizer:
|
||||||
@ -20,7 +22,9 @@ class GPT2Tokenizer:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_num_tokens(text: str) -> int:
|
def get_num_tokens(text: str) -> int:
|
||||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
future = _pool.spawn(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||||
|
result = future.get(block=True)
|
||||||
|
return cast(int, result)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_encoder() -> Any:
|
def get_encoder() -> Any:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user