From 1d91535ba6458068e51c5d6efb0b08ee29e6f2a2 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 17 Jan 2024 21:17:59 +0800 Subject: [PATCH] fix: azure customize model name duplicate (#2073) --- .../model_providers/azure_openai/llm/llm.py | 10 ++++++---- .../azure_openai/text_embedding/text_embedding.py | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 55f0a9408f..c1a5e23bc2 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -1,3 +1,4 @@ +import copy import logging from typing import Generator, List, Optional, Union, cast @@ -625,9 +626,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: for ai_model_entity in LLM_BASE_MODELS: if ai_model_entity.base_model_name == base_model_name: - ai_model_entity.entity.model = model - ai_model_entity.entity.label.en_US = model - ai_model_entity.entity.label.zh_Hans = model - return ai_model_entity + ai_model_entity_copy = copy.deepcopy(ai_model_entity) + ai_model_entity_copy.entity.model = model + ai_model_entity_copy.entity.label.en_US = model + ai_model_entity_copy.entity.label.zh_Hans = model + return ai_model_entity_copy return None diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index 06897a6c45..227cd64fba 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -1,4 +1,5 @@ import base64 +import copy import time from typing import Optional, Tuple @@ -186,9 +187,10 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: for ai_model_entity in EMBEDDING_BASE_MODELS: if ai_model_entity.base_model_name == base_model_name: - ai_model_entity.entity.model = model - ai_model_entity.entity.label.en_US = model - ai_model_entity.entity.label.zh_Hans = model - return ai_model_entity + ai_model_entity_copy = copy.deepcopy(ai_model_entity) + ai_model_entity_copy.entity.model = model + ai_model_entity_copy.entity.label.en_US = model + ai_model_entity_copy.entity.label.zh_Hans = model + return ai_model_entity_copy return None