diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index d0487c62b0..b547f59d95 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage @@ -38,6 +38,50 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): raise ValueError('Invalid model name') if not api_key: raise CredentialsValidateFailedError('api_key is required') + + # split into chunks of batch size 16 + chunks = [] + for i in range(0, len(texts), 16): + chunks.append(texts[i:i + 16]) + + embeddings = [] + token_usage = 0 + + for chunk in chunks: + # embeding chunk + chunk_embeddings, chunk_usage = self.embedding( + model=model, + api_key=api_key, + texts=chunk, + user=user + ) + + embeddings.extend(chunk_embeddings) + token_usage += chunk_usage + + result = TextEmbeddingResult( + model=model, + embeddings=embeddings, + usage=self._calc_response_usage( + model=model, + credentials=credentials, + tokens=token_usage + ) + ) + + return result + + def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \ + -> Tuple[list[list[float]], int]: + """ + Embed given texts + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ url = self.api_base headers = { 'Authorization': 'Bearer ' + api_key, @@ -85,17 +129,10 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + return [ + data['embedding'] for data in embeddings + ], usage['total_tokens'] - result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage - ) - - return result def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py index b6d806df02..b0a6620bb0 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py @@ -59,3 +59,40 @@ def test_get_num_tokens(): ) assert num_tokens == 2 + +def test_max_chunks(): + model = BaichuanTextEmbeddingModel() + + result = model.invoke( + model='baichuan-text-embedding', + credentials={ + 'api_key': os.environ.get('BAICHUAN_API_KEY'), + }, + texts=[ + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + "hello", + "world", + ] + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 22 \ No newline at end of file