mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 17:39:06 +08:00
refactor: avoid to use extra space when finding model by name (#13043)
This commit is contained in:
parent
b4b09ddc3c
commit
b09c39c8dc
@ -221,13 +221,12 @@ class AIModel(ABC):
|
|||||||
:param credentials: model credentials
|
:param credentials: model credentials
|
||||||
:return: model schema
|
:return: model schema
|
||||||
"""
|
"""
|
||||||
# get predefined models (predefined_models)
|
# Try to get model schema from predefined models
|
||||||
models = self.predefined_models()
|
for predefined_model in self.predefined_models():
|
||||||
|
if model == predefined_model.model:
|
||||||
model_map = {model.model: model for model in models}
|
return predefined_model
|
||||||
if model in model_map:
|
|
||||||
return model_map[model]
|
|
||||||
|
|
||||||
|
# Try to get model schema from credentials
|
||||||
if credentials:
|
if credentials:
|
||||||
model_schema = self.get_customizable_model_schema_from_credentials(model, credentials)
|
model_schema = self.get_customizable_model_schema_from_credentials(model, credentials)
|
||||||
if model_schema:
|
if model_schema:
|
||||||
|
@ -677,16 +677,17 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
:return: model schema
|
:return: model schema
|
||||||
"""
|
"""
|
||||||
# get model schema
|
|
||||||
models = self.predefined_models()
|
|
||||||
model_map = {model.model: model for model in models}
|
|
||||||
|
|
||||||
mode = credentials.get("mode")
|
mode = credentials.get("mode")
|
||||||
|
base_model_schema = None
|
||||||
|
for predefined_model in self.predefined_models():
|
||||||
|
if (
|
||||||
|
mode == "chat" and predefined_model.model == "command-light-chat"
|
||||||
|
) or predefined_model.model == "command-light":
|
||||||
|
base_model_schema = predefined_model
|
||||||
|
break
|
||||||
|
|
||||||
if mode == "chat":
|
if not base_model_schema:
|
||||||
base_model_schema = model_map["command-light-chat"]
|
raise ValueError("Model not found")
|
||||||
else:
|
|
||||||
base_model_schema = model_map["command-light"]
|
|
||||||
|
|
||||||
base_model_schema = cast(AIModelEntity, base_model_schema)
|
base_model_schema = cast(AIModelEntity, base_model_schema)
|
||||||
|
|
||||||
|
@ -341,9 +341,6 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
:param credentials: provider credentials
|
:param credentials: provider credentials
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# get predefined models
|
|
||||||
predefined_models = self.predefined_models()
|
|
||||||
predefined_models_map = {model.model: model for model in predefined_models}
|
|
||||||
|
|
||||||
# transform credentials to kwargs for model instance
|
# transform credentials to kwargs for model instance
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
@ -359,9 +356,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
base_model = model.id.split(":")[1]
|
base_model = model.id.split(":")[1]
|
||||||
|
|
||||||
base_model_schema = None
|
base_model_schema = None
|
||||||
for predefined_model_name, predefined_model in predefined_models_map.items():
|
for predefined_model in self.predefined_models():
|
||||||
if predefined_model_name in base_model:
|
if predefined_model.model in base_model:
|
||||||
base_model_schema = predefined_model
|
base_model_schema = predefined_model
|
||||||
|
break
|
||||||
|
|
||||||
if not base_model_schema:
|
if not base_model_schema:
|
||||||
continue
|
continue
|
||||||
@ -1186,12 +1184,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
base_model = model.split(":")[1]
|
base_model = model.split(":")[1]
|
||||||
|
|
||||||
# get model schema
|
# get model schema
|
||||||
models = self.predefined_models()
|
base_model_schema = None
|
||||||
model_map = {model.model: model for model in models}
|
for predefined_model in self.predefined_models():
|
||||||
if base_model not in model_map:
|
if base_model == predefined_model.model:
|
||||||
raise ValueError(f"Base model {base_model} not found")
|
base_model_schema = predefined_model
|
||||||
|
break
|
||||||
|
|
||||||
base_model_schema = model_map[base_model]
|
if not base_model_schema:
|
||||||
|
raise ValueError(f"Base model {base_model} not found")
|
||||||
|
|
||||||
base_model_schema_features = base_model_schema.features or []
|
base_model_schema_features = base_model_schema.features or []
|
||||||
base_model_schema_model_properties = base_model_schema.model_properties
|
base_model_schema_model_properties = base_model_schema.model_properties
|
||||||
|
Loading…
x
Reference in New Issue
Block a user