mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-05-21 12:08:50 +08:00
fix: #18744 The model order defined in position.yaml in the Model Plugin is not taking effect. (#18756)
This commit is contained in:
parent
a944542858
commit
9bcc8041e9
@ -798,7 +798,25 @@ class ProviderConfiguration(BaseModel):
|
|||||||
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
|
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
|
||||||
|
|
||||||
# resort provider_models
|
# resort provider_models
|
||||||
return sorted(provider_models, key=lambda x: x.model_type.value)
|
# Optimize sorting logic: first sort by provider.position order, then by model_type.value
|
||||||
|
# Get the position list for model types (retrieve only once for better performance)
|
||||||
|
model_type_positions = {}
|
||||||
|
if hasattr(self.provider, "position") and self.provider.position:
|
||||||
|
model_type_positions = self.provider.position
|
||||||
|
|
||||||
|
def get_sort_key(model: ModelWithProviderEntity):
|
||||||
|
# Get the position list for the current model type
|
||||||
|
positions = model_type_positions.get(model.model_type.value, [])
|
||||||
|
|
||||||
|
# If the model name is in the position list, use its index for sorting
|
||||||
|
# Otherwise use a large value (list length) to place undefined models at the end
|
||||||
|
position_index = positions.index(model.model) if model.model in positions else len(positions)
|
||||||
|
|
||||||
|
# Return composite sort key: (model_type value, model position index)
|
||||||
|
return (model.model_type.value, position_index)
|
||||||
|
|
||||||
|
# Sort using the composite sort key
|
||||||
|
return sorted(provider_models, key=get_sort_key)
|
||||||
|
|
||||||
def _get_system_provider_models(
|
def _get_system_provider_models(
|
||||||
self,
|
self,
|
||||||
|
@ -134,6 +134,9 @@ class ProviderEntity(BaseModel):
|
|||||||
# pydantic configs
|
# pydantic configs
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
# position from plugin _position.yaml
|
||||||
|
position: Optional[dict[str, list[str]]] = {}
|
||||||
|
|
||||||
@field_validator("models", mode="before")
|
@field_validator("models", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_models(cls, v):
|
def validate_models(cls, v):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user