From 6a85960605376f9cd1b88b64c9326a82fe8073b2 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 31 Dec 2024 17:02:08 +0800 Subject: [PATCH] feat: implement asynchronous token counting in GPT2Tokenizer (#12239) Signed-off-by: -LAN- --- .../model_providers/__base/tokenizers/gpt2_tokenzier.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 6dab0aaf2d..ab45a95803 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -1,11 +1,13 @@ from os.path import abspath, dirname, join from threading import Lock -from typing import Any +from typing import Any, cast +import gevent.threadpool # type: ignore from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore _tokenizer: Any = None _lock = Lock() +_pool = gevent.threadpool.ThreadPool(1) class GPT2Tokenizer: @@ -20,7 +22,9 @@ class GPT2Tokenizer: @staticmethod def get_num_tokens(text: str) -> int: - return GPT2Tokenizer._get_num_tokens_by_gpt2(text) + future = _pool.spawn(GPT2Tokenizer._get_num_tokens_by_gpt2, text) + result = future.get(block=True) + return cast(int, result) @staticmethod def get_encoder() -> Any: