From d3f5b1cbb66655c33d7153b94dc9c185c404a245 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 7 Jan 2025 13:32:30 +0800 Subject: [PATCH] refactor: use tiktoken for token calculation (#12416) Signed-off-by: -LAN- --- .../__base/tokenizers/gpt2_tokenzier.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 72d9b7163c..9a5c40addb 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -1,13 +1,10 @@ -from concurrent.futures import ProcessPoolExecutor -from os.path import abspath, dirname, join from threading import Lock -from typing import Any, cast +from typing import Any -from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore +import tiktoken _tokenizer: Any = None _lock = Lock() -_executor = ProcessPoolExecutor(max_workers=1) class GPT2Tokenizer: @@ -17,22 +14,28 @@ class GPT2Tokenizer: use gpt2 tokenizer to get num tokens """ _tokenizer = GPT2Tokenizer.get_encoder() - tokens = _tokenizer.encode(text, verbose=False) + tokens = _tokenizer.encode(text) return len(tokens) @staticmethod def get_num_tokens(text: str) -> int: - future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text) - result = future.result() - return cast(int, result) + # Because this process needs more cpu resource, we turn this back before we find a better way to handle it. + # + # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text) + # result = future.result() + # return cast(int, result) + return GPT2Tokenizer._get_num_tokens_by_gpt2(text) @staticmethod def get_encoder() -> Any: global _tokenizer, _lock with _lock: if _tokenizer is None: - base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), "gpt2") - _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) + # Try to use tiktoken to get the tokenizer because it is faster + # + _tokenizer = tiktoken.get_encoding("gpt2") + # base_path = abspath(__file__) + # gpt2_tokenizer_path = join(dirname(base_path), "gpt2") + # _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) return _tokenizer