From 2381264a3f7a39a5c9facc6da87c62059dc88251 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 3 Jan 2024 09:12:53 +0800 Subject: [PATCH] fix: provider create cause IntegrityError (#1866) --- api/core/provider_manager.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 0dafb6f0f2..df4157c1f2 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,6 +3,8 @@ from collections import defaultdict from json import JSONDecodeError from typing import Optional +from psycopg2 import IntegrityError + from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfigurations, ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, CustomModelConfiguration, \ @@ -380,17 +382,28 @@ class ProviderManager: if quota.quota_type == ProviderQuotaType.TRIAL: # Init trial provider records if not exists if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: - provider_record = Provider( - tenant_id=tenant_id, - provider_name=provider_name, - provider_type=ProviderType.SYSTEM.value, - quota_type=ProviderQuotaType.TRIAL.value, - quota_limit=quota.quota_limit, - quota_used=0, - is_valid=True - ) - db.session.add(provider_record) - db.session.commit() + try: + provider_record = Provider( + tenant_id=tenant_id, + provider_name=provider_name, + provider_type=ProviderType.SYSTEM.value, + quota_type=ProviderQuotaType.TRIAL.value, + quota_limit=quota.quota_limit, + quota_used=0, + is_valid=True + ) + db.session.add(provider_record) + db.session.commit() + except IntegrityError: + db.session.rollback() + provider_record = db.session.query(Provider) \ + .filter( + Provider.tenant_id == tenant_id, + Provider.provider_name == provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == ProviderQuotaType.TRIAL.value, + Provider.is_valid == True + ).first() provider_name_to_provider_records_dict[provider_name].append(provider_record)