diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index fd73728b78..069de9acec 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -59,7 +59,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): if not endpoint_url.endswith('/'): endpoint_url += '/' - endpoint_url = urljoin(endpoint_url, 'api/embeddings') + endpoint_url = urljoin(endpoint_url, 'api/embed') # get model properties context_size = self._get_context_size(model, credentials) @@ -78,32 +78,28 @@ class OllamaEmbeddingModel(TextEmbeddingModel): else: inputs.append(text) - batched_embeddings = [] + # Prepare the payload for the request + payload = { + 'input': inputs, + 'model': model, + } - for text in inputs: - # Prepare the payload for the request - payload = { - 'prompt': text, - 'model': model, - } + # Make the request to the OpenAI API + response = requests.post( + endpoint_url, + headers=headers, + data=json.dumps(payload), + timeout=(10, 300) + ) - # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response.raise_for_status() # Raise an exception for HTTP errors + response_data = response.json() - response.raise_for_status() # Raise an exception for HTTP errors - response_data = response.json() + # Extract embeddings and used tokens from the response + embeddings = response_data['embeddings'] + embedding_used_tokens = self.get_num_tokens(model, credentials, inputs) - # Extract embeddings and used tokens from the response - embeddings = response_data['embedding'] - embedding_used_tokens = self.get_num_tokens(model, credentials, [text]) - - used_tokens += embedding_used_tokens - batched_embeddings.append(embeddings) + used_tokens += embedding_used_tokens # calc usage usage = self._calc_response_usage( @@ -113,7 +109,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): ) return TextEmbeddingResult( - embeddings=batched_embeddings, + embeddings=embeddings, usage=usage, model=model )