diff --git a/api/commands.py b/api/commands.py
index 544b3110e7..4eed5de047 100644
--- a/api/commands.py
+++ b/api/commands.py
@@ -18,7 +18,8 @@ from models.model import Account
import secrets
import base64
-from models.provider import Provider
+from models.provider import Provider, ProviderName
+from services.provider_service import ProviderService
@click.command('reset-password', help='Reset the account password.')
@@ -193,9 +194,40 @@ def recreate_all_dataset_indexes():
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
+@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
+def sync_anthropic_hosted_providers():
+ click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
+ count = 0
+
+ page = 1
+ while True:
+ try:
+ tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
+ except NotFound:
+ break
+
+ page += 1
+ for tenant in tenants:
+ try:
+ click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
+ ProviderService.create_system_provider(
+ tenant,
+ ProviderName.ANTHROPIC.value,
+ current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
+ True
+ )
+ count += 1
+ except Exception as e:
+ click.echo(click.style('Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
+ continue
+
+ click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
+
+
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
app.cli.add_command(generate_invitation_codes)
app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(recreate_all_dataset_indexes)
+ app.cli.add_command(sync_anthropic_hosted_providers)
diff --git a/api/config.py b/api/config.py
index 99ecc67656..951574722a 100644
--- a/api/config.py
+++ b/api/config.py
@@ -51,6 +51,8 @@ DEFAULTS = {
'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai',
+ 'OPENAI_HOSTED_QUOTA_LIMIT': 200,
+ 'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
'TENANT_DOCUMENT_COUNT': 100
}
@@ -192,6 +194,10 @@ class Config:
# hosted provider credentials
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
+ self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
+
+ self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
+ self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
# By default it is False
# You could disable it for compatibility with certain OpenAPI providers
diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py
index e6bd2fdf28..075e8d4a91 100644
--- a/api/controllers/console/app/audio.py
+++ b/api/controllers/console/app/audio.py
@@ -50,8 +50,8 @@ class ChatMessageAudioApi(Resource):
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py
index 552271a9ec..e76186671d 100644
--- a/api/controllers/console/app/completion.py
+++ b/api/controllers/console/app/completion.py
@@ -63,8 +63,8 @@ class CompletionMessageApi(Resource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -133,8 +133,8 @@ class ChatMessageApi(Resource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
- except ProviderTokenNotInitError:
- yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+ except ProviderTokenNotInitError as ex:
+ yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py
index 905e5b273a..b6086b68b8 100644
--- a/api/controllers/console/app/error.py
+++ b/api/controllers/console/app/error.py
@@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException):
class ProviderQuotaExceededError(BaseHTTPException):
error_code = 'provider_quota_exceeded'
- description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
+ description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
"Please go to Settings -> Model Provider to complete your own provider credentials."
code = 400
diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py
index 6a74bf2584..6b9a0a2140 100644
--- a/api/controllers/console/app/generator.py
+++ b/api/controllers/console/app/generator.py
@@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource):
account.current_tenant_id,
args['prompt_template']
)
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -58,8 +58,8 @@ class RuleGenerateApi(Resource):
args['audiences'],
args['hoping_to_solve']
)
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py
index 88594425de..c5764a7ec7 100644
--- a/api/controllers/console/app/message.py
+++ b/api/controllers/console/app/message.py
@@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource):
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
- except ProviderTokenNotInitError:
- yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+ except ProviderTokenNotInitError as ex:
+ yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
@@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource):
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index 3e65b1319a..e165d2130a 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource):
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -324,8 +324,8 @@ class DatasetInitApi(Resource):
document_data=args,
account=current_user
)
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py
index 771d49045f..c1ccb30dd6 100644
--- a/api/controllers/console/datasets/hit_testing.py
+++ b/api/controllers/console/datasets/hit_testing.py
@@ -95,8 +95,8 @@ class HitTestingApi(Resource):
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py
index a027fe625e..991a228dd5 100644
--- a/api/controllers/console/explore/audio.py
+++ b/api/controllers/console/explore/audio.py
@@ -47,8 +47,8 @@ class ChatAudioApi(InstalledAppResource):
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py
index f2a1acd6d5..bc4b88ad15 100644
--- a/api/controllers/console/explore/completion.py
+++ b/api/controllers/console/explore/completion.py
@@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
- except ProviderTokenNotInitError:
- yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+ except ProviderTokenNotInitError as ex:
+ yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py
index b5b9547ff7..1232169eab 100644
--- a/api/controllers/console/explore/message.py
+++ b/api/controllers/console/explore/message.py
@@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
- except ProviderTokenNotInitError:
- yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+ except ProviderTokenNotInitError as ex:
+ yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
@@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/console/workspace/providers.py b/api/controllers/console/workspace/providers.py
index f2baec29c1..1991b1c6c9 100644
--- a/api/controllers/console/workspace/providers.py
+++ b/api/controllers/console/workspace/providers.py
@@ -3,6 +3,7 @@ import base64
import json
import logging
+from flask import current_app
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, abort
from werkzeug.exceptions import Forbidden
@@ -34,7 +35,7 @@ class ProviderListApi(Resource):
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
"""
- ProviderService.init_supported_provider(current_user.current_tenant, "cloud")
+ ProviderService.init_supported_provider(current_user.current_tenant)
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
provider_list = [
@@ -50,7 +51,8 @@ class ProviderListApi(Resource):
'quota_used': p.quota_used
} if p.provider_type == ProviderType.SYSTEM.value else {}),
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
- ProviderName(p.provider_name))
+ ProviderName(p.provider_name), only_custom=True)
+ if p.provider_type == ProviderType.CUSTOM.value else None
}
for p in providers
]
@@ -121,9 +123,10 @@ class ProviderTokenApi(Resource):
is_valid=token_is_valid)
db.session.add(provider_model)
- if provider_model.is_valid:
+ if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
other_providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
+ Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
Provider.provider_name != provider,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
@@ -133,7 +136,7 @@ class ProviderTokenApi(Resource):
db.session.commit()
- if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
+ if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
@@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource):
args = parser.parse_args()
# todo: remove this when the provider is supported
- if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
+ if provider in [ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
@@ -203,7 +206,19 @@ class ProviderSystemApi(Resource):
provider_model.is_valid = args['is_enabled']
db.session.commit()
elif not provider_model:
- ProviderService.create_system_provider(tenant, provider, args['is_enabled'])
+ if provider == ProviderName.OPENAI.value:
+ quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
+ elif provider == ProviderName.ANTHROPIC.value:
+ quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
+ else:
+ quota_limit = 0
+
+ ProviderService.create_system_provider(
+ tenant,
+ provider,
+ quota_limit,
+ args['is_enabled']
+ )
else:
abort(403)
diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py
index aacf1ca2a2..470afc6b42 100644
--- a/api/controllers/service_api/app/audio.py
+++ b/api/controllers/service_api/app/audio.py
@@ -43,8 +43,8 @@ class AudioApi(AppApiResource):
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py
index e5eb4153aa..448c408bce 100644
--- a/api/controllers/service_api/app/completion.py
+++ b/api/controllers/service_api/app/completion.py
@@ -54,8 +54,8 @@ class CompletionApi(AppApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -115,8 +115,8 @@ class ChatApi(AppApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
- except ProviderTokenNotInitError:
- yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+ except ProviderTokenNotInitError as ex:
+ yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index 3036882d71..e00de0f9a1 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource):
dataset_process_rule=dataset.latest_process_rule,
created_from='api'
)
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
document = documents[0]
if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py
index b07382176c..3e3fe3a28d 100644
--- a/api/controllers/web/audio.py
+++ b/api/controllers/web/audio.py
@@ -45,8 +45,8 @@ class AudioApi(WebApiResource):
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py
index 532bcfaa8d..db2f770e5a 100644
--- a/api/controllers/web/completion.py
+++ b/api/controllers/web/completion.py
@@ -52,8 +52,8 @@ class CompletionApi(WebApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -109,8 +109,8 @@ class ChatApi(WebApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
- except ProviderTokenNotInitError:
- yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+ except ProviderTokenNotInitError as ex:
+ yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py
index 0d519eac06..3d978a1099 100644
--- a/api/controllers/web/message.py
+++ b/api/controllers/web/message.py
@@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource):
raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
@@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
- except ProviderTokenNotInitError:
- yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
+ except ProviderTokenNotInitError as ex:
+ yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
@@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
- except ProviderTokenNotInitError:
- raise ProviderNotInitializeError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
diff --git a/api/core/__init__.py b/api/core/__init__.py
index 0b26044aa5..2dc9a9e869 100644
--- a/api/core/__init__.py
+++ b/api/core/__init__.py
@@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel):
api_key: str
+class HostedAnthropicCredential(BaseModel):
+ api_key: str
+
+
class HostedLLMCredentials(BaseModel):
openai: Optional[HostedOpenAICredential] = None
+ anthropic: Optional[HostedAnthropicCredential] = None
hosted_llm_credentials = HostedLLMCredentials()
@@ -26,3 +31,6 @@ def init_app(app: Flask):
if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
+
+ if app.config.get("ANTHROPIC_API_KEY"):
+ hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))
diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py
index c9db70fa76..6b3daf76e7 100644
--- a/api/core/callback_handler/llm_callback_handler.py
+++ b/api/core/callback_handler/llm_callback_handler.py
@@ -48,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
})
self.llm_message.prompt = real_prompts
- self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
+ self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
diff --git a/api/core/completion.py b/api/core/completion.py
index 38a81f2807..bcb75c2e07 100644
--- a/api/core/completion.py
+++ b/api/core/completion.py
@@ -118,6 +118,7 @@ class Completion:
prompt, stop_words = cls.get_main_llm_prompt(
mode=mode,
llm=final_llm,
+ model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
@@ -129,6 +130,7 @@ class Completion:
cls.recale_llm_max_tokens(
final_llm=final_llm,
+ model=app_model_config.model_dict,
prompt=prompt,
mode=mode
)
@@ -138,7 +140,8 @@ class Completion:
return response
@classmethod
- def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
+ def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
+ pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
@@ -151,10 +154,11 @@ class Completion:
if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template(
- template=("""Use the following CONTEXT as your learned knowledge:
-[CONTEXT]
+ template=("""Use the following context as your learned knowledge, inside XML tags.
+
+
{{context}}
-[END CONTEXT]
+
When answer to user:
- If you don't know, just say that you don't know.
@@ -204,10 +208,11 @@ And answer according to the language of the user's question.
if chain_output:
human_inputs['context'] = chain_output
- human_message_prompt += """Use the following CONTEXT as your learned knowledge.
-[CONTEXT]
+ human_message_prompt += """Use the following context as your learned knowledge, inside XML tags.
+
+
{{context}}
-[END CONTEXT]
+
When answer to user:
- If you don't know, just say that you don't know.
@@ -219,7 +224,7 @@ And answer according to the language of the user's question.
if pre_prompt:
human_message_prompt += pre_prompt
- query_prompt = "\nHuman: {{query}}\nAI: "
+ query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
if memory:
# append chat histories
@@ -228,9 +233,11 @@ And answer according to the language of the user's question.
inputs=human_inputs
)
- curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
- rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
- - memory.llm.max_tokens - curr_message_tokens
+ curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
+ model_name = model['name']
+ max_tokens = model.get("completion_params").get('max_tokens')
+ rest_tokens = llm_constant.max_context_token_length[model_name] \
+ - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
@@ -241,7 +248,10 @@ And answer according to the language of the user's question.
# if histories_param not in human_inputs:
# human_inputs[histories_param] = '{{' + histories_param + '}}'
- human_message_prompt += "\n\n" + histories
+ human_message_prompt += "\n\n" if human_message_prompt else ""
+ human_message_prompt += "Here is the chat histories between human and assistant, " \
+ "inside XML tags.\n\n"
+ human_message_prompt += histories + ""
human_message_prompt += query_prompt
@@ -307,13 +317,15 @@ And answer according to the language of the user's question.
model=app_model_config.model_dict
)
- model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
- max_tokens = llm.max_tokens
+ model_name = app_model_config.model_dict.get("name")
+ model_limited_tokens = llm_constant.max_context_token_length[model_name]
+ max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
# get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt(
mode=mode,
llm=llm,
+ model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
@@ -332,16 +344,17 @@ And answer according to the language of the user's question.
return rest_tokens
@classmethod
- def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
+ def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
prompt: Union[str, List[BaseMessage]], mode: str):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
- model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
- max_tokens = final_llm.max_tokens
+ model_name = model.get("name")
+ model_limited_tokens = llm_constant.max_context_token_length[model_name]
+ max_tokens = model.get("completion_params").get('max_tokens')
if mode == 'completion' and isinstance(final_llm, BaseLLM):
prompt_tokens = final_llm.get_num_tokens(prompt)
else:
- prompt_tokens = final_llm.get_messages_tokens(prompt)
+ prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
if prompt_tokens + max_tokens > model_limited_tokens:
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
@@ -350,9 +363,10 @@ And answer according to the language of the user's question.
@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):
- llm: StreamableOpenAI = LLMBuilder.to_llm(
+
+ llm = LLMBuilder.to_llm_from_model(
tenant_id=app.tenant_id,
- model_name='gpt-3.5-turbo',
+ model=app_model_config.model_dict,
streaming=streaming
)
@@ -360,6 +374,7 @@ And answer according to the language of the user's question.
original_prompt, _ = cls.get_main_llm_prompt(
mode="completion",
llm=llm,
+ model=app_model_config.model_dict,
pre_prompt=pre_prompt,
query=message.query,
inputs=message.inputs,
@@ -390,6 +405,7 @@ And answer according to the language of the user's question.
cls.recale_llm_max_tokens(
final_llm=llm,
+ model=app_model_config.model_dict,
prompt=prompt,
mode='completion'
)
diff --git a/api/core/constant/llm_constant.py b/api/core/constant/llm_constant.py
index 397a3d4c8f..3a02abc90e 100644
--- a/api/core/constant/llm_constant.py
+++ b/api/core/constant/llm_constant.py
@@ -1,6 +1,8 @@
from _decimal import Decimal
models = {
+ 'claude-instant-1': 'anthropic', # 100,000 tokens
+ 'claude-2': 'anthropic', # 100,000 tokens
'gpt-4': 'openai', # 8,192 tokens
'gpt-4-32k': 'openai', # 32,768 tokens
'gpt-3.5-turbo': 'openai', # 4,096 tokens
@@ -10,10 +12,13 @@ models = {
'text-curie-001': 'openai', # 2,049 tokens
'text-babbage-001': 'openai', # 2,049 tokens
'text-ada-001': 'openai', # 2,049 tokens
- 'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions
+ 'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
+ 'whisper-1': 'openai'
}
max_context_token_length = {
+ 'claude-instant-1': 100000,
+ 'claude-2': 100000,
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
@@ -23,17 +28,21 @@ max_context_token_length = {
'text-curie-001': 2049,
'text-babbage-001': 2049,
'text-ada-001': 2049,
- 'text-embedding-ada-002': 8191
+ 'text-embedding-ada-002': 8191,
}
models_by_mode = {
'chat': [
+ 'claude-instant-1', # 100,000 tokens
+ 'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
],
'completion': [
+ 'claude-instant-1', # 100,000 tokens
+ 'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
@@ -52,6 +61,14 @@ models_by_mode = {
model_currency = 'USD'
model_prices = {
+ 'claude-instant-1': {
+ 'prompt': Decimal('0.00163'),
+ 'completion': Decimal('0.00551'),
+ },
+ 'claude-2': {
+ 'prompt': Decimal('0.01102'),
+ 'completion': Decimal('0.03268'),
+ },
'gpt-4': {
'prompt': Decimal('0.03'),
'completion': Decimal('0.06'),
diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py
index d0e10fa119..3d8a50363f 100644
--- a/api/core/conversation_message_task.py
+++ b/api/core/conversation_message_task.py
@@ -56,7 +56,7 @@ class ConversationMessageTask:
)
def init(self):
- provider_name = LLMBuilder.get_default_provider(self.app.tenant_id)
+ provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
self.model_dict['provider'] = provider_name
override_model_configs = None
@@ -89,7 +89,7 @@ class ConversationMessageTask:
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
- system_instruction_tokens = llm.get_messages_tokens([system_message])
+ system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
if not self.conversation:
self.is_new_conversation = True
@@ -185,6 +185,7 @@ class ConversationMessageTask:
if provider and provider.provider_type == ProviderType.SYSTEM.value:
db.session.query(Provider).filter(
Provider.tenant_id == self.app.tenant_id,
+ Provider.provider_name == provider.provider_name,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + 1})
diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py
index 4030eb158c..174f259456 100644
--- a/api/core/embedding/cached_embedding.py
+++ b/api/core/embedding/cached_embedding.py
@@ -4,6 +4,7 @@ from typing import List
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
+from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from extensions.ext_database import db
from libs import helper
from models.dataset import Embedding
@@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings):
text_embeddings.extend(embedding_results)
return text_embeddings
+ @handle_openai_exceptions
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py
index 87540272da..9c6dbdfc8f 100644
--- a/api/core/generator/llm_generator.py
+++ b/api/core/generator/llm_generator.py
@@ -23,6 +23,10 @@ class LLMGenerator:
@classmethod
def generate_conversation_name(cls, tenant_id: str, query, answer):
prompt = CONVERSATION_TITLE_PROMPT
+
+ if len(query) > 2000:
+ query = query[:300] + "...[TRUNCATED]..." + query[-300:]
+
prompt = prompt.format(query=query)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
@@ -52,7 +56,17 @@ class LLMGenerator:
if not message.answer:
continue
- message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
+ if len(message.query) > 2000:
+ query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
+ else:
+ query = message.query
+
+ if len(message.answer) > 2000:
+ answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
+ else:
+ answer = message.answer
+
+ message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
context += message_qa_text
diff --git a/api/core/index/index.py b/api/core/index/index.py
index 617b763982..bd32b3c493 100644
--- a/api/core/index/index.py
+++ b/api/core/index/index.py
@@ -17,7 +17,7 @@ class IndexBuilder:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
- model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
+ model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)
diff --git a/api/core/llm/error.py b/api/core/llm/error.py
index 883d282e8a..9bba8401f2 100644
--- a/api/core/llm/error.py
+++ b/api/core/llm/error.py
@@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception):
"""
description = "Provider Token Not Init"
+ def __init__(self, *args, **kwargs):
+ self.description = args[0] if args else self.description
+
class QuotaExceededError(Exception):
"""
diff --git a/api/core/llm/llm_builder.py b/api/core/llm/llm_builder.py
index c2deda5351..50cd1b620e 100644
--- a/api/core/llm/llm_builder.py
+++ b/api/core/llm/llm_builder.py
@@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
+from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
-from models.provider import ProviderType
+from models.provider import ProviderType, ProviderName
class LLMBuilder:
@@ -32,43 +33,43 @@ class LLMBuilder:
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
- provider = cls.get_default_provider(tenant_id)
+ provider = cls.get_default_provider(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
+ llm_cls = None
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
- if provider == 'openai':
+ if provider == ProviderName.OPENAI.value:
llm_cls = StreamableChatOpenAI
- else:
+ elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureChatOpenAI
+ elif provider == ProviderName.ANTHROPIC.value:
+ llm_cls = StreamableChatAnthropic
elif mode == 'completion':
- if provider == 'openai':
+ if provider == ProviderName.OPENAI.value:
llm_cls = StreamableOpenAI
- else:
+ elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureOpenAI
- else:
+
+ if not llm_cls:
raise ValueError(f"model name {model_name} is not supported.")
-
model_kwargs = {
+ 'model_name': model_name,
+ 'temperature': kwargs.get('temperature', 0),
+ 'max_tokens': kwargs.get('max_tokens', 256),
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
+ 'callbacks': kwargs.get('callbacks', None),
+ 'streaming': kwargs.get('streaming', False),
}
- model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
+ model_kwargs.update(model_credentials)
+ model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
- return llm_cls(
- model_name=model_name,
- temperature=kwargs.get('temperature', 0),
- max_tokens=kwargs.get('max_tokens', 256),
- **model_extras_kwargs,
- callbacks=kwargs.get('callbacks', None),
- streaming=kwargs.get('streaming', False),
- # request_timeout=None
- **model_credentials
- )
+ return llm_cls(**model_kwargs)
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
@@ -118,14 +119,29 @@ class LLMBuilder:
return provider_service.get_credentials(model_name)
@classmethod
- def get_default_provider(cls, tenant_id: str) -> str:
- provider = BaseProvider.get_valid_provider(tenant_id)
- if not provider:
- raise ProviderTokenNotInitError()
+ def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
+ provider_name = llm_constant.models[model_name]
- if provider.provider_type == ProviderType.SYSTEM.value:
- provider_name = 'openai'
- else:
- provider_name = provider.provider_name
+ if provider_name == 'openai':
+ # get the default provider (openai / azure_openai) for the tenant
+ openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
+ azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
+
+ provider = None
+ if openai_provider:
+ provider = openai_provider
+ elif azure_openai_provider:
+ provider = azure_openai_provider
+
+ if not provider:
+ raise ProviderTokenNotInitError(
+ f"No valid {provider_name} model provider credentials found. "
+ f"Please go to Settings -> Model Provider to complete your provider credentials."
+ )
+
+ if provider.provider_type == ProviderType.SYSTEM.value:
+ provider_name = 'openai'
+ else:
+ provider_name = provider.provider_name
return provider_name
diff --git a/api/core/llm/provider/anthropic_provider.py b/api/core/llm/provider/anthropic_provider.py
index 4c7756305e..d6165d0329 100644
--- a/api/core/llm/provider/anthropic_provider.py
+++ b/api/core/llm/provider/anthropic_provider.py
@@ -1,23 +1,138 @@
-from typing import Optional
+import json
+import logging
+from typing import Optional, Union
+import anthropic
+from langchain.chat_models import ChatAnthropic
+from langchain.schema import HumanMessage
+
+from core import hosted_llm_credentials
+from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
-from models.provider import ProviderName
+from core.llm.provider.errors import ValidateFailedError
+from models.provider import ProviderName, ProviderType
class AnthropicProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
- credentials = self.get_credentials(model_id)
- # todo
- return []
+ return [
+ {
+ 'id': 'claude-instant-1',
+ 'name': 'claude-instant-1',
+ },
+ {
+ 'id': 'claude-2',
+ 'name': 'claude-2',
+ },
+ ]
def get_credentials(self, model_id: Optional[str] = None) -> dict:
- """
- Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
- The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
- """
- return {
- 'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
- }
+ return self.get_provider_api_key(model_id=model_id)
def get_provider_name(self):
- return ProviderName.ANTHROPIC
\ No newline at end of file
+ return ProviderName.ANTHROPIC
+
+ def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
+ """
+ Returns the provider configs.
+ """
+ try:
+ config = self.get_provider_api_key(only_custom=only_custom)
+ except:
+ config = {
+ 'anthropic_api_key': ''
+ }
+
+ if obfuscated:
+ if not config.get('anthropic_api_key'):
+ config = {
+ 'anthropic_api_key': ''
+ }
+
+ config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
+ return config
+
+ return config
+
+ def get_encrypted_token(self, config: Union[dict | str]):
+ """
+ Returns the encrypted token.
+ """
+ return json.dumps({
+ 'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
+ })
+
+ def get_decrypted_token(self, token: str):
+ """
+ Returns the decrypted token.
+ """
+ config = json.loads(token)
+ config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
+ return config
+
+ def get_token_type(self):
+ return dict
+
+ def config_validate(self, config: Union[dict | str]):
+ """
+ Validates the given config.
+ """
+ # check OpenAI / Azure OpenAI credential is valid
+ openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
+ azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
+
+ provider = None
+ if openai_provider:
+ provider = openai_provider
+ elif azure_openai_provider:
+ provider = azure_openai_provider
+
+ if not provider:
+ raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
+
+ if provider.provider_type == ProviderType.SYSTEM.value:
+ quota_used = provider.quota_used if provider.quota_used is not None else 0
+ quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
+ if quota_used >= quota_limit:
+ raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
+ f"please configure OpenAI or Azure OpenAI provider first.")
+
+ try:
+ if not isinstance(config, dict):
+ raise ValueError('Config must be a object.')
+
+ if 'anthropic_api_key' not in config:
+ raise ValueError('anthropic_api_key must be provided.')
+
+ chat_llm = ChatAnthropic(
+ model='claude-instant-1',
+ anthropic_api_key=config['anthropic_api_key'],
+ max_tokens_to_sample=10,
+ temperature=0,
+ default_request_timeout=60
+ )
+
+ messages = [
+ HumanMessage(
+ content="ping"
+ )
+ ]
+
+ chat_llm(messages)
+ except anthropic.APIConnectionError as ex:
+ raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
+ except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
+ raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
+ f"{ex.body['error']['type']}: {ex.body['error']['message']}")
+ except Exception as ex:
+ logging.exception('Anthropic config validation failed')
+ raise ex
+
+ def get_hosted_credentials(self) -> Union[str | dict]:
+ if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
+ raise ProviderTokenNotInitError(
+ f"No valid {self.get_provider_name().value} model provider credentials found. "
+ f"Please go to Settings -> Model Provider to complete your provider credentials."
+ )
+
+ return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}
diff --git a/api/core/llm/provider/azure_provider.py b/api/core/llm/provider/azure_provider.py
index 54b7af6462..706adc7be7 100644
--- a/api/core/llm/provider/azure_provider.py
+++ b/api/core/llm/provider/azure_provider.py
@@ -52,12 +52,12 @@ class AzureProvider(BaseProvider):
def get_provider_name(self):
return ProviderName.AZURE_OPENAI
- def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
+ def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
- config = self.get_provider_api_key()
+ config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'openai_api_type': 'azure',
@@ -81,7 +81,6 @@ class AzureProvider(BaseProvider):
return config
def get_token_type(self):
- # TODO: change to dict when implemented
return dict
def config_validate(self, config: Union[dict | str]):
diff --git a/api/core/llm/provider/base.py b/api/core/llm/provider/base.py
index 71bb32dca6..c3ff5cf237 100644
--- a/api/core/llm/provider/base.py
+++ b/api/core/llm/provider/base.py
@@ -2,7 +2,7 @@ import base64
from abc import ABC, abstractmethod
from typing import Optional, Union
-from core import hosted_llm_credentials
+from core.constant import llm_constant
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
from extensions.ext_database import db
from libs import rsa
@@ -14,15 +14,18 @@ class BaseProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
- def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
+ def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
"""
- provider = self.get_provider(prefer_custom)
+ provider = self.get_provider(only_custom)
if not provider:
- raise ProviderTokenNotInitError()
+ raise ProviderTokenNotInitError(
+ f"No valid {llm_constant.models[model_id]} model provider credentials found. "
+ f"Please go to Settings -> Model Provider to complete your provider credentials."
+ )
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
@@ -38,18 +41,19 @@ class BaseProvider(ABC):
else:
return self.get_decrypted_token(provider.encrypted_config)
- def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
+ def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
- return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
+ return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
@classmethod
- def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
+ def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
+ Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
- If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
+ If both CUSTOM and System providers exist.
"""
query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id
@@ -58,39 +62,31 @@ class BaseProvider(ABC):
if provider_name:
query = query.filter(Provider.provider_name == provider_name)
- providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
+ if only_custom:
+ query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
- custom_provider = None
- system_provider = None
+ providers = query.order_by(Provider.provider_type.asc()).all()
for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
- custom_provider = provider
+ return provider
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
- system_provider = provider
+ return provider
- if custom_provider:
- return custom_provider
- elif system_provider:
- return system_provider
- else:
- return None
+ return None
- def get_hosted_credentials(self) -> str:
- if self.get_provider_name() != ProviderName.OPENAI:
- raise ProviderTokenNotInitError()
+ def get_hosted_credentials(self) -> Union[str | dict]:
+ raise ProviderTokenNotInitError(
+ f"No valid {self.get_provider_name().value} model provider credentials found. "
+ f"Please go to Settings -> Model Provider to complete your provider credentials."
+ )
- if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
- raise ProviderTokenNotInitError()
-
- return hosted_llm_credentials.openai.api_key
-
- def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
+ def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
- config = self.get_provider_api_key()
+ config = self.get_provider_api_key(only_custom=only_custom)
except:
config = ''
diff --git a/api/core/llm/provider/llm_provider_service.py b/api/core/llm/provider/llm_provider_service.py
index ca4f8bec6d..a520e3d6bb 100644
--- a/api/core/llm/provider/llm_provider_service.py
+++ b/api/core/llm/provider/llm_provider_service.py
@@ -31,11 +31,11 @@ class LLMProviderService:
def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.provider.get_credentials(model_id)
- def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
- return self.provider.get_provider_configs(obfuscated)
+ def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
+ return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
- def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
- return self.provider.get_provider(prefer_custom)
+ def get_provider_db_record(self) -> Optional[Provider]:
+ return self.provider.get_provider()
def config_validate(self, config: Union[dict | str]):
"""
diff --git a/api/core/llm/provider/openai_provider.py b/api/core/llm/provider/openai_provider.py
index 8257ad3aab..b24e98e5d1 100644
--- a/api/core/llm/provider/openai_provider.py
+++ b/api/core/llm/provider/openai_provider.py
@@ -4,6 +4,8 @@ from typing import Optional, Union
import openai
from openai.error import AuthenticationError, OpenAIError
+from core import hosted_llm_credentials
+from core.llm.error import ProviderTokenNotInitError
from core.llm.moderation import Moderation
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
@@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider):
except Exception as ex:
logging.exception('OpenAI config validation failed')
raise ex
+
+ def get_hosted_credentials(self) -> Union[str | dict]:
+ if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
+ raise ProviderTokenNotInitError(
+ f"No valid {self.get_provider_name().value} model provider credentials found. "
+ f"Please go to Settings -> Model Provider to complete your provider credentials."
+ )
+
+ return hosted_llm_credentials.openai.api_key
diff --git a/api/core/llm/streamable_azure_chat_open_ai.py b/api/core/llm/streamable_azure_chat_open_ai.py
index 4d1d5be0b3..5bff42fedb 100644
--- a/api/core/llm/streamable_azure_chat_open_ai.py
+++ b/api/core/llm/streamable_azure_chat_open_ai.py
@@ -1,11 +1,11 @@
-from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
-from langchain.schema import BaseMessage, ChatResult, LLMResult
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import BaseMessage, LLMResult
from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any
from pydantic import root_validator
-from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
+from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureChatOpenAI(AzureChatOpenAI):
@@ -46,30 +46,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}
- def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
- """Get the number of tokens in a list of messages.
-
- Args:
- messages: The messages to count the tokens of.
-
- Returns:
- The number of tokens in the messages.
- """
- tokens_per_message = 5
- tokens_per_request = 3
-
- message_tokens = tokens_per_request
- message_strs = ''
- for message in messages:
- message_strs += message.content
- message_tokens += tokens_per_message
-
- # calc once
- message_tokens += self.get_num_tokens(message_strs)
-
- return message_tokens
-
- @handle_llm_exceptions
+ @handle_openai_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
@@ -79,12 +56,18 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
) -> LLMResult:
return super().generate(messages, stop, callbacks, **kwargs)
- @handle_llm_exceptions_async
- async def agenerate(
- self,
- messages: List[List[BaseMessage]],
- stop: Optional[List[str]] = None,
- callbacks: Callbacks = None,
- **kwargs: Any,
- ) -> LLMResult:
- return await super().agenerate(messages, stop, callbacks, **kwargs)
+ @classmethod
+ def get_kwargs_from_model_params(cls, params: dict):
+ model_kwargs = {
+ 'top_p': params.get('top_p', 1),
+ 'frequency_penalty': params.get('frequency_penalty', 0),
+ 'presence_penalty': params.get('presence_penalty', 0),
+ }
+
+ del params['top_p']
+ del params['frequency_penalty']
+ del params['presence_penalty']
+
+ params['model_kwargs'] = model_kwargs
+
+ return params
diff --git a/api/core/llm/streamable_azure_open_ai.py b/api/core/llm/streamable_azure_open_ai.py
index ac2258bb61..108f723aa4 100644
--- a/api/core/llm/streamable_azure_open_ai.py
+++ b/api/core/llm/streamable_azure_open_ai.py
@@ -5,7 +5,7 @@ from typing import Optional, List, Dict, Mapping, Any
from pydantic import root_validator
-from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
+from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureOpenAI(AzureOpenAI):
@@ -50,7 +50,7 @@ class StreamableAzureOpenAI(AzureOpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}}
- @handle_llm_exceptions
+ @handle_openai_exceptions
def generate(
self,
prompts: List[str],
@@ -60,12 +60,6 @@ class StreamableAzureOpenAI(AzureOpenAI):
) -> LLMResult:
return super().generate(prompts, stop, callbacks, **kwargs)
- @handle_llm_exceptions_async
- async def agenerate(
- self,
- prompts: List[str],
- stop: Optional[List[str]] = None,
- callbacks: Callbacks = None,
- **kwargs: Any,
- ) -> LLMResult:
- return await super().agenerate(prompts, stop, callbacks, **kwargs)
+ @classmethod
+ def get_kwargs_from_model_params(cls, params: dict):
+ return params
diff --git a/api/core/llm/streamable_chat_anthropic.py b/api/core/llm/streamable_chat_anthropic.py
new file mode 100644
index 0000000000..de268800ca
--- /dev/null
+++ b/api/core/llm/streamable_chat_anthropic.py
@@ -0,0 +1,39 @@
+from typing import List, Optional, Any, Dict
+
+from langchain.callbacks.manager import Callbacks
+from langchain.chat_models import ChatAnthropic
+from langchain.schema import BaseMessage, LLMResult
+
+from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
+
+
+class StreamableChatAnthropic(ChatAnthropic):
+ """
+ Wrapper around Anthropic's large language model.
+ """
+
+ @handle_anthropic_exceptions
+ def generate(
+ self,
+ messages: List[List[BaseMessage]],
+ stop: Optional[List[str]] = None,
+ callbacks: Callbacks = None,
+ *,
+ tags: Optional[List[str]] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> LLMResult:
+ return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
+
+ @classmethod
+ def get_kwargs_from_model_params(cls, params: dict):
+ params['model'] = params.get('model_name')
+ del params['model_name']
+
+ params['max_tokens_to_sample'] = params.get('max_tokens')
+ del params['max_tokens']
+
+ del params['frequency_penalty']
+ del params['presence_penalty']
+
+ return params
diff --git a/api/core/llm/streamable_chat_open_ai.py b/api/core/llm/streamable_chat_open_ai.py
index a1fad702ab..ba4470b846 100644
--- a/api/core/llm/streamable_chat_open_ai.py
+++ b/api/core/llm/streamable_chat_open_ai.py
@@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Any
from pydantic import root_validator
-from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
+from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableChatOpenAI(ChatOpenAI):
@@ -48,30 +48,7 @@ class StreamableChatOpenAI(ChatOpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}
- def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
- """Get the number of tokens in a list of messages.
-
- Args:
- messages: The messages to count the tokens of.
-
- Returns:
- The number of tokens in the messages.
- """
- tokens_per_message = 5
- tokens_per_request = 3
-
- message_tokens = tokens_per_request
- message_strs = ''
- for message in messages:
- message_strs += message.content
- message_tokens += tokens_per_message
-
- # calc once
- message_tokens += self.get_num_tokens(message_strs)
-
- return message_tokens
-
- @handle_llm_exceptions
+ @handle_openai_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
@@ -81,12 +58,18 @@ class StreamableChatOpenAI(ChatOpenAI):
) -> LLMResult:
return super().generate(messages, stop, callbacks, **kwargs)
- @handle_llm_exceptions_async
- async def agenerate(
- self,
- messages: List[List[BaseMessage]],
- stop: Optional[List[str]] = None,
- callbacks: Callbacks = None,
- **kwargs: Any,
- ) -> LLMResult:
- return await super().agenerate(messages, stop, callbacks, **kwargs)
+ @classmethod
+ def get_kwargs_from_model_params(cls, params: dict):
+ model_kwargs = {
+ 'top_p': params.get('top_p', 1),
+ 'frequency_penalty': params.get('frequency_penalty', 0),
+ 'presence_penalty': params.get('presence_penalty', 0),
+ }
+
+ del params['top_p']
+ del params['frequency_penalty']
+ del params['presence_penalty']
+
+ params['model_kwargs'] = model_kwargs
+
+ return params
diff --git a/api/core/llm/streamable_open_ai.py b/api/core/llm/streamable_open_ai.py
index a69e461d0d..37d2e11e7c 100644
--- a/api/core/llm/streamable_open_ai.py
+++ b/api/core/llm/streamable_open_ai.py
@@ -6,7 +6,7 @@ from typing import Optional, List, Dict, Any, Mapping
from langchain import OpenAI
from pydantic import root_validator
-from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
+from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableOpenAI(OpenAI):
@@ -49,7 +49,7 @@ class StreamableOpenAI(OpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}}
- @handle_llm_exceptions
+ @handle_openai_exceptions
def generate(
self,
prompts: List[str],
@@ -59,12 +59,6 @@ class StreamableOpenAI(OpenAI):
) -> LLMResult:
return super().generate(prompts, stop, callbacks, **kwargs)
- @handle_llm_exceptions_async
- async def agenerate(
- self,
- prompts: List[str],
- stop: Optional[List[str]] = None,
- callbacks: Callbacks = None,
- **kwargs: Any,
- ) -> LLMResult:
- return await super().agenerate(prompts, stop, callbacks, **kwargs)
+ @classmethod
+ def get_kwargs_from_model_params(cls, params: dict):
+ return params
diff --git a/api/core/llm/whisper.py b/api/core/llm/whisper.py
index 3ad993be17..7f3bf3d794 100644
--- a/api/core/llm/whisper.py
+++ b/api/core/llm/whisper.py
@@ -1,6 +1,7 @@
import openai
+
+from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from models.provider import ProviderName
-from core.llm.error_handle_wraps import handle_llm_exceptions
from core.llm.provider.base import BaseProvider
@@ -13,7 +14,7 @@ class Whisper:
self.client = openai.Audio
self.credentials = provider.get_credentials()
- @handle_llm_exceptions
+ @handle_openai_exceptions
def transcribe(self, file):
return self.client.transcribe(
model='whisper-1',
diff --git a/api/core/llm/wrappers/anthropic_wrapper.py b/api/core/llm/wrappers/anthropic_wrapper.py
new file mode 100644
index 0000000000..7fddc277d2
--- /dev/null
+++ b/api/core/llm/wrappers/anthropic_wrapper.py
@@ -0,0 +1,27 @@
+import logging
+from functools import wraps
+
+import anthropic
+
+from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
+ LLMBadRequestError
+
+
+def handle_anthropic_exceptions(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except anthropic.APIConnectionError as e:
+ logging.exception("Failed to connect to Anthropic API.")
+ raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
+ except anthropic.RateLimitError:
+ raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
+ except anthropic.AuthenticationError as e:
+ raise LLMAuthorizationError(f"Anthropic: {e.message}")
+ except anthropic.BadRequestError as e:
+ raise LLMBadRequestError(f"Anthropic: {e.message}")
+ except anthropic.APIStatusError as e:
+ raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
+
+ return wrapper
diff --git a/api/core/llm/error_handle_wraps.py b/api/core/llm/wrappers/openai_wrapper.py
similarity index 52%
rename from api/core/llm/error_handle_wraps.py
rename to api/core/llm/wrappers/openai_wrapper.py
index 2bddebc26a..7f96e75edf 100644
--- a/api/core/llm/error_handle_wraps.py
+++ b/api/core/llm/wrappers/openai_wrapper.py
@@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat
LLMBadRequestError
-def handle_llm_exceptions(func):
+def handle_openai_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
@@ -29,27 +29,3 @@ def handle_llm_exceptions(func):
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper
-
-
-def handle_llm_exceptions_async(func):
- @wraps(func)
- async def wrapper(*args, **kwargs):
- try:
- return await func(*args, **kwargs)
- except openai.error.InvalidRequestError as e:
- logging.exception("Invalid request to OpenAI API.")
- raise LLMBadRequestError(str(e))
- except openai.error.APIConnectionError as e:
- logging.exception("Failed to connect to OpenAI API.")
- raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
- except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
- logging.exception("OpenAI service unavailable.")
- raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
- except openai.error.RateLimitError as e:
- raise LLMRateLimitError(str(e))
- except openai.error.AuthenticationError as e:
- raise LLMAuthorizationError(str(e))
- except openai.error.OpenAIError as e:
- raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
-
- return wrapper
diff --git a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
index 16f982c592..d96187ece0 100644
--- a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
+++ b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
@@ -1,7 +1,7 @@
from typing import Any, List, Dict, Union
from langchain.memory.chat_memory import BaseChatMemory
-from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
+from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
@@ -12,8 +12,8 @@ from models.model import Conversation, Message
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
conversation: Conversation
human_prefix: str = "Human"
- ai_prefix: str = "AI"
- llm: Union[StreamableChatOpenAI | StreamableOpenAI]
+ ai_prefix: str = "Assistant"
+ llm: BaseLanguageModel
memory_key: str = "chat_history"
max_token_limit: int = 2000
message_limit: int = 10
@@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
return chat_messages
# prune the chat message if it exceeds the max token limit
- curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
+ curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0))
- curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
+ curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
return chat_messages
diff --git a/api/core/tool/dataset_index_tool.py b/api/core/tool/dataset_index_tool.py
index 2776c6f48a..17b3c148b2 100644
--- a/api/core/tool/dataset_index_tool.py
+++ b/api/core/tool/dataset_index_tool.py
@@ -30,7 +30,7 @@ class DatasetTool(BaseTool):
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
- model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
+ model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)
@@ -60,7 +60,7 @@ class DatasetTool(BaseTool):
async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
- model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
+ model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)
diff --git a/api/events/event_handlers/create_provider_when_tenant_created.py b/api/events/event_handlers/create_provider_when_tenant_created.py
index e967a5d071..0d35670258 100644
--- a/api/events/event_handlers/create_provider_when_tenant_created.py
+++ b/api/events/event_handlers/create_provider_when_tenant_created.py
@@ -1,4 +1,7 @@
+from flask import current_app
+
from events.tenant_event import tenant_was_updated
+from models.provider import ProviderName
from services.provider_service import ProviderService
@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
def handle(sender, **kwargs):
tenant = sender
if tenant.status == 'normal':
- ProviderService.create_system_provider(tenant)
+ ProviderService.create_system_provider(
+ tenant,
+ ProviderName.OPENAI.value,
+ current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
+ True
+ )
+
+ ProviderService.create_system_provider(
+ tenant,
+ ProviderName.ANTHROPIC.value,
+ current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
+ True
+ )
diff --git a/api/events/event_handlers/create_provider_when_tenant_updated.py b/api/events/event_handlers/create_provider_when_tenant_updated.py
index 81a7d40ff6..366e13c599 100644
--- a/api/events/event_handlers/create_provider_when_tenant_updated.py
+++ b/api/events/event_handlers/create_provider_when_tenant_updated.py
@@ -1,4 +1,7 @@
+from flask import current_app
+
from events.tenant_event import tenant_was_created
+from models.provider import ProviderName
from services.provider_service import ProviderService
@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
def handle(sender, **kwargs):
tenant = sender
if tenant.status == 'normal':
- ProviderService.create_system_provider(tenant)
+ ProviderService.create_system_provider(
+ tenant,
+ ProviderName.OPENAI.value,
+ current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
+ True
+ )
+
+ ProviderService.create_system_provider(
+ tenant,
+ ProviderName.ANTHROPIC.value,
+ current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
+ True
+ )
diff --git a/api/requirements.txt b/api/requirements.txt
index 5ffda02ced..97a1c5048d 100644
--- a/api/requirements.txt
+++ b/api/requirements.txt
@@ -10,7 +10,7 @@ flask-session2==1.3.1
flask-cors==3.0.10
gunicorn~=20.1.0
gevent~=22.10.2
-langchain==0.0.209
+langchain==0.0.230
openai~=0.27.5
psycopg2-binary~=2.9.6
pycryptodome==3.17
@@ -35,3 +35,4 @@ docx2txt==0.8
pypdfium2==4.16.0
resend~=0.5.1
pyjwt~=2.6.0
+anthropic~=0.3.4
diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py
index 77293a1a05..ecf3a68269 100644
--- a/api/services/app_model_config_service.py
+++ b/api/services/app_model_config_service.py
@@ -6,6 +6,30 @@ from models.account import Account
from services.dataset_service import DatasetService
from core.llm.llm_builder import LLMBuilder
+MODEL_PROVIDERS = [
+ 'openai',
+ 'anthropic',
+]
+
+MODELS_BY_APP_MODE = {
+ 'chat': [
+ 'claude-instant-1',
+ 'claude-2',
+ 'gpt-4',
+ 'gpt-4-32k',
+ 'gpt-3.5-turbo',
+ 'gpt-3.5-turbo-16k',
+ ],
+ 'completion': [
+ 'claude-instant-1',
+ 'claude-2',
+ 'gpt-4',
+ 'gpt-4-32k',
+ 'gpt-3.5-turbo',
+ 'gpt-3.5-turbo-16k',
+ 'text-davinci-003',
+ ]
+}
class AppModelConfigService:
@staticmethod
@@ -125,7 +149,7 @@ class AppModelConfigService:
if not isinstance(config["speech_to_text"]["enabled"], bool):
raise ValueError("enabled in speech_to_text must be of boolean type")
- provider_name = LLMBuilder.get_default_provider(account.current_tenant_id)
+ provider_name = LLMBuilder.get_default_provider(account.current_tenant_id, 'whisper-1')
if config["speech_to_text"]["enabled"] and provider_name != 'openai':
raise ValueError("provider not support speech to text")
@@ -153,14 +177,14 @@ class AppModelConfigService:
raise ValueError("model must be of object type")
# model.provider
- if 'provider' not in config["model"] or config["model"]["provider"] != "openai":
- raise ValueError("model.provider must be 'openai'")
+ if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS:
+ raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}")
# model.name
if 'name' not in config["model"]:
raise ValueError("model.name is required")
- if config["model"]["name"] not in llm_constant.models_by_mode[mode]:
+ if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]:
raise ValueError("model.name must be in the specified model list")
# model.completion_params
diff --git a/api/services/audio_service.py b/api/services/audio_service.py
index 9870702e46..667fb4cb67 100644
--- a/api/services/audio_service.py
+++ b/api/services/audio_service.py
@@ -27,7 +27,7 @@ class AudioService:
message = f"Audio size larger than {FILE_SIZE} mb"
raise AudioTooLargeServiceError(message)
- provider_name = LLMBuilder.get_default_provider(tenant_id)
+ provider_name = LLMBuilder.get_default_provider(tenant_id, 'whisper-1')
if provider_name != ProviderName.OPENAI.value:
raise ProviderNotSupportSpeechToTextServiceError()
@@ -37,8 +37,3 @@ class AudioService:
buffer.name = 'temp.mp3'
return Whisper(provider_service.provider).transcribe(buffer)
-
-
-
-
-
\ No newline at end of file
diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py
index b0029f80ad..17a4a4f4c6 100644
--- a/api/services/hit_testing_service.py
+++ b/api/services/hit_testing_service.py
@@ -31,7 +31,7 @@ class HitTestingService:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
- model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
+ model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)
diff --git a/api/services/provider_service.py b/api/services/provider_service.py
index 39ee8353c0..fffd3fbd5b 100644
--- a/api/services/provider_service.py
+++ b/api/services/provider_service.py
@@ -10,50 +10,40 @@ from models.provider import *
class ProviderService:
@staticmethod
- def init_supported_provider(tenant, edition):
+ def init_supported_provider(tenant):
"""Initialize the model provider, check whether the supported provider has a record"""
- providers = Provider.query.filter_by(tenant_id=tenant.id).all()
+ need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value]
- openai_provider_exists = False
- azure_openai_provider_exists = False
-
- # TODO: The cloud version needs to construct the data of the SYSTEM type
+ providers = db.session.query(Provider).filter(
+ Provider.tenant_id == tenant.id,
+ Provider.provider_type == ProviderType.CUSTOM.value,
+ Provider.provider_name.in_(need_init_provider_names)
+ ).all()
+ exists_provider_names = []
for provider in providers:
- if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
- openai_provider_exists = True
- if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
- azure_openai_provider_exists = True
+ exists_provider_names.append(provider.provider_name)
- # Initialize the model provider, check whether the supported provider has a record
+ not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names))
- # Create default providers if they don't exist
- if not openai_provider_exists:
- openai_provider = Provider(
- tenant_id=tenant.id,
- provider_name=ProviderName.OPENAI.value,
- provider_type=ProviderType.CUSTOM.value,
- is_valid=False
- )
- db.session.add(openai_provider)
+ if not_exists_provider_names:
+ # Initialize the model provider, check whether the supported provider has a record
+ for provider_name in not_exists_provider_names:
+ provider = Provider(
+ tenant_id=tenant.id,
+ provider_name=provider_name,
+ provider_type=ProviderType.CUSTOM.value,
+ is_valid=False
+ )
+ db.session.add(provider)
- if not azure_openai_provider_exists:
- azure_openai_provider = Provider(
- tenant_id=tenant.id,
- provider_name=ProviderName.AZURE_OPENAI.value,
- provider_type=ProviderType.CUSTOM.value,
- is_valid=False
- )
- db.session.add(azure_openai_provider)
-
- if not openai_provider_exists or not azure_openai_provider_exists:
db.session.commit()
@staticmethod
- def get_obfuscated_api_key(tenant, provider_name: ProviderName):
+ def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False):
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
- return llm_provider_service.get_provider_configs(obfuscated=True)
+ return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom)
@staticmethod
def get_token_type(tenant, provider_name: ProviderName):
@@ -73,7 +63,7 @@ class ProviderService:
return llm_provider_service.get_encrypted_token(configs)
@staticmethod
- def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value,
+ def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200,
is_valid: bool = True):
if current_app.config['EDITION'] != 'CLOUD':
return
@@ -90,7 +80,7 @@ class ProviderService:
provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
- quota_limit=200,
+ quota_limit=quota_limit,
encrypted_config='',
is_valid=is_valid,
)
diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py
index 92227ffa7a..abd1f7f3fb 100644
--- a/api/services/workspace_service.py
+++ b/api/services/workspace_service.py
@@ -1,6 +1,6 @@
from extensions.ext_database import db
from models.account import Tenant
-from models.provider import Provider, ProviderType
+from models.provider import Provider, ProviderType, ProviderName
class WorkspaceService:
@@ -33,7 +33,7 @@ class WorkspaceService:
if provider.is_valid and provider.encrypted_config:
custom_provider = provider
elif provider.provider_type == ProviderType.SYSTEM.value:
- if provider.is_valid:
+ if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid:
system_provider = provider
if system_provider and not custom_provider: