From 92ab7ef65925edfb312952f0153d4849ab2309e4 Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Tue, 3 Dec 2024 16:22:39 +0800 Subject: [PATCH] Refactor embedding batch_size (#3825) ### What problem does this PR solve? Refactor embedding batch_size. Close #3657 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring --- api/db/services/llm_service.py | 8 +- rag/benchmark.py | 13 +- rag/llm/embedding_model.py | 248 ++++++++++++++++++++------------- 3 files changed, 160 insertions(+), 109 deletions(-) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 4f69e72a2..128f154f2 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -232,13 +232,13 @@ class LLMBundle(object): self.max_length = lm.max_tokens break - def encode(self, texts: list, batch_size=32): - emd, used_tokens = self.mdl.encode(texts, batch_size) + def encode(self, texts: list): + embeddings, used_tokens = self.mdl.encode(texts) if not TenantLLMService.increase_usage( self.tenant_id, self.llm_type, used_tokens): logging.error( "LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) - return emd, used_tokens + return embeddings, used_tokens def encode_queries(self, query: str): emd, used_tokens = self.mdl.encode_queries(query) @@ -280,7 +280,7 @@ class LLMBundle(object): logging.error( "LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id)) return - yield chunk + yield chunk def chat(self, system, history, gen_conf): txt, used_tokens = self.mdl.chat(system, history, gen_conf) diff --git a/rag/benchmark.py b/rag/benchmark.py index 1146d55bf..31a6c92ff 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -63,16 +63,13 @@ class Benchmark: run[query][c["chunk_id"]] = c["similarity"] return run - def embedding(self, docs, batch_size=16): - vects = [] - cnts = [d["content_with_weight"] for d in docs] - for i in range(0, len(cnts), batch_size): - vts, c = self.embd_mdl.encode(cnts[i: i + batch_size]) - vects.extend(vts.tolist()) - assert len(docs) == len(vects) + def embedding(self, docs): + texts = [d["content_with_weight"] for d in docs] + embeddings, _ = self.embd_mdl.encode(texts) + assert len(docs) == len(embeddings) vector_size = 0 for i, d in enumerate(docs): - v = vects[i] + v = embeddings[i] vector_size = len(v) d["q_%d_vec" % len(v)] = v return docs, vector_size diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 6d70ada4b..bdebdc3a3 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -38,7 +38,7 @@ class Base(ABC): def __init__(self, key, model_name): pass - def encode(self, texts: list, batch_size=16): + def encode(self, texts: list): raise NotImplementedError("Please implement encode method!") def encode_queries(self, text: str): @@ -78,15 +78,16 @@ class DefaultEmbedding(Base): use_fp16=torch.cuda.is_available()) self._model = DefaultEmbedding._model - def encode(self, texts: list, batch_size=16): + def encode(self, texts: list): + batch_size = 16 texts = [truncate(t, 2048) for t in texts] token_count = 0 for t in texts: token_count += num_tokens_from_string(t) - res = [] + ress = [] for i in range(0, len(texts), batch_size): - res.extend(self._model.encode(texts[i:i + batch_size]).tolist()) - return np.array(res), token_count + ress.extend(self._model.encode(texts[i:i + batch_size]).tolist()) + return np.array(ress), token_count def encode_queries(self, text: str): token_count = num_tokens_from_string(text) @@ -101,12 +102,18 @@ class OpenAIEmbed(Base): self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name - def encode(self, texts: list, batch_size=16): + def encode(self, texts: list): + # OpenAI requires batch size <=16 + batch_size = 16 texts = [truncate(t, 8191) for t in texts] - res = self.client.embeddings.create(input=texts, - model=self.model_name) - return np.array([d.embedding for d in res.data] - ), res.usage.total_tokens + ress = [] + total_tokens = 0 + for i in range(0, len(texts), batch_size): + res = self.client.embeddings.create(input=texts[i:i + batch_size], + model=self.model_name) + ress.extend([d.embedding for d in res.data]) + total_tokens += res.usage.total_tokens + return np.array(ress), total_tokens def encode_queries(self, text): res = self.client.embeddings.create(input=[truncate(text, 8191)], @@ -123,12 +130,14 @@ class LocalAIEmbed(Base): self.client = OpenAI(api_key="empty", base_url=base_url) self.model_name = model_name.split("___")[0] - def encode(self, texts: list, batch_size=16): - res = self.client.embeddings.create(input=texts, model=self.model_name) - return ( - np.array([d.embedding for d in res.data]), - 1024, - ) # local embedding for LmStudio donot count tokens + def encode(self, texts: list): + batch_size = 16 + ress = [] + for i in range(0, len(texts), batch_size): + res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name) + ress.extend([d.embedding for d in res.data]) + # local embedding for LmStudio donot count tokens + return np.array(ress), 1024 def encode_queries(self, text): embds, cnt = self.encode([text]) @@ -155,12 +164,12 @@ class BaiChuanEmbed(OpenAIEmbed): class QWenEmbed(Base): def __init__(self, key, model_name="text_embedding_v2", **kwargs): - dashscope.api_key = key + self.key = key self.model_name = model_name - def encode(self, texts: list, batch_size=10): + def encode(self, texts: list): import dashscope - batch_size = min(batch_size, 4) + batch_size = 4 try: res = [] token_count = 0 @@ -169,6 +178,7 @@ class QWenEmbed(Base): resp = dashscope.TextEmbedding.call( model=self.model_name, input=texts[i:i + batch_size], + api_key=self.key, text_type="document" ) embds = [[] for _ in range(len(resp["output"]["embeddings"]))] @@ -186,6 +196,7 @@ class QWenEmbed(Base): resp = dashscope.TextEmbedding.call( model=self.model_name, input=text[:2048], + api_key=self.key, text_type="query" ) return np.array(resp["output"]["embeddings"][0] @@ -200,7 +211,7 @@ class ZhipuEmbed(Base): self.client = ZhipuAI(api_key=key) self.model_name = model_name - def encode(self, texts: list, batch_size=16): + def encode(self, texts: list): arr = [] tks_num = 0 for txt in texts: @@ -221,7 +232,7 @@ class OllamaEmbed(Base): self.client = Client(host=kwargs["base_url"]) self.model_name = model_name - def encode(self, texts: list, batch_size=16): + def encode(self, texts: list): arr = [] tks_num = 0 for txt in texts: @@ -252,13 +263,13 @@ class FastEmbed(Base): from fastembed import TextEmbedding self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) - def encode(self, texts: list, batch_size=16): + def encode(self, texts: list): # Using the internal tokenizer to encode the texts and get the total # number of tokens encodings = self._model.model.tokenizer.encode_batch(texts) total_tokens = sum(len(e) for e in encodings) - embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)] + embeddings = [e.tolist() for e in self._model.embed(texts, batch_size=16)] return np.array(embeddings), total_tokens @@ -278,11 +289,15 @@ class XinferenceEmbed(Base): self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name - def encode(self, texts: list, batch_size=16): - res = self.client.embeddings.create(input=texts, - model=self.model_name) - return np.array([d.embedding for d in res.data] - ), res.usage.total_tokens + def encode(self, texts: list): + batch_size = 16 + ress = [] + total_tokens = 0 + for i in range(0, len(texts), batch_size): + res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name) + ress.extend([d.embedding for d in res.data]) + total_tokens += res.usage.total_tokens + return np.array(ress), total_tokens def encode_queries(self, text): res = self.client.embeddings.create(input=[text], @@ -306,7 +321,8 @@ class YoudaoEmbed(Base): model_name_or_path=model_name.replace( "maidalun1020", "InfiniFlow")) - def encode(self, texts: list, batch_size=10): + def encode(self, texts: list): + batch_size = 10 res = [] token_count = 0 for t in texts: @@ -332,15 +348,21 @@ class JinaEmbed(Base): } self.model_name = model_name - def encode(self, texts: list, batch_size=None): + def encode(self, texts: list): texts = [truncate(t, 8196) for t in texts] - data = { - "model": self.model_name, - "input": texts, - 'encoding_type': 'float' - } - res = requests.post(self.base_url, headers=self.headers, json=data).json() - return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"] + batch_size = 16 + ress = [] + token_count = 0 + for i in range(0, len(texts), batch_size): + data = { + "model": self.model_name, + "input": texts[i:i + batch_size], + 'encoding_type': 'float' + } + res = requests.post(self.base_url, headers=self.headers, json=data).json() + ress.extend([d["embedding"] for d in res["data"]]) + token_count += res["usage"]["total_tokens"] + return np.array(ress), token_count def encode_queries(self, text): embds, cnt = self.encode([text]) @@ -394,12 +416,17 @@ class MistralEmbed(Base): self.client = MistralClient(api_key=key) self.model_name = model_name - def encode(self, texts: list, batch_size=16): + def encode(self, texts: list): texts = [truncate(t, 8196) for t in texts] - res = self.client.embeddings(input=texts, - model=self.model_name) - return np.array([d.embedding for d in res.data] - ), res.usage.total_tokens + batch_size = 16 + ress = [] + token_count = 0 + for i in range(0, len(texts), batch_size): + res = self.client.embeddings(input=texts[i:i + batch_size], + model=self.model_name) + ress.extend([d.embedding for d in res.data]) + token_count += res.usage.total_tokens + return np.array(ress), token_count def encode_queries(self, text): res = self.client.embeddings(input=[truncate(text, 8196)], @@ -418,7 +445,7 @@ class BedrockEmbed(Base): self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) - def encode(self, texts: list, batch_size=16): + def encode(self, texts: list): texts = [truncate(t, 8196) for t in texts] embeddings = [] token_count = 0 @@ -436,7 +463,6 @@ class BedrockEmbed(Base): return np.array(embeddings), token_count def encode_queries(self, text): - embeddings = [] token_count = num_tokens_from_string(text) if self.model_name.split('.')[0] == 'amazon': @@ -453,20 +479,26 @@ class BedrockEmbed(Base): class GeminiEmbed(Base): def __init__(self, key, model_name='models/text-embedding-004', **kwargs): - genai.configure(api_key=key) + self.key = key self.model_name = 'models/' + model_name - def encode(self, texts: list, batch_size=16): + def encode(self, texts: list): texts = [truncate(t, 2048) for t in texts] token_count = sum(num_tokens_from_string(text) for text in texts) - result = genai.embed_content( - model=self.model_name, - content=texts, - task_type="retrieval_document", - title="Embedding of list of strings") - return np.array(result['embedding']),token_count + genai.configure(api_key=self.key) + batch_size = 16 + ress = [] + for i in range(0, len(texts), batch_size): + result = genai.embed_content( + model=self.model_name, + content=texts[i, i + batch_size], + task_type="retrieval_document", + title="Embedding of single string") + ress.extend(result['embedding']) + return np.array(ress),token_count def encode_queries(self, text): + genai.configure(api_key=self.key) result = genai.embed_content( model=self.model_name, content=truncate(text,2048), @@ -495,19 +527,22 @@ class NvidiaEmbed(Base): if model_name == "snowflake/arctic-embed-l": self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings" - def encode(self, texts: list, batch_size=None): - payload = { - "input": texts, - "input_type": "query", - "model": self.model_name, - "encoding_format": "float", - "truncate": "END", - } - res = requests.post(self.base_url, headers=self.headers, json=payload).json() - return ( - np.array([d["embedding"] for d in res["data"]]), - res["usage"]["total_tokens"], - ) + def encode(self, texts: list): + batch_size = 16 + ress = [] + token_count = 0 + for i in range(0, len(texts), batch_size): + payload = { + "input": texts[i : i + batch_size], + "input_type": "query", + "model": self.model_name, + "encoding_format": "float", + "truncate": "END", + } + res = requests.post(self.base_url, headers=self.headers, json=payload).json() + ress.extend([d["embedding"] for d in res["data"]]) + token_count += res["usage"]["total_tokens"] + return np.array(ress), token_count def encode_queries(self, text): embds, cnt = self.encode([text]) @@ -541,16 +576,20 @@ class CoHereEmbed(Base): self.client = Client(api_key=key) self.model_name = model_name - def encode(self, texts: list, batch_size=16): - res = self.client.embed( - texts=texts, - model=self.model_name, - input_type="search_query", - embedding_types=["float"], - ) - return np.array([d for d in res.embeddings.float]), int( - res.meta.billed_units.input_tokens - ) + def encode(self, texts: list): + batch_size = 16 + ress = [] + token_count = 0 + for i in range(0, len(texts), batch_size): + res = self.client.embed( + texts=texts[i : i + batch_size], + model=self.model_name, + input_type="search_document", + embedding_types=["float"], + ) + ress.extend([d for d in res.embeddings.float]) + token_count += res.meta.billed_units.input_tokens + return np.array(ress), token_count def encode_queries(self, text): res = self.client.embed( @@ -599,19 +638,23 @@ class SILICONFLOWEmbed(Base): self.base_url = base_url self.model_name = model_name - def encode(self, texts: list, batch_size=16): - payload = { - "model": self.model_name, - "input": texts, - "encoding_format": "float", - } - res = requests.post(self.base_url, json=payload, headers=self.headers).json() - if "data" not in res or not isinstance(res["data"], list) or len(res["data"])!= len(texts): - raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}") - return ( - np.array([d["embedding"] for d in res["data"]]), - res["usage"]["total_tokens"], - ) + def encode(self, texts: list): + batch_size = 16 + ress = [] + token_count = 0 + for i in range(0, len(texts), batch_size): + texts_batch = texts[i : i + batch_size] + payload = { + "model": self.model_name, + "input": texts_batch, + "encoding_format": "float", + } + res = requests.post(self.base_url, json=payload, headers=self.headers).json() + if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch): + raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}") + ress.extend([d["embedding"] for d in res["data"]]) + token_count += res["usage"]["total_tokens"] + return np.array(ress), token_count def encode_queries(self, text): payload = { @@ -632,9 +675,14 @@ class ReplicateEmbed(Base): self.model_name = model_name self.client = Client(api_token=key) - def encode(self, texts: list, batch_size=16): - res = self.client.run(self.model_name, input={"texts": json.dumps(texts)}) - return np.array(res), sum([num_tokens_from_string(text) for text in texts]) + def encode(self, texts: list): + batch_size = 16 + token_count = sum([num_tokens_from_string(text) for text in texts]) + ress = [] + for i in range(0, len(texts), batch_size): + res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]}) + ress.extend(res) + return np.array(ress), token_count def encode_queries(self, text): res = self.client.embed(self.model_name, input={"texts": [text]}) @@ -673,11 +721,17 @@ class VoyageEmbed(Base): self.client = voyageai.Client(api_key=key) self.model_name = model_name - def encode(self, texts: list, batch_size=16): - res = self.client.embed( - texts=texts, model=self.model_name, input_type="document" - ) - return np.array(res.embeddings), res.total_tokens + def encode(self, texts: list): + batch_size = 16 + ress = [] + token_count = 0 + for i in range(0, len(texts), batch_size): + res = self.client.embed( + texts=texts[i : i + batch_size], model=self.model_name, input_type="document" + ) + ress.extend(res.embeddings) + token_count += res.total_tokens + return np.array(ress), token_count def encode_queries(self, text): res = self.client.embed( @@ -694,7 +748,7 @@ class HuggingFaceEmbed(Base): self.model_name = model_name self.base_url = base_url or "http://127.0.0.1:8080" - def encode(self, texts: list, batch_size=16): + def encode(self, texts: list): embeddings = [] for text in texts: response = requests.post(