From 22a1bc337f7a46dc75c58d8fc88e0bde8af6590b Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 17 Apr 2025 16:44:00 +0900 Subject: [PATCH] fix: perferred model provider not match with provider. (#18282) Signed-off-by: -LAN- --- api/core/provider_manager.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 099acfd7f4..7570200175 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -124,6 +124,15 @@ class ProviderManager: # Get All preferred provider types of the workspace provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) + # Ensure that both the original provider name and its ModelProviderID string representation + # are present in the dictionary to handle cases where either form might be used + for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()): + provider_id = ModelProviderID(provider_name) + if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict: + # Add the ModelProviderID string representation if it's not already present + provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = ( + provider_name_to_preferred_model_provider_records_dict[provider_name] + ) # Get All provider model settings provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) @@ -497,8 +506,8 @@ class ProviderManager: @staticmethod def _init_trial_provider_records( - tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] - ) -> dict[str, list]: + tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] + ) -> dict[str, list[Provider]]: """ Initialize trial provider records if not exists. @@ -532,7 +541,7 @@ class ProviderManager: if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: try: # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic - provider_record = Provider( + new_provider_record = Provider( tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, @@ -542,11 +551,12 @@ class ProviderManager: quota_used=0, is_valid=True, ) - db.session.add(provider_record) + db.session.add(new_provider_record) db.session.commit() + provider_name_to_provider_records_dict[provider_name].append(new_provider_record) except IntegrityError: db.session.rollback() - provider_record = ( + existed_provider_record = ( db.session.query(Provider) .filter( Provider.tenant_id == tenant_id, @@ -556,11 +566,14 @@ class ProviderManager: ) .first() ) - if provider_record and not provider_record.is_valid: - provider_record.is_valid = True + if not existed_provider_record: + continue + + if not existed_provider_record.is_valid: + existed_provider_record.is_valid = True db.session.commit() - provider_name_to_provider_records_dict[provider_name].append(provider_record) + provider_name_to_provider_records_dict[provider_name].append(existed_provider_record) return provider_name_to_provider_records_dict