diff --git a/api/commands.py b/api/commands.py index 544b3110e7..4eed5de047 100644 --- a/api/commands.py +++ b/api/commands.py @@ -18,7 +18,8 @@ from models.model import Account import secrets import base64 -from models.provider import Provider +from models.provider import Provider, ProviderName +from services.provider_service import ProviderService @click.command('reset-password', help='Reset the account password.') @@ -193,9 +194,40 @@ def recreate_all_dataset_indexes(): click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green')) +@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.') +def sync_anthropic_hosted_providers(): + click.echo(click.style('Start sync anthropic hosted providers.', fg='green')) + count = 0 + + page = 1 + while True: + try: + tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50) + except NotFound: + break + + page += 1 + for tenant in tenants: + try: + click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id)) + ProviderService.create_system_provider( + tenant, + ProviderName.ANTHROPIC.value, + current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'], + True + ) + count += 1 + except Exception as e: + click.echo(click.style('Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) + continue + + click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green')) + + def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(generate_invitation_codes) app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(recreate_all_dataset_indexes) + app.cli.add_command(sync_anthropic_hosted_providers) diff --git a/api/config.py b/api/config.py index 99ecc67656..951574722a 100644 --- a/api/config.py +++ b/api/config.py @@ -51,6 +51,8 @@ DEFAULTS = { 'LOG_LEVEL': 'INFO', 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', 'DEFAULT_LLM_PROVIDER': 'openai', + 'OPENAI_HOSTED_QUOTA_LIMIT': 200, + 'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000, 'TENANT_DOCUMENT_COUNT': 100 } @@ -192,6 +194,10 @@ class Config: # hosted provider credentials self.OPENAI_API_KEY = get_env('OPENAI_API_KEY') + self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY') + + self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT') + self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT') # By default it is False # You could disable it for compatibility with certain OpenAPI providers diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index e6bd2fdf28..075e8d4a91 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -50,8 +50,8 @@ class ChatMessageAudioApi(Resource): raise UnsupportedAudioTypeError() except ProviderNotSupportSpeechToTextServiceError: raise ProviderNotSupportSpeechToTextError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 552271a9ec..e76186671d 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -63,8 +63,8 @@ class CompletionMessageApi(Resource): except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") raise AppUnavailableError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -133,8 +133,8 @@ class ChatMessageApi(Resource): except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") raise AppUnavailableError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" - except ProviderTokenNotInitError: - yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except ProviderTokenNotInitError as ex: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" except QuotaExceededError: yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" except ModelCurrentlyNotSupportError: diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index 905e5b273a..b6086b68b8 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException): class ProviderQuotaExceededError(BaseHTTPException): error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ + description = "Your quota for Dify Hosted Model Provider has been exhausted. " \ "Please go to Settings -> Model Provider to complete your own provider credentials." code = 400 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 6a74bf2584..6b9a0a2140 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource): account.current_tenant_id, args['prompt_template'] ) - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -58,8 +58,8 @@ class RuleGenerateApi(Resource): args['audiences'], args['hoping_to_solve'] ) - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 88594425de..c5764a7ec7 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource): raise NotFound("Message Not Exists.") except MoreLikeThisDisabledError: raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" except MoreLikeThisDisabledError: yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" - except ProviderTokenNotInitError: - yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except ProviderTokenNotInitError as ex: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" except QuotaExceededError: yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" except ModelCurrentlyNotSupportError: @@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource): raise NotFound("Message not found") except ConversationNotExistsError: raise NotFound("Conversation not found") - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 3e65b1319a..e165d2130a 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource): try: documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -324,8 +324,8 @@ class DatasetInitApi(Resource): document_data=args, account=current_user ) - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 771d49045f..c1ccb30dd6 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -95,8 +95,8 @@ class HitTestingApi(Resource): return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} except services.errors.index.IndexNotInitializedError: raise DatasetNotInitializedError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index a027fe625e..991a228dd5 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -47,8 +47,8 @@ class ChatAudioApi(InstalledAppResource): raise UnsupportedAudioTypeError() except ProviderNotSupportSpeechToTextServiceError: raise ProviderNotSupportSpeechToTextError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index f2a1acd6d5..bc4b88ad15 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource): except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") raise AppUnavailableError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource): except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") raise AppUnavailableError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" - except ProviderTokenNotInitError: - yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except ProviderTokenNotInitError as ex: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" except QuotaExceededError: yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" except ModelCurrentlyNotSupportError: diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index b5b9547ff7..1232169eab 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource): raise NotFound("Message Not Exists.") except MoreLikeThisDisabledError: raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" except MoreLikeThisDisabledError: yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" - except ProviderTokenNotInitError: - yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except ProviderTokenNotInitError as ex: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" except QuotaExceededError: yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" except ModelCurrentlyNotSupportError: @@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource): raise NotFound("Conversation not found") except SuggestedQuestionsAfterAnswerDisabledError: raise AppSuggestedQuestionsAfterAnswerDisabledError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: diff --git a/api/controllers/console/workspace/providers.py b/api/controllers/console/workspace/providers.py index f2baec29c1..1991b1c6c9 100644 --- a/api/controllers/console/workspace/providers.py +++ b/api/controllers/console/workspace/providers.py @@ -3,6 +3,7 @@ import base64 import json import logging +from flask import current_app from flask_login import login_required, current_user from flask_restful import Resource, reqparse, abort from werkzeug.exceptions import Forbidden @@ -34,7 +35,7 @@ class ProviderListApi(Resource): plaintext, the rest is replaced by * and the last two bits are displayed in plaintext """ - ProviderService.init_supported_provider(current_user.current_tenant, "cloud") + ProviderService.init_supported_provider(current_user.current_tenant) providers = Provider.query.filter_by(tenant_id=tenant_id).all() provider_list = [ @@ -50,7 +51,8 @@ class ProviderListApi(Resource): 'quota_used': p.quota_used } if p.provider_type == ProviderType.SYSTEM.value else {}), 'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant, - ProviderName(p.provider_name)) + ProviderName(p.provider_name), only_custom=True) + if p.provider_type == ProviderType.CUSTOM.value else None } for p in providers ] @@ -121,9 +123,10 @@ class ProviderTokenApi(Resource): is_valid=token_is_valid) db.session.add(provider_model) - if provider_model.is_valid: + if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid: other_providers = db.session.query(Provider).filter( Provider.tenant_id == tenant.id, + Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]), Provider.provider_name != provider, Provider.provider_type == ProviderType.CUSTOM.value ).all() @@ -133,7 +136,7 @@ class ProviderTokenApi(Resource): db.session.commit() - if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, + if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, ProviderName.HUGGINGFACEHUB.value]: return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201 @@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource): args = parser.parse_args() # todo: remove this when the provider is supported - if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value, + if provider in [ProviderName.COHERE.value, ProviderName.HUGGINGFACEHUB.value]: return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'} @@ -203,7 +206,19 @@ class ProviderSystemApi(Resource): provider_model.is_valid = args['is_enabled'] db.session.commit() elif not provider_model: - ProviderService.create_system_provider(tenant, provider, args['is_enabled']) + if provider == ProviderName.OPENAI.value: + quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'] + elif provider == ProviderName.ANTHROPIC.value: + quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'] + else: + quota_limit = 0 + + ProviderService.create_system_provider( + tenant, + provider, + quota_limit, + args['is_enabled'] + ) else: abort(403) diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index aacf1ca2a2..470afc6b42 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -43,8 +43,8 @@ class AudioApi(AppApiResource): raise UnsupportedAudioTypeError() except ProviderNotSupportSpeechToTextServiceError: raise ProviderNotSupportSpeechToTextError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index e5eb4153aa..448c408bce 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -54,8 +54,8 @@ class CompletionApi(AppApiResource): except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") raise AppUnavailableError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -115,8 +115,8 @@ class ChatApi(AppApiResource): except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") raise AppUnavailableError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" - except ProviderTokenNotInitError: - yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except ProviderTokenNotInitError as ex: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" except QuotaExceededError: yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" except ModelCurrentlyNotSupportError: diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 3036882d71..e00de0f9a1 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource): dataset_process_rule=dataset.latest_process_rule, created_from='api' ) - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) document = documents[0] if doc_type and doc_metadata: metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index b07382176c..3e3fe3a28d 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -45,8 +45,8 @@ class AudioApi(WebApiResource): raise UnsupportedAudioTypeError() except ProviderNotSupportSpeechToTextServiceError: raise ProviderNotSupportSpeechToTextError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 532bcfaa8d..db2f770e5a 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -52,8 +52,8 @@ class CompletionApi(WebApiResource): except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") raise AppUnavailableError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -109,8 +109,8 @@ class ChatApi(WebApiResource): except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") raise AppUnavailableError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" - except ProviderTokenNotInitError: - yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except ProviderTokenNotInitError as ex: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" except QuotaExceededError: yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" except ModelCurrentlyNotSupportError: diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 0d519eac06..3d978a1099 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource): raise NotFound("Message Not Exists.") except MoreLikeThisDisabledError: raise AppMoreLikeThisDisabledError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: @@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" except MoreLikeThisDisabledError: yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" - except ProviderTokenNotInitError: - yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except ProviderTokenNotInitError as ex: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" except QuotaExceededError: yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" except ModelCurrentlyNotSupportError: @@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource): raise NotFound("Conversation not found") except SuggestedQuestionsAfterAnswerDisabledError: raise AppSuggestedQuestionsAfterAnswerDisabledError() - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: diff --git a/api/core/__init__.py b/api/core/__init__.py index 0b26044aa5..2dc9a9e869 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel): api_key: str +class HostedAnthropicCredential(BaseModel): + api_key: str + + class HostedLLMCredentials(BaseModel): openai: Optional[HostedOpenAICredential] = None + anthropic: Optional[HostedAnthropicCredential] = None hosted_llm_credentials = HostedLLMCredentials() @@ -26,3 +31,6 @@ def init_app(app: Flask): if app.config.get("OPENAI_API_KEY"): hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) + + if app.config.get("ANTHROPIC_API_KEY"): + hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY")) diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index c9db70fa76..6b3daf76e7 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -48,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler): }) self.llm_message.prompt = real_prompts - self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0]) + self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0]) def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any diff --git a/api/core/completion.py b/api/core/completion.py index 38a81f2807..bcb75c2e07 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -118,6 +118,7 @@ class Completion: prompt, stop_words = cls.get_main_llm_prompt( mode=mode, llm=final_llm, + model=app_model_config.model_dict, pre_prompt=app_model_config.pre_prompt, query=query, inputs=inputs, @@ -129,6 +130,7 @@ class Completion: cls.recale_llm_max_tokens( final_llm=final_llm, + model=app_model_config.model_dict, prompt=prompt, mode=mode ) @@ -138,7 +140,8 @@ class Completion: return response @classmethod - def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, + def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict, + pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str], memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]: @@ -151,10 +154,11 @@ class Completion: if mode == 'completion': prompt_template = JinjaPromptTemplate.from_template( - template=("""Use the following CONTEXT as your learned knowledge: -[CONTEXT] + template=("""Use the following context as your learned knowledge, inside XML tags. + + {{context}} -[END CONTEXT] + When answer to user: - If you don't know, just say that you don't know. @@ -204,10 +208,11 @@ And answer according to the language of the user's question. if chain_output: human_inputs['context'] = chain_output - human_message_prompt += """Use the following CONTEXT as your learned knowledge. -[CONTEXT] + human_message_prompt += """Use the following context as your learned knowledge, inside XML tags. + + {{context}} -[END CONTEXT] + When answer to user: - If you don't know, just say that you don't know. @@ -219,7 +224,7 @@ And answer according to the language of the user's question. if pre_prompt: human_message_prompt += pre_prompt - query_prompt = "\nHuman: {{query}}\nAI: " + query_prompt = "\n\nHuman: {{query}}\n\nAssistant: " if memory: # append chat histories @@ -228,9 +233,11 @@ And answer according to the language of the user's question. inputs=human_inputs ) - curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message]) - rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \ - - memory.llm.max_tokens - curr_message_tokens + curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message]) + model_name = model['name'] + max_tokens = model.get("completion_params").get('max_tokens') + rest_tokens = llm_constant.max_context_token_length[model_name] \ + - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) histories = cls.get_history_messages_from_memory(memory, rest_tokens) @@ -241,7 +248,10 @@ And answer according to the language of the user's question. # if histories_param not in human_inputs: # human_inputs[histories_param] = '{{' + histories_param + '}}' - human_message_prompt += "\n\n" + histories + human_message_prompt += "\n\n" if human_message_prompt else "" + human_message_prompt += "Here is the chat histories between human and assistant, " \ + "inside XML tags.\n\n" + human_message_prompt += histories + "" human_message_prompt += query_prompt @@ -307,13 +317,15 @@ And answer according to the language of the user's question. model=app_model_config.model_dict ) - model_limited_tokens = llm_constant.max_context_token_length[llm.model_name] - max_tokens = llm.max_tokens + model_name = app_model_config.model_dict.get("name") + model_limited_tokens = llm_constant.max_context_token_length[model_name] + max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens') # get prompt without memory and context prompt, _ = cls.get_main_llm_prompt( mode=mode, llm=llm, + model=app_model_config.model_dict, pre_prompt=app_model_config.pre_prompt, query=query, inputs=inputs, @@ -332,16 +344,17 @@ And answer according to the language of the user's question. return rest_tokens @classmethod - def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI], + def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict, prompt: Union[str, List[BaseMessage]], mode: str): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit - model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name] - max_tokens = final_llm.max_tokens + model_name = model.get("name") + model_limited_tokens = llm_constant.max_context_token_length[model_name] + max_tokens = model.get("completion_params").get('max_tokens') if mode == 'completion' and isinstance(final_llm, BaseLLM): prompt_tokens = final_llm.get_num_tokens(prompt) else: - prompt_tokens = final_llm.get_messages_tokens(prompt) + prompt_tokens = final_llm.get_num_tokens_from_messages(prompt) if prompt_tokens + max_tokens > model_limited_tokens: max_tokens = max(model_limited_tokens - prompt_tokens, 16) @@ -350,9 +363,10 @@ And answer according to the language of the user's question. @classmethod def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str, app_model_config: AppModelConfig, user: Account, streaming: bool): - llm: StreamableOpenAI = LLMBuilder.to_llm( + + llm = LLMBuilder.to_llm_from_model( tenant_id=app.tenant_id, - model_name='gpt-3.5-turbo', + model=app_model_config.model_dict, streaming=streaming ) @@ -360,6 +374,7 @@ And answer according to the language of the user's question. original_prompt, _ = cls.get_main_llm_prompt( mode="completion", llm=llm, + model=app_model_config.model_dict, pre_prompt=pre_prompt, query=message.query, inputs=message.inputs, @@ -390,6 +405,7 @@ And answer according to the language of the user's question. cls.recale_llm_max_tokens( final_llm=llm, + model=app_model_config.model_dict, prompt=prompt, mode='completion' ) diff --git a/api/core/constant/llm_constant.py b/api/core/constant/llm_constant.py index 397a3d4c8f..3a02abc90e 100644 --- a/api/core/constant/llm_constant.py +++ b/api/core/constant/llm_constant.py @@ -1,6 +1,8 @@ from _decimal import Decimal models = { + 'claude-instant-1': 'anthropic', # 100,000 tokens + 'claude-2': 'anthropic', # 100,000 tokens 'gpt-4': 'openai', # 8,192 tokens 'gpt-4-32k': 'openai', # 32,768 tokens 'gpt-3.5-turbo': 'openai', # 4,096 tokens @@ -10,10 +12,13 @@ models = { 'text-curie-001': 'openai', # 2,049 tokens 'text-babbage-001': 'openai', # 2,049 tokens 'text-ada-001': 'openai', # 2,049 tokens - 'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions + 'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions + 'whisper-1': 'openai' } max_context_token_length = { + 'claude-instant-1': 100000, + 'claude-2': 100000, 'gpt-4': 8192, 'gpt-4-32k': 32768, 'gpt-3.5-turbo': 4096, @@ -23,17 +28,21 @@ max_context_token_length = { 'text-curie-001': 2049, 'text-babbage-001': 2049, 'text-ada-001': 2049, - 'text-embedding-ada-002': 8191 + 'text-embedding-ada-002': 8191, } models_by_mode = { 'chat': [ + 'claude-instant-1', # 100,000 tokens + 'claude-2', # 100,000 tokens 'gpt-4', # 8,192 tokens 'gpt-4-32k', # 32,768 tokens 'gpt-3.5-turbo', # 4,096 tokens 'gpt-3.5-turbo-16k', # 16,384 tokens ], 'completion': [ + 'claude-instant-1', # 100,000 tokens + 'claude-2', # 100,000 tokens 'gpt-4', # 8,192 tokens 'gpt-4-32k', # 32,768 tokens 'gpt-3.5-turbo', # 4,096 tokens @@ -52,6 +61,14 @@ models_by_mode = { model_currency = 'USD' model_prices = { + 'claude-instant-1': { + 'prompt': Decimal('0.00163'), + 'completion': Decimal('0.00551'), + }, + 'claude-2': { + 'prompt': Decimal('0.01102'), + 'completion': Decimal('0.03268'), + }, 'gpt-4': { 'prompt': Decimal('0.03'), 'completion': Decimal('0.06'), diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index d0e10fa119..3d8a50363f 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -56,7 +56,7 @@ class ConversationMessageTask: ) def init(self): - provider_name = LLMBuilder.get_default_provider(self.app.tenant_id) + provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name) self.model_dict['provider'] = provider_name override_model_configs = None @@ -89,7 +89,7 @@ class ConversationMessageTask: system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs) system_instruction = system_message.content llm = LLMBuilder.to_llm(self.tenant_id, self.model_name) - system_instruction_tokens = llm.get_messages_tokens([system_message]) + system_instruction_tokens = llm.get_num_tokens_from_messages([system_message]) if not self.conversation: self.is_new_conversation = True @@ -185,6 +185,7 @@ class ConversationMessageTask: if provider and provider.provider_type == ProviderType.SYSTEM.value: db.session.query(Provider).filter( Provider.tenant_id == self.app.tenant_id, + Provider.provider_name == provider.provider_name, Provider.quota_limit > Provider.quota_used ).update({'quota_used': Provider.quota_used + 1}) diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 4030eb158c..174f259456 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -4,6 +4,7 @@ from typing import List from langchain.embeddings.base import Embeddings from sqlalchemy.exc import IntegrityError +from core.llm.wrappers.openai_wrapper import handle_openai_exceptions from extensions.ext_database import db from libs import helper from models.dataset import Embedding @@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings): text_embeddings.extend(embedding_results) return text_embeddings + @handle_openai_exceptions def embed_query(self, text: str) -> List[float]: """Embed query text.""" # use doc embedding cache or store if not exists diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index 87540272da..9c6dbdfc8f 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -23,6 +23,10 @@ class LLMGenerator: @classmethod def generate_conversation_name(cls, tenant_id: str, query, answer): prompt = CONVERSATION_TITLE_PROMPT + + if len(query) > 2000: + query = query[:300] + "...[TRUNCATED]..." + query[-300:] + prompt = prompt.format(query=query) llm: StreamableOpenAI = LLMBuilder.to_llm( tenant_id=tenant_id, @@ -52,7 +56,17 @@ class LLMGenerator: if not message.answer: continue - message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n" + if len(message.query) > 2000: + query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:] + else: + query = message.query + + if len(message.answer) > 2000: + answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:] + else: + answer = message.answer + + message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0: context += message_qa_text diff --git a/api/core/index/index.py b/api/core/index/index.py index 617b763982..bd32b3c493 100644 --- a/api/core/index/index.py +++ b/api/core/index/index.py @@ -17,7 +17,7 @@ class IndexBuilder: model_credentials = LLMBuilder.get_model_credentials( tenant_id=dataset.tenant_id, - model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), + model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'), model_name='text-embedding-ada-002' ) diff --git a/api/core/llm/error.py b/api/core/llm/error.py index 883d282e8a..9bba8401f2 100644 --- a/api/core/llm/error.py +++ b/api/core/llm/error.py @@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception): """ description = "Provider Token Not Init" + def __init__(self, *args, **kwargs): + self.description = args[0] if args else self.description + class QuotaExceededError(Exception): """ diff --git a/api/core/llm/llm_builder.py b/api/core/llm/llm_builder.py index c2deda5351..50cd1b620e 100644 --- a/api/core/llm/llm_builder.py +++ b/api/core/llm/llm_builder.py @@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider from core.llm.provider.llm_provider_service import LLMProviderService from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI +from core.llm.streamable_chat_anthropic import StreamableChatAnthropic from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_open_ai import StreamableOpenAI -from models.provider import ProviderType +from models.provider import ProviderType, ProviderName class LLMBuilder: @@ -32,43 +33,43 @@ class LLMBuilder: @classmethod def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]: - provider = cls.get_default_provider(tenant_id) + provider = cls.get_default_provider(tenant_id, model_name) model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) + llm_cls = None mode = cls.get_mode_by_model(model_name) if mode == 'chat': - if provider == 'openai': + if provider == ProviderName.OPENAI.value: llm_cls = StreamableChatOpenAI - else: + elif provider == ProviderName.AZURE_OPENAI.value: llm_cls = StreamableAzureChatOpenAI + elif provider == ProviderName.ANTHROPIC.value: + llm_cls = StreamableChatAnthropic elif mode == 'completion': - if provider == 'openai': + if provider == ProviderName.OPENAI.value: llm_cls = StreamableOpenAI - else: + elif provider == ProviderName.AZURE_OPENAI.value: llm_cls = StreamableAzureOpenAI - else: + + if not llm_cls: raise ValueError(f"model name {model_name} is not supported.") - model_kwargs = { + 'model_name': model_name, + 'temperature': kwargs.get('temperature', 0), + 'max_tokens': kwargs.get('max_tokens', 256), 'top_p': kwargs.get('top_p', 1), 'frequency_penalty': kwargs.get('frequency_penalty', 0), 'presence_penalty': kwargs.get('presence_penalty', 0), + 'callbacks': kwargs.get('callbacks', None), + 'streaming': kwargs.get('streaming', False), } - model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs} + model_kwargs.update(model_credentials) + model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs) - return llm_cls( - model_name=model_name, - temperature=kwargs.get('temperature', 0), - max_tokens=kwargs.get('max_tokens', 256), - **model_extras_kwargs, - callbacks=kwargs.get('callbacks', None), - streaming=kwargs.get('streaming', False), - # request_timeout=None - **model_credentials - ) + return llm_cls(**model_kwargs) @classmethod def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, @@ -118,14 +119,29 @@ class LLMBuilder: return provider_service.get_credentials(model_name) @classmethod - def get_default_provider(cls, tenant_id: str) -> str: - provider = BaseProvider.get_valid_provider(tenant_id) - if not provider: - raise ProviderTokenNotInitError() + def get_default_provider(cls, tenant_id: str, model_name: str) -> str: + provider_name = llm_constant.models[model_name] - if provider.provider_type == ProviderType.SYSTEM.value: - provider_name = 'openai' - else: - provider_name = provider.provider_name + if provider_name == 'openai': + # get the default provider (openai / azure_openai) for the tenant + openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value) + azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value) + + provider = None + if openai_provider: + provider = openai_provider + elif azure_openai_provider: + provider = azure_openai_provider + + if not provider: + raise ProviderTokenNotInitError( + f"No valid {provider_name} model provider credentials found. " + f"Please go to Settings -> Model Provider to complete your provider credentials." + ) + + if provider.provider_type == ProviderType.SYSTEM.value: + provider_name = 'openai' + else: + provider_name = provider.provider_name return provider_name diff --git a/api/core/llm/provider/anthropic_provider.py b/api/core/llm/provider/anthropic_provider.py index 4c7756305e..d6165d0329 100644 --- a/api/core/llm/provider/anthropic_provider.py +++ b/api/core/llm/provider/anthropic_provider.py @@ -1,23 +1,138 @@ -from typing import Optional +import json +import logging +from typing import Optional, Union +import anthropic +from langchain.chat_models import ChatAnthropic +from langchain.schema import HumanMessage + +from core import hosted_llm_credentials +from core.llm.error import ProviderTokenNotInitError from core.llm.provider.base import BaseProvider -from models.provider import ProviderName +from core.llm.provider.errors import ValidateFailedError +from models.provider import ProviderName, ProviderType class AnthropicProvider(BaseProvider): def get_models(self, model_id: Optional[str] = None) -> list[dict]: - credentials = self.get_credentials(model_id) - # todo - return [] + return [ + { + 'id': 'claude-instant-1', + 'name': 'claude-instant-1', + }, + { + 'id': 'claude-2', + 'name': 'claude-2', + }, + ] def get_credentials(self, model_id: Optional[str] = None) -> dict: - """ - Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id. - The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key. - """ - return { - 'anthropic_api_key': self.get_provider_api_key(model_id=model_id) - } + return self.get_provider_api_key(model_id=model_id) def get_provider_name(self): - return ProviderName.ANTHROPIC \ No newline at end of file + return ProviderName.ANTHROPIC + + def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: + """ + Returns the provider configs. + """ + try: + config = self.get_provider_api_key(only_custom=only_custom) + except: + config = { + 'anthropic_api_key': '' + } + + if obfuscated: + if not config.get('anthropic_api_key'): + config = { + 'anthropic_api_key': '' + } + + config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key')) + return config + + return config + + def get_encrypted_token(self, config: Union[dict | str]): + """ + Returns the encrypted token. + """ + return json.dumps({ + 'anthropic_api_key': self.encrypt_token(config['anthropic_api_key']) + }) + + def get_decrypted_token(self, token: str): + """ + Returns the decrypted token. + """ + config = json.loads(token) + config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key']) + return config + + def get_token_type(self): + return dict + + def config_validate(self, config: Union[dict | str]): + """ + Validates the given config. + """ + # check OpenAI / Azure OpenAI credential is valid + openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value) + azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value) + + provider = None + if openai_provider: + provider = openai_provider + elif azure_openai_provider: + provider = azure_openai_provider + + if not provider: + raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.") + + if provider.provider_type == ProviderType.SYSTEM.value: + quota_used = provider.quota_used if provider.quota_used is not None else 0 + quota_limit = provider.quota_limit if provider.quota_limit is not None else 0 + if quota_used >= quota_limit: + raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, " + f"please configure OpenAI or Azure OpenAI provider first.") + + try: + if not isinstance(config, dict): + raise ValueError('Config must be a object.') + + if 'anthropic_api_key' not in config: + raise ValueError('anthropic_api_key must be provided.') + + chat_llm = ChatAnthropic( + model='claude-instant-1', + anthropic_api_key=config['anthropic_api_key'], + max_tokens_to_sample=10, + temperature=0, + default_request_timeout=60 + ) + + messages = [ + HumanMessage( + content="ping" + ) + ] + + chat_llm(messages) + except anthropic.APIConnectionError as ex: + raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}") + except (anthropic.APIStatusError, anthropic.RateLimitError) as ex: + raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - " + f"{ex.body['error']['type']}: {ex.body['error']['message']}") + except Exception as ex: + logging.exception('Anthropic config validation failed') + raise ex + + def get_hosted_credentials(self) -> Union[str | dict]: + if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key: + raise ProviderTokenNotInitError( + f"No valid {self.get_provider_name().value} model provider credentials found. " + f"Please go to Settings -> Model Provider to complete your provider credentials." + ) + + return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key} diff --git a/api/core/llm/provider/azure_provider.py b/api/core/llm/provider/azure_provider.py index 54b7af6462..706adc7be7 100644 --- a/api/core/llm/provider/azure_provider.py +++ b/api/core/llm/provider/azure_provider.py @@ -52,12 +52,12 @@ class AzureProvider(BaseProvider): def get_provider_name(self): return ProviderName.AZURE_OPENAI - def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: + def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: """ Returns the provider configs. """ try: - config = self.get_provider_api_key() + config = self.get_provider_api_key(only_custom=only_custom) except: config = { 'openai_api_type': 'azure', @@ -81,7 +81,6 @@ class AzureProvider(BaseProvider): return config def get_token_type(self): - # TODO: change to dict when implemented return dict def config_validate(self, config: Union[dict | str]): diff --git a/api/core/llm/provider/base.py b/api/core/llm/provider/base.py index 71bb32dca6..c3ff5cf237 100644 --- a/api/core/llm/provider/base.py +++ b/api/core/llm/provider/base.py @@ -2,7 +2,7 @@ import base64 from abc import ABC, abstractmethod from typing import Optional, Union -from core import hosted_llm_credentials +from core.constant import llm_constant from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError from extensions.ext_database import db from libs import rsa @@ -14,15 +14,18 @@ class BaseProvider(ABC): def __init__(self, tenant_id: str): self.tenant_id = tenant_id - def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]: + def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]: """ Returns the decrypted API key for the given tenant_id and provider_name. If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. If the provider is not found or not valid, raises a ProviderTokenNotInitError. """ - provider = self.get_provider(prefer_custom) + provider = self.get_provider(only_custom) if not provider: - raise ProviderTokenNotInitError() + raise ProviderTokenNotInitError( + f"No valid {llm_constant.models[model_id]} model provider credentials found. " + f"Please go to Settings -> Model Provider to complete your provider credentials." + ) if provider.provider_type == ProviderType.SYSTEM.value: quota_used = provider.quota_used if provider.quota_used is not None else 0 @@ -38,18 +41,19 @@ class BaseProvider(ABC): else: return self.get_decrypted_token(provider.encrypted_config) - def get_provider(self, prefer_custom: bool) -> Optional[Provider]: + def get_provider(self, only_custom: bool = False) -> Optional[Provider]: """ Returns the Provider instance for the given tenant_id and provider_name. If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. """ - return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom) + return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom) @classmethod - def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]: + def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[ + Provider]: """ Returns the Provider instance for the given tenant_id and provider_name. - If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. + If both CUSTOM and System providers exist. """ query = db.session.query(Provider).filter( Provider.tenant_id == tenant_id @@ -58,39 +62,31 @@ class BaseProvider(ABC): if provider_name: query = query.filter(Provider.provider_name == provider_name) - providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() + if only_custom: + query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value) - custom_provider = None - system_provider = None + providers = query.order_by(Provider.provider_type.asc()).all() for provider in providers: if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config: - custom_provider = provider + return provider elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid: - system_provider = provider + return provider - if custom_provider: - return custom_provider - elif system_provider: - return system_provider - else: - return None + return None - def get_hosted_credentials(self) -> str: - if self.get_provider_name() != ProviderName.OPENAI: - raise ProviderTokenNotInitError() + def get_hosted_credentials(self) -> Union[str | dict]: + raise ProviderTokenNotInitError( + f"No valid {self.get_provider_name().value} model provider credentials found. " + f"Please go to Settings -> Model Provider to complete your provider credentials." + ) - if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key: - raise ProviderTokenNotInitError() - - return hosted_llm_credentials.openai.api_key - - def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: + def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: """ Returns the provider configs. """ try: - config = self.get_provider_api_key() + config = self.get_provider_api_key(only_custom=only_custom) except: config = '' diff --git a/api/core/llm/provider/llm_provider_service.py b/api/core/llm/provider/llm_provider_service.py index ca4f8bec6d..a520e3d6bb 100644 --- a/api/core/llm/provider/llm_provider_service.py +++ b/api/core/llm/provider/llm_provider_service.py @@ -31,11 +31,11 @@ class LLMProviderService: def get_credentials(self, model_id: Optional[str] = None) -> dict: return self.provider.get_credentials(model_id) - def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: - return self.provider.get_provider_configs(obfuscated) + def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: + return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom) - def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]: - return self.provider.get_provider(prefer_custom) + def get_provider_db_record(self) -> Optional[Provider]: + return self.provider.get_provider() def config_validate(self, config: Union[dict | str]): """ diff --git a/api/core/llm/provider/openai_provider.py b/api/core/llm/provider/openai_provider.py index 8257ad3aab..b24e98e5d1 100644 --- a/api/core/llm/provider/openai_provider.py +++ b/api/core/llm/provider/openai_provider.py @@ -4,6 +4,8 @@ from typing import Optional, Union import openai from openai.error import AuthenticationError, OpenAIError +from core import hosted_llm_credentials +from core.llm.error import ProviderTokenNotInitError from core.llm.moderation import Moderation from core.llm.provider.base import BaseProvider from core.llm.provider.errors import ValidateFailedError @@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider): except Exception as ex: logging.exception('OpenAI config validation failed') raise ex + + def get_hosted_credentials(self) -> Union[str | dict]: + if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key: + raise ProviderTokenNotInitError( + f"No valid {self.get_provider_name().value} model provider credentials found. " + f"Please go to Settings -> Model Provider to complete your provider credentials." + ) + + return hosted_llm_credentials.openai.api_key diff --git a/api/core/llm/streamable_azure_chat_open_ai.py b/api/core/llm/streamable_azure_chat_open_ai.py index 4d1d5be0b3..5bff42fedb 100644 --- a/api/core/llm/streamable_azure_chat_open_ai.py +++ b/api/core/llm/streamable_azure_chat_open_ai.py @@ -1,11 +1,11 @@ -from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks -from langchain.schema import BaseMessage, ChatResult, LLMResult +from langchain.callbacks.manager import Callbacks +from langchain.schema import BaseMessage, LLMResult from langchain.chat_models import AzureChatOpenAI from typing import Optional, List, Dict, Any from pydantic import root_validator -from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async +from core.llm.wrappers.openai_wrapper import handle_openai_exceptions class StreamableAzureChatOpenAI(AzureChatOpenAI): @@ -46,30 +46,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): "organization": self.openai_organization if self.openai_organization else None, } - def get_messages_tokens(self, messages: List[BaseMessage]) -> int: - """Get the number of tokens in a list of messages. - - Args: - messages: The messages to count the tokens of. - - Returns: - The number of tokens in the messages. - """ - tokens_per_message = 5 - tokens_per_request = 3 - - message_tokens = tokens_per_request - message_strs = '' - for message in messages: - message_strs += message.content - message_tokens += tokens_per_message - - # calc once - message_tokens += self.get_num_tokens(message_strs) - - return message_tokens - - @handle_llm_exceptions + @handle_openai_exceptions def generate( self, messages: List[List[BaseMessage]], @@ -79,12 +56,18 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): ) -> LLMResult: return super().generate(messages, stop, callbacks, **kwargs) - @handle_llm_exceptions_async - async def agenerate( - self, - messages: List[List[BaseMessage]], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> LLMResult: - return await super().agenerate(messages, stop, callbacks, **kwargs) + @classmethod + def get_kwargs_from_model_params(cls, params: dict): + model_kwargs = { + 'top_p': params.get('top_p', 1), + 'frequency_penalty': params.get('frequency_penalty', 0), + 'presence_penalty': params.get('presence_penalty', 0), + } + + del params['top_p'] + del params['frequency_penalty'] + del params['presence_penalty'] + + params['model_kwargs'] = model_kwargs + + return params diff --git a/api/core/llm/streamable_azure_open_ai.py b/api/core/llm/streamable_azure_open_ai.py index ac2258bb61..108f723aa4 100644 --- a/api/core/llm/streamable_azure_open_ai.py +++ b/api/core/llm/streamable_azure_open_ai.py @@ -5,7 +5,7 @@ from typing import Optional, List, Dict, Mapping, Any from pydantic import root_validator -from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async +from core.llm.wrappers.openai_wrapper import handle_openai_exceptions class StreamableAzureOpenAI(AzureOpenAI): @@ -50,7 +50,7 @@ class StreamableAzureOpenAI(AzureOpenAI): "organization": self.openai_organization if self.openai_organization else None, }} - @handle_llm_exceptions + @handle_openai_exceptions def generate( self, prompts: List[str], @@ -60,12 +60,6 @@ class StreamableAzureOpenAI(AzureOpenAI): ) -> LLMResult: return super().generate(prompts, stop, callbacks, **kwargs) - @handle_llm_exceptions_async - async def agenerate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> LLMResult: - return await super().agenerate(prompts, stop, callbacks, **kwargs) + @classmethod + def get_kwargs_from_model_params(cls, params: dict): + return params diff --git a/api/core/llm/streamable_chat_anthropic.py b/api/core/llm/streamable_chat_anthropic.py new file mode 100644 index 0000000000..de268800ca --- /dev/null +++ b/api/core/llm/streamable_chat_anthropic.py @@ -0,0 +1,39 @@ +from typing import List, Optional, Any, Dict + +from langchain.callbacks.manager import Callbacks +from langchain.chat_models import ChatAnthropic +from langchain.schema import BaseMessage, LLMResult + +from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions + + +class StreamableChatAnthropic(ChatAnthropic): + """ + Wrapper around Anthropic's large language model. + """ + + @handle_anthropic_exceptions + def generate( + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> LLMResult: + return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs) + + @classmethod + def get_kwargs_from_model_params(cls, params: dict): + params['model'] = params.get('model_name') + del params['model_name'] + + params['max_tokens_to_sample'] = params.get('max_tokens') + del params['max_tokens'] + + del params['frequency_penalty'] + del params['presence_penalty'] + + return params diff --git a/api/core/llm/streamable_chat_open_ai.py b/api/core/llm/streamable_chat_open_ai.py index a1fad702ab..ba4470b846 100644 --- a/api/core/llm/streamable_chat_open_ai.py +++ b/api/core/llm/streamable_chat_open_ai.py @@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Any from pydantic import root_validator -from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async +from core.llm.wrappers.openai_wrapper import handle_openai_exceptions class StreamableChatOpenAI(ChatOpenAI): @@ -48,30 +48,7 @@ class StreamableChatOpenAI(ChatOpenAI): "organization": self.openai_organization if self.openai_organization else None, } - def get_messages_tokens(self, messages: List[BaseMessage]) -> int: - """Get the number of tokens in a list of messages. - - Args: - messages: The messages to count the tokens of. - - Returns: - The number of tokens in the messages. - """ - tokens_per_message = 5 - tokens_per_request = 3 - - message_tokens = tokens_per_request - message_strs = '' - for message in messages: - message_strs += message.content - message_tokens += tokens_per_message - - # calc once - message_tokens += self.get_num_tokens(message_strs) - - return message_tokens - - @handle_llm_exceptions + @handle_openai_exceptions def generate( self, messages: List[List[BaseMessage]], @@ -81,12 +58,18 @@ class StreamableChatOpenAI(ChatOpenAI): ) -> LLMResult: return super().generate(messages, stop, callbacks, **kwargs) - @handle_llm_exceptions_async - async def agenerate( - self, - messages: List[List[BaseMessage]], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> LLMResult: - return await super().agenerate(messages, stop, callbacks, **kwargs) + @classmethod + def get_kwargs_from_model_params(cls, params: dict): + model_kwargs = { + 'top_p': params.get('top_p', 1), + 'frequency_penalty': params.get('frequency_penalty', 0), + 'presence_penalty': params.get('presence_penalty', 0), + } + + del params['top_p'] + del params['frequency_penalty'] + del params['presence_penalty'] + + params['model_kwargs'] = model_kwargs + + return params diff --git a/api/core/llm/streamable_open_ai.py b/api/core/llm/streamable_open_ai.py index a69e461d0d..37d2e11e7c 100644 --- a/api/core/llm/streamable_open_ai.py +++ b/api/core/llm/streamable_open_ai.py @@ -6,7 +6,7 @@ from typing import Optional, List, Dict, Any, Mapping from langchain import OpenAI from pydantic import root_validator -from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async +from core.llm.wrappers.openai_wrapper import handle_openai_exceptions class StreamableOpenAI(OpenAI): @@ -49,7 +49,7 @@ class StreamableOpenAI(OpenAI): "organization": self.openai_organization if self.openai_organization else None, }} - @handle_llm_exceptions + @handle_openai_exceptions def generate( self, prompts: List[str], @@ -59,12 +59,6 @@ class StreamableOpenAI(OpenAI): ) -> LLMResult: return super().generate(prompts, stop, callbacks, **kwargs) - @handle_llm_exceptions_async - async def agenerate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> LLMResult: - return await super().agenerate(prompts, stop, callbacks, **kwargs) + @classmethod + def get_kwargs_from_model_params(cls, params: dict): + return params diff --git a/api/core/llm/whisper.py b/api/core/llm/whisper.py index 3ad993be17..7f3bf3d794 100644 --- a/api/core/llm/whisper.py +++ b/api/core/llm/whisper.py @@ -1,6 +1,7 @@ import openai + +from core.llm.wrappers.openai_wrapper import handle_openai_exceptions from models.provider import ProviderName -from core.llm.error_handle_wraps import handle_llm_exceptions from core.llm.provider.base import BaseProvider @@ -13,7 +14,7 @@ class Whisper: self.client = openai.Audio self.credentials = provider.get_credentials() - @handle_llm_exceptions + @handle_openai_exceptions def transcribe(self, file): return self.client.transcribe( model='whisper-1', diff --git a/api/core/llm/wrappers/anthropic_wrapper.py b/api/core/llm/wrappers/anthropic_wrapper.py new file mode 100644 index 0000000000..7fddc277d2 --- /dev/null +++ b/api/core/llm/wrappers/anthropic_wrapper.py @@ -0,0 +1,27 @@ +import logging +from functools import wraps + +import anthropic + +from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \ + LLMBadRequestError + + +def handle_anthropic_exceptions(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except anthropic.APIConnectionError as e: + logging.exception("Failed to connect to Anthropic API.") + raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}") + except anthropic.RateLimitError: + raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.") + except anthropic.AuthenticationError as e: + raise LLMAuthorizationError(f"Anthropic: {e.message}") + except anthropic.BadRequestError as e: + raise LLMBadRequestError(f"Anthropic: {e.message}") + except anthropic.APIStatusError as e: + raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}") + + return wrapper diff --git a/api/core/llm/error_handle_wraps.py b/api/core/llm/wrappers/openai_wrapper.py similarity index 52% rename from api/core/llm/error_handle_wraps.py rename to api/core/llm/wrappers/openai_wrapper.py index 2bddebc26a..7f96e75edf 100644 --- a/api/core/llm/error_handle_wraps.py +++ b/api/core/llm/wrappers/openai_wrapper.py @@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat LLMBadRequestError -def handle_llm_exceptions(func): +def handle_openai_exceptions(func): @wraps(func) def wrapper(*args, **kwargs): try: @@ -29,27 +29,3 @@ def handle_llm_exceptions(func): raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e)) return wrapper - - -def handle_llm_exceptions_async(func): - @wraps(func) - async def wrapper(*args, **kwargs): - try: - return await func(*args, **kwargs) - except openai.error.InvalidRequestError as e: - logging.exception("Invalid request to OpenAI API.") - raise LLMBadRequestError(str(e)) - except openai.error.APIConnectionError as e: - logging.exception("Failed to connect to OpenAI API.") - raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e)) - except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e: - logging.exception("OpenAI service unavailable.") - raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e)) - except openai.error.RateLimitError as e: - raise LLMRateLimitError(str(e)) - except openai.error.AuthenticationError as e: - raise LLMAuthorizationError(str(e)) - except openai.error.OpenAIError as e: - raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e)) - - return wrapper diff --git a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py index 16f982c592..d96187ece0 100644 --- a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py +++ b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py @@ -1,7 +1,7 @@ from typing import Any, List, Dict, Union from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage +from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_open_ai import StreamableOpenAI @@ -12,8 +12,8 @@ from models.model import Conversation, Message class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): conversation: Conversation human_prefix: str = "Human" - ai_prefix: str = "AI" - llm: Union[StreamableChatOpenAI | StreamableOpenAI] + ai_prefix: str = "Assistant" + llm: BaseLanguageModel memory_key: str = "chat_history" max_token_limit: int = 2000 message_limit: int = 10 @@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): return chat_messages # prune the chat message if it exceeds the max token limit - curr_buffer_length = self.llm.get_messages_tokens(chat_messages) + curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages) if curr_buffer_length > self.max_token_limit: pruned_memory = [] while curr_buffer_length > self.max_token_limit and chat_messages: pruned_memory.append(chat_messages.pop(0)) - curr_buffer_length = self.llm.get_messages_tokens(chat_messages) + curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages) return chat_messages diff --git a/api/core/tool/dataset_index_tool.py b/api/core/tool/dataset_index_tool.py index 2776c6f48a..17b3c148b2 100644 --- a/api/core/tool/dataset_index_tool.py +++ b/api/core/tool/dataset_index_tool.py @@ -30,7 +30,7 @@ class DatasetTool(BaseTool): else: model_credentials = LLMBuilder.get_model_credentials( tenant_id=self.dataset.tenant_id, - model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), + model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'), model_name='text-embedding-ada-002' ) @@ -60,7 +60,7 @@ class DatasetTool(BaseTool): async def _arun(self, tool_input: str) -> str: model_credentials = LLMBuilder.get_model_credentials( tenant_id=self.dataset.tenant_id, - model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), + model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'), model_name='text-embedding-ada-002' ) diff --git a/api/events/event_handlers/create_provider_when_tenant_created.py b/api/events/event_handlers/create_provider_when_tenant_created.py index e967a5d071..0d35670258 100644 --- a/api/events/event_handlers/create_provider_when_tenant_created.py +++ b/api/events/event_handlers/create_provider_when_tenant_created.py @@ -1,4 +1,7 @@ +from flask import current_app + from events.tenant_event import tenant_was_updated +from models.provider import ProviderName from services.provider_service import ProviderService @@ -6,4 +9,16 @@ from services.provider_service import ProviderService def handle(sender, **kwargs): tenant = sender if tenant.status == 'normal': - ProviderService.create_system_provider(tenant) + ProviderService.create_system_provider( + tenant, + ProviderName.OPENAI.value, + current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'], + True + ) + + ProviderService.create_system_provider( + tenant, + ProviderName.ANTHROPIC.value, + current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'], + True + ) diff --git a/api/events/event_handlers/create_provider_when_tenant_updated.py b/api/events/event_handlers/create_provider_when_tenant_updated.py index 81a7d40ff6..366e13c599 100644 --- a/api/events/event_handlers/create_provider_when_tenant_updated.py +++ b/api/events/event_handlers/create_provider_when_tenant_updated.py @@ -1,4 +1,7 @@ +from flask import current_app + from events.tenant_event import tenant_was_created +from models.provider import ProviderName from services.provider_service import ProviderService @@ -6,4 +9,16 @@ from services.provider_service import ProviderService def handle(sender, **kwargs): tenant = sender if tenant.status == 'normal': - ProviderService.create_system_provider(tenant) + ProviderService.create_system_provider( + tenant, + ProviderName.OPENAI.value, + current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'], + True + ) + + ProviderService.create_system_provider( + tenant, + ProviderName.ANTHROPIC.value, + current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'], + True + ) diff --git a/api/requirements.txt b/api/requirements.txt index 5ffda02ced..97a1c5048d 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -10,7 +10,7 @@ flask-session2==1.3.1 flask-cors==3.0.10 gunicorn~=20.1.0 gevent~=22.10.2 -langchain==0.0.209 +langchain==0.0.230 openai~=0.27.5 psycopg2-binary~=2.9.6 pycryptodome==3.17 @@ -35,3 +35,4 @@ docx2txt==0.8 pypdfium2==4.16.0 resend~=0.5.1 pyjwt~=2.6.0 +anthropic~=0.3.4 diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 77293a1a05..ecf3a68269 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -6,6 +6,30 @@ from models.account import Account from services.dataset_service import DatasetService from core.llm.llm_builder import LLMBuilder +MODEL_PROVIDERS = [ + 'openai', + 'anthropic', +] + +MODELS_BY_APP_MODE = { + 'chat': [ + 'claude-instant-1', + 'claude-2', + 'gpt-4', + 'gpt-4-32k', + 'gpt-3.5-turbo', + 'gpt-3.5-turbo-16k', + ], + 'completion': [ + 'claude-instant-1', + 'claude-2', + 'gpt-4', + 'gpt-4-32k', + 'gpt-3.5-turbo', + 'gpt-3.5-turbo-16k', + 'text-davinci-003', + ] +} class AppModelConfigService: @staticmethod @@ -125,7 +149,7 @@ class AppModelConfigService: if not isinstance(config["speech_to_text"]["enabled"], bool): raise ValueError("enabled in speech_to_text must be of boolean type") - provider_name = LLMBuilder.get_default_provider(account.current_tenant_id) + provider_name = LLMBuilder.get_default_provider(account.current_tenant_id, 'whisper-1') if config["speech_to_text"]["enabled"] and provider_name != 'openai': raise ValueError("provider not support speech to text") @@ -153,14 +177,14 @@ class AppModelConfigService: raise ValueError("model must be of object type") # model.provider - if 'provider' not in config["model"] or config["model"]["provider"] != "openai": - raise ValueError("model.provider must be 'openai'") + if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS: + raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}") # model.name if 'name' not in config["model"]: raise ValueError("model.name is required") - if config["model"]["name"] not in llm_constant.models_by_mode[mode]: + if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]: raise ValueError("model.name must be in the specified model list") # model.completion_params diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 9870702e46..667fb4cb67 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -27,7 +27,7 @@ class AudioService: message = f"Audio size larger than {FILE_SIZE} mb" raise AudioTooLargeServiceError(message) - provider_name = LLMBuilder.get_default_provider(tenant_id) + provider_name = LLMBuilder.get_default_provider(tenant_id, 'whisper-1') if provider_name != ProviderName.OPENAI.value: raise ProviderNotSupportSpeechToTextServiceError() @@ -37,8 +37,3 @@ class AudioService: buffer.name = 'temp.mp3' return Whisper(provider_service.provider).transcribe(buffer) - - - - - \ No newline at end of file diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index b0029f80ad..17a4a4f4c6 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -31,7 +31,7 @@ class HitTestingService: model_credentials = LLMBuilder.get_model_credentials( tenant_id=dataset.tenant_id, - model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), + model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'), model_name='text-embedding-ada-002' ) diff --git a/api/services/provider_service.py b/api/services/provider_service.py index 39ee8353c0..fffd3fbd5b 100644 --- a/api/services/provider_service.py +++ b/api/services/provider_service.py @@ -10,50 +10,40 @@ from models.provider import * class ProviderService: @staticmethod - def init_supported_provider(tenant, edition): + def init_supported_provider(tenant): """Initialize the model provider, check whether the supported provider has a record""" - providers = Provider.query.filter_by(tenant_id=tenant.id).all() + need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value] - openai_provider_exists = False - azure_openai_provider_exists = False - - # TODO: The cloud version needs to construct the data of the SYSTEM type + providers = db.session.query(Provider).filter( + Provider.tenant_id == tenant.id, + Provider.provider_type == ProviderType.CUSTOM.value, + Provider.provider_name.in_(need_init_provider_names) + ).all() + exists_provider_names = [] for provider in providers: - if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value: - openai_provider_exists = True - if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value: - azure_openai_provider_exists = True + exists_provider_names.append(provider.provider_name) - # Initialize the model provider, check whether the supported provider has a record + not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names)) - # Create default providers if they don't exist - if not openai_provider_exists: - openai_provider = Provider( - tenant_id=tenant.id, - provider_name=ProviderName.OPENAI.value, - provider_type=ProviderType.CUSTOM.value, - is_valid=False - ) - db.session.add(openai_provider) + if not_exists_provider_names: + # Initialize the model provider, check whether the supported provider has a record + for provider_name in not_exists_provider_names: + provider = Provider( + tenant_id=tenant.id, + provider_name=provider_name, + provider_type=ProviderType.CUSTOM.value, + is_valid=False + ) + db.session.add(provider) - if not azure_openai_provider_exists: - azure_openai_provider = Provider( - tenant_id=tenant.id, - provider_name=ProviderName.AZURE_OPENAI.value, - provider_type=ProviderType.CUSTOM.value, - is_valid=False - ) - db.session.add(azure_openai_provider) - - if not openai_provider_exists or not azure_openai_provider_exists: db.session.commit() @staticmethod - def get_obfuscated_api_key(tenant, provider_name: ProviderName): + def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False): llm_provider_service = LLMProviderService(tenant.id, provider_name.value) - return llm_provider_service.get_provider_configs(obfuscated=True) + return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom) @staticmethod def get_token_type(tenant, provider_name: ProviderName): @@ -73,7 +63,7 @@ class ProviderService: return llm_provider_service.get_encrypted_token(configs) @staticmethod - def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, + def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200, is_valid: bool = True): if current_app.config['EDITION'] != 'CLOUD': return @@ -90,7 +80,7 @@ class ProviderService: provider_name=provider_name, provider_type=ProviderType.SYSTEM.value, quota_type=ProviderQuotaType.TRIAL.value, - quota_limit=200, + quota_limit=quota_limit, encrypted_config='', is_valid=is_valid, ) diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 92227ffa7a..abd1f7f3fb 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,6 +1,6 @@ from extensions.ext_database import db from models.account import Tenant -from models.provider import Provider, ProviderType +from models.provider import Provider, ProviderType, ProviderName class WorkspaceService: @@ -33,7 +33,7 @@ class WorkspaceService: if provider.is_valid and provider.encrypted_config: custom_provider = provider elif provider.provider_type == ProviderType.SYSTEM.value: - if provider.is_valid: + if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid: system_provider = provider if system_provider and not custom_provider: