diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 6dab0aaf2d..ab45a95803 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -1,11 +1,13 @@ from os.path import abspath, dirname, join from threading import Lock -from typing import Any +from typing import Any, cast +import gevent.threadpool # type: ignore from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore _tokenizer: Any = None _lock = Lock() +_pool = gevent.threadpool.ThreadPool(1) class GPT2Tokenizer: @@ -20,7 +22,9 @@ class GPT2Tokenizer: @staticmethod def get_num_tokens(text: str) -> int: - return GPT2Tokenizer._get_num_tokens_by_gpt2(text) + future = _pool.spawn(GPT2Tokenizer._get_num_tokens_by_gpt2, text) + result = future.get(block=True) + return cast(int, result) @staticmethod def get_encoder() -> Any: