From 5b24d7129eed0cf44cccbb4801f275dc3c2011c4 Mon Sep 17 00:00:00 2001 From: "Charlie.Wei" Date: Tue, 9 Jan 2024 19:17:47 +0800 Subject: [PATCH] Azure openai init (#1929) Co-authored-by: luowei Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/core/entities/provider_configuration.py | 86 +++++++++++++++++-- api/core/entities/provider_entities.py | 8 +- api/core/hosting_configuration.py | 63 +++++++++++--- api/core/model_manager.py | 5 +- .../model_runtime/entities/model_entities.py | 2 +- .../model_providers/azure_openai/_constant.py | 6 +- .../model_providers/azure_openai/llm/llm.py | 10 +-- api/core/provider_manager.py | 5 +- 8 files changed, 151 insertions(+), 34 deletions(-) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index a1c05a2af3..be1d71ff51 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,7 +1,7 @@ import datetime import json import logging -import time + from json import JSONDecodeError from typing import Optional, List, Dict, Tuple, Iterator @@ -11,8 +11,9 @@ from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, S from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType +from core.model_runtime.entities.model_entities import ModelType, FetchFrom +from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \ + ConfigurateMethod from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.model_provider import ModelProvider @@ -22,6 +23,8 @@ from models.provider import ProviderType, Provider, ProviderModel, TenantPreferr logger = logging.getLogger(__name__) +original_provider_configurate_methods = {} + class ProviderConfiguration(BaseModel): """ @@ -34,6 +37,20 @@ class ProviderConfiguration(BaseModel): system_configuration: SystemConfiguration custom_configuration: CustomConfiguration + def __init__(self, **data): + super().__init__(**data) + + if self.provider.provider not in original_provider_configurate_methods: + original_provider_configurate_methods[self.provider.provider] = [] + for configurate_method in self.provider.configurate_methods: + original_provider_configurate_methods[self.provider.provider].append(configurate_method) + + if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: + if (any([len(quota_configuration.restrict_models) > 0 + for quota_configuration in self.system_configuration.quota_configurations]) + and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): + self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) + def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: """ Get current credentials. @@ -123,7 +140,8 @@ class ProviderConfiguration(BaseModel): if provider_record: try: - original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {} + original_credentials = json.loads( + provider_record.encrypted_config) if provider_record.encrypted_config else {} except JSONDecodeError: original_credentials = {} @@ -265,7 +283,8 @@ class ProviderConfiguration(BaseModel): if provider_model_record: try: - original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} + original_credentials = json.loads( + provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} except JSONDecodeError: original_credentials = {} @@ -534,21 +553,70 @@ class ProviderConfiguration(BaseModel): ] ) + if self.provider.provider not in original_provider_configurate_methods: + original_provider_configurate_methods[self.provider.provider] = [] + for configurate_method in provider_instance.get_provider_schema().configurate_methods: + original_provider_configurate_methods[self.provider.provider].append(configurate_method) + + should_use_custom_model = False + if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: + should_use_custom_model = True + for quota_configuration in self.system_configuration.quota_configurations: if self.system_configuration.current_quota_type != quota_configuration.quota_type: continue - restrict_llms = quota_configuration.restrict_llms - if not restrict_llms: + restrict_models = quota_configuration.restrict_models + if len(restrict_models) == 0: break + if should_use_custom_model: + if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: + # only customizable model + for restrict_model in restrict_models: + copy_credentials = self.system_configuration.credentials.copy() + if restrict_model.base_model_name: + copy_credentials['base_model_name'] = restrict_model.base_model_name + + try: + custom_model_schema = ( + provider_instance.get_model_instance(restrict_model.model_type) + .get_customizable_model_schema_from_credentials( + restrict_model.model, + copy_credentials + ) + ) + except Exception as ex: + logger.warning(f'get custom model schema failed, {ex}') + continue + + if not custom_model_schema: + continue + + if custom_model_schema.model_type not in model_types: + continue + + provider_models.append( + ModelWithProviderEntity( + model=custom_model_schema.model, + label=custom_model_schema.label, + model_type=custom_model_schema.model_type, + features=custom_model_schema.features, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties=custom_model_schema.model_properties, + deprecated=custom_model_schema.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=ModelStatus.ACTIVE + ) + ) + # if llm name not in restricted llm list, remove it + restrict_model_names = [rm.model for rm in restrict_models] for m in provider_models: - if m.model_type == ModelType.LLM and m.model not in restrict_llms: + if m.model_type == ModelType.LLM and m.model not in restrict_model_names: m.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: m.status = ModelStatus.QUOTA_EXCEEDED - return provider_models def _get_custom_provider_models(self, diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index d8905c71af..866b064a4e 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -21,6 +21,12 @@ class SystemConfigurationStatus(Enum): UNSUPPORTED = 'unsupported' +class RestrictModel(BaseModel): + model: str + base_model_name: Optional[str] = None + model_type: ModelType + + class QuotaConfiguration(BaseModel): """ Model class for provider quota configuration. @@ -30,7 +36,7 @@ class QuotaConfiguration(BaseModel): quota_limit: int quota_used: int is_valid: bool - restrict_llms: list[str] = [] + restrict_models: list[RestrictModel] = [] class SystemConfiguration(BaseModel): diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index ad58dd382e..22605c71f9 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -4,13 +4,14 @@ from typing import Optional from flask import Flask from pydantic import BaseModel -from core.entities.provider_entities import QuotaUnit +from core.entities.provider_entities import QuotaUnit, RestrictModel +from core.model_runtime.entities.model_entities import ModelType from models.provider import ProviderQuotaType class HostingQuota(BaseModel): quota_type: ProviderQuotaType - restrict_llms: list[str] = [] + restrict_models: list[RestrictModel] = [] class TrialHostingQuota(HostingQuota): @@ -47,10 +48,9 @@ class HostingConfiguration: provider_map: dict[str, HostingProvider] = {} moderation_config: HostedModerationConfig = None - def init_app(self, app: Flask): - if app.config.get('EDITION') != 'CLOUD': - return + def init_app(self, app: Flask) -> None: + self.provider_map["azure_openai"] = self.init_azure_openai() self.provider_map["openai"] = self.init_openai() self.provider_map["anthropic"] = self.init_anthropic() self.provider_map["minimax"] = self.init_minimax() @@ -59,6 +59,47 @@ class HostingConfiguration: self.moderation_config = self.init_moderation_config() + def init_azure_openai(self) -> HostingProvider: + quota_unit = QuotaUnit.TIMES + if os.environ.get("HOSTED_AZURE_OPENAI_ENABLED") and os.environ.get("HOSTED_AZURE_OPENAI_ENABLED").lower() == 'true': + credentials = { + "openai_api_key": os.environ.get("HOSTED_AZURE_OPENAI_API_KEY"), + "openai_api_base": os.environ.get("HOSTED_AZURE_OPENAI_API_BASE"), + "base_model_name": "gpt-35-turbo" + } + + quotas = [] + hosted_quota_limit = int(os.environ.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000")) + if hosted_quota_limit != -1 or hosted_quota_limit > 0: + trial_quota = TrialHostingQuota( + quota_limit=hosted_quota_limit, + restrict_models=[ + RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM), + RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM), + RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM), + RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM), + RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM), + RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM), + RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM), + RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM), + RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM), + RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING), + ] + ) + quotas.append(trial_quota) + + return HostingProvider( + enabled=True, + credentials=credentials, + quota_unit=quota_unit, + quotas=quotas + ) + + return HostingProvider( + enabled=False, + quota_unit=quota_unit, + ) + def init_openai(self) -> HostingProvider: quota_unit = QuotaUnit.TIMES if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true': @@ -77,12 +118,12 @@ class HostingConfiguration: if hosted_quota_limit != -1 or hosted_quota_limit > 0: trial_quota = TrialHostingQuota( quota_limit=hosted_quota_limit, - restrict_llms=[ - "gpt-3.5-turbo", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-instruct", - "gpt-3.5-turbo-16k", - "text-davinci-003" + restrict_models=[ + RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM), + RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM), + RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM), + RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM), + RestrictModel(model="text-davinci-003", model_type=ModelType.LLM), ] ) quotas.append(trial_quota) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index ffbec8578e..dab30b2c3a 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -144,7 +144,7 @@ class ModelInstance: user=user ) - def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ + def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None, **params) \ -> str: """ Invoke large language model @@ -161,7 +161,8 @@ class ModelInstance: model=self.model, credentials=self.credentials, file=file, - user=user + user=user, + **params ) diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index b9393071a9..5aa0e19ef0 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -32,7 +32,7 @@ class ModelType(Enum): return cls.TEXT_EMBEDDING elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value: return cls.RERANK - elif origin_model_type == cls.SPEECH2TEXT.value: + elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value: return cls.SPEECH2TEXT elif origin_model_type == cls.MODERATION.value: return cls.MODERATION diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 4f58a850d0..3ffcd96769 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -2,7 +2,7 @@ from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \ - DefaultParameterName, PriceConfig + DefaultParameterName, PriceConfig, ModelPropertyKey from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE @@ -502,8 +502,8 @@ EMBEDDING_BASE_MODELS = [ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ - 'context_size': 8097, - 'max_chunks': 32, + ModelPropertyKey.CONTEXT_SIZE: 8097, + ModelPropertyKey.MAX_CHUNKS: 32, }, pricing=PriceConfig( input=0.0001, diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index b52b19f417..2e4cd069ab 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -30,7 +30,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: # chat model @@ -59,7 +59,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: - model_mode = self._get_ai_model_entity(credentials['base_model_name'], model).entity.model_properties.get( + model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get( ModelPropertyKey.MODE) if model_mode == LLMMode.CHAT.value: @@ -79,7 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if 'base_model_name' not in credentials: raise CredentialsValidateFailedError('Base Model Name is required') - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) if not ai_model_entity: raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') @@ -109,8 +109,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) - return ai_model_entity.entity + ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) + return ai_model_entity.entity if ai_model_entity else None def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index d1e93bb839..1265c4e423 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -12,7 +12,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType +from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \ + ConfigurateMethod from core.model_runtime.model_providers import model_provider_factory from extensions import ext_hosting_provider from extensions.ext_database import db @@ -607,7 +608,7 @@ class ProviderManager: quota_used=provider_record.quota_used, quota_limit=provider_record.quota_limit, is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1, - restrict_llms=provider_quota.restrict_llms + restrict_models=provider_quota.restrict_models ) quota_configurations.append(quota_configuration)