mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-13 07:11:47 +08:00
Merge branch 'main' into tp
This commit is contained in:
commit
05aec43ee3
@ -36,7 +36,7 @@
|
||||
| 被团队成员标记为高优先级的功能 | 高优先级 |
|
||||
| 在 [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) 内反馈的常见功能请求 | 中等优先级 |
|
||||
| 非核心功能和小幅改进 | 低优先级 |
|
||||
| 有价值当不紧急 | 未来功能 |
|
||||
| 有价值但不紧急 | 未来功能 |
|
||||
|
||||
### 其他任何事情(例如 bug 报告、性能优化、拼写错误更正):
|
||||
* 立即开始编码。
|
||||
@ -138,7 +138,7 @@ Dify 的后端使用 Python 编写,使用 [Flask](https://flask.palletsproject
|
||||
├── models // 描述数据模型和 API 响应的形状
|
||||
├── public // 如 favicon 等元资源
|
||||
├── service // 定义 API 操作的形状
|
||||
├── test
|
||||
├── test
|
||||
├── types // 函数参数和返回值的描述
|
||||
└── utils // 共享的实用函数
|
||||
```
|
||||
|
@ -65,14 +65,12 @@
|
||||
|
||||
8. Start Dify [web](../web) service.
|
||||
9. Setup your application by visiting `http://localhost:3000`...
|
||||
10. If you need to debug local async processing, please start the worker service.
|
||||
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
|
||||
```
|
||||
|
||||
The started celery app handles the async tasks, e.g. dataset importing and documents indexing.
|
||||
|
||||
## Testing
|
||||
|
||||
1. Install dependencies for both the backend and the test environment
|
||||
|
130
api/commands.py
130
api/commands.py
@ -28,28 +28,28 @@ from services.account_service import RegisterService, TenantService
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset")
|
||||
@click.option("--new-password", prompt=True, help="the new password.")
|
||||
@click.option("--password-confirm", prompt=True, help="the new password confirm.")
|
||||
@click.option("--email", prompt=True, help="Account email to reset password for")
|
||||
@click.option("--new-password", prompt=True, help="New password")
|
||||
@click.option("--password-confirm", prompt=True, help="Confirm new password")
|
||||
def reset_password(email, new_password, password_confirm):
|
||||
"""
|
||||
Reset password of owner account
|
||||
Only available in SELF_HOSTED mode
|
||||
"""
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style("sorry. The two passwords do not match.", fg="red"))
|
||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||
return
|
||||
|
||||
account = db.session.query(Account).filter(Account.email == email).one_or_none()
|
||||
|
||||
if not account:
|
||||
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
|
||||
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red"))
|
||||
click.echo(click.style("Invalid password. Must match {}".format(password_pattern), fg="red"))
|
||||
return
|
||||
|
||||
# generate password salt
|
||||
@ -62,37 +62,37 @@ def reset_password(email, new_password, password_confirm):
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
click.echo(click.style("Congratulations! Password has been reset.", fg="green"))
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset")
|
||||
@click.option("--new-email", prompt=True, help="the new email.")
|
||||
@click.option("--email-confirm", prompt=True, help="the new email confirm.")
|
||||
@click.option("--email", prompt=True, help="Current account email")
|
||||
@click.option("--new-email", prompt=True, help="New email")
|
||||
@click.option("--email-confirm", prompt=True, help="Confirm new email")
|
||||
def reset_email(email, new_email, email_confirm):
|
||||
"""
|
||||
Replace account email
|
||||
:return:
|
||||
"""
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red"))
|
||||
click.echo(click.style("New emails do not match.", fg="red"))
|
||||
return
|
||||
|
||||
account = db.session.query(Account).filter(Account.email == email).one_or_none()
|
||||
|
||||
if not account:
|
||||
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
|
||||
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red"))
|
||||
click.echo(click.style("Invalid email: {}".format(new_email), fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style("Congratulations!, email has been reset.", fg="green"))
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command(
|
||||
@ -104,7 +104,7 @@ def reset_email(email, new_email, email_confirm):
|
||||
)
|
||||
@click.confirmation_option(
|
||||
prompt=click.style(
|
||||
"Are you sure you want to reset encrypt key pair? this operation cannot be rolled back!", fg="red"
|
||||
"Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red"
|
||||
)
|
||||
)
|
||||
def reset_encrypt_key_pair():
|
||||
@ -114,13 +114,13 @@ def reset_encrypt_key_pair():
|
||||
Only support SELF_HOSTED mode.
|
||||
"""
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red"))
|
||||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||
return
|
||||
|
||||
tenants = db.session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red"))
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
@ -137,7 +137,7 @@ def reset_encrypt_key_pair():
|
||||
)
|
||||
|
||||
|
||||
@click.command("vdb-migrate", help="migrate vector db.")
|
||||
@click.command("vdb-migrate", help="Migrate vector db.")
|
||||
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
|
||||
def vdb_migrate(scope: str):
|
||||
if scope in {"knowledge", "all"}:
|
||||
@ -150,7 +150,7 @@ def migrate_annotation_vector_database():
|
||||
"""
|
||||
Migrate annotation datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style("Start migrate annotation data.", fg="green"))
|
||||
click.echo(click.style("Starting annotation data migration.", fg="green"))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
@ -174,14 +174,14 @@ def migrate_annotation_vector_database():
|
||||
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
|
||||
)
|
||||
try:
|
||||
click.echo("Create app annotation index: {}".format(app.id))
|
||||
click.echo("Creating app annotation index: {}".format(app.id))
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo("App annotation setting is disabled: {}".format(app.id))
|
||||
click.echo("App annotation setting disabled: {}".format(app.id))
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
@ -190,7 +190,7 @@ def migrate_annotation_vector_database():
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo("App annotation collection binding is not exist: {}".format(app.id))
|
||||
click.echo("App annotation collection binding not found: {}".format(app.id))
|
||||
continue
|
||||
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
|
||||
dataset = Dataset(
|
||||
@ -211,11 +211,11 @@ def migrate_annotation_vector_database():
|
||||
documents.append(document)
|
||||
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
|
||||
click.echo(f"Migrating annotations for app: {app.id}.")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green"))
|
||||
click.echo(click.style(f"Deleted vector index for app {app.id}.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
|
||||
raise e
|
||||
@ -223,12 +223,12 @@ def migrate_annotation_vector_database():
|
||||
try:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Start to created vector index with {len(documents)} annotations for app {app.id}.",
|
||||
f"Creating vector index with {len(documents)} annotations for app {app.id}.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
vector.create(documents)
|
||||
click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green"))
|
||||
click.echo(click.style(f"Created vector index for app {app.id}.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
|
||||
raise e
|
||||
@ -237,14 +237,14 @@ def migrate_annotation_vector_database():
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
"Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red"
|
||||
"Error creating app annotation index: {} {}".format(e.__class__.__name__, str(e)), fg="red"
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.",
|
||||
f"Migration complete. Created {create_count} app annotation indexes. Skipped {skipped_count} apps.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
@ -254,7 +254,7 @@ def migrate_knowledge_vector_database():
|
||||
"""
|
||||
Migrate vector database datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style("Start migrate vector db.", fg="green"))
|
||||
click.echo(click.style("Starting vector database migration.", fg="green"))
|
||||
create_count = 0
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
@ -278,7 +278,7 @@ def migrate_knowledge_vector_database():
|
||||
f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped."
|
||||
)
|
||||
try:
|
||||
click.echo("Create dataset vdb index: {}".format(dataset.id))
|
||||
click.echo("Creating dataset vector database index: {}".format(dataset.id))
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict["type"] == vector_type:
|
||||
skipped_count = skipped_count + 1
|
||||
@ -299,7 +299,7 @@ def migrate_knowledge_vector_database():
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError("Dataset Collection Bindings is not exist!")
|
||||
raise ValueError("Dataset Collection Binding not found")
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
@ -351,14 +351,12 @@ def migrate_knowledge_vector_database():
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
vector = Vector(dataset)
|
||||
click.echo(f"Start to migrate dataset {dataset.id}.")
|
||||
click.echo(f"Migrating dataset {dataset.id}.")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green"
|
||||
)
|
||||
click.style(f"Deleted vector index {collection_name} for dataset {dataset.id}.", fg="green")
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
@ -410,15 +408,13 @@ def migrate_knowledge_vector_database():
|
||||
try:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Start to created vector index with {len(documents)} documents of {segments_count}"
|
||||
f"Creating vector index with {len(documents)} documents of {segments_count}"
|
||||
f" segments for dataset {dataset.id}.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
vector.create(documents)
|
||||
click.echo(
|
||||
click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green")
|
||||
)
|
||||
click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
|
||||
raise e
|
||||
@ -429,13 +425,13 @@ def migrate_knowledge_vector_database():
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(
|
||||
click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
|
||||
click.style("Error creating dataset index: {} {}".format(e.__class__.__name__, str(e)), fg="red")
|
||||
)
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green"
|
||||
f"Migration complete. Created {create_count} dataset indexes. Skipped {skipped_count} datasets.", fg="green"
|
||||
)
|
||||
)
|
||||
|
||||
@ -445,7 +441,7 @@ def convert_to_agent_apps():
|
||||
"""
|
||||
Convert Agent Assistant to Agent App.
|
||||
"""
|
||||
click.echo(click.style("Start convert to agent apps.", fg="green"))
|
||||
click.echo(click.style("Starting convert to agent apps.", fg="green"))
|
||||
|
||||
proceeded_app_ids = []
|
||||
|
||||
@ -496,23 +492,23 @@ def convert_to_agent_apps():
|
||||
except Exception as e:
|
||||
click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
|
||||
|
||||
click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
|
||||
click.echo(click.style("Conversion complete. Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
|
||||
|
||||
|
||||
@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.")
|
||||
@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.")
|
||||
@click.command("add-qdrant-doc-id-index", help="Add Qdrant doc_id index.")
|
||||
@click.option("--field", default="metadata.doc_id", prompt=False, help="Index field , default is metadata.doc_id.")
|
||||
def add_qdrant_doc_id_index(field: str):
|
||||
click.echo(click.style("Start add qdrant doc_id index.", fg="green"))
|
||||
click.echo(click.style("Starting Qdrant doc_id index creation.", fg="green"))
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
if vector_type != "qdrant":
|
||||
click.echo(click.style("Sorry, only support qdrant vector store.", fg="red"))
|
||||
click.echo(click.style("This command only supports Qdrant vector store.", fg="red"))
|
||||
return
|
||||
create_count = 0
|
||||
|
||||
try:
|
||||
bindings = db.session.query(DatasetCollectionBinding).all()
|
||||
if not bindings:
|
||||
click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red"))
|
||||
click.echo(click.style("No dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
import qdrant_client
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
@ -522,7 +518,7 @@ def add_qdrant_doc_id_index(field: str):
|
||||
|
||||
for binding in bindings:
|
||||
if dify_config.QDRANT_URL is None:
|
||||
raise ValueError("Qdrant url is required.")
|
||||
raise ValueError("Qdrant URL is required.")
|
||||
qdrant_config = QdrantConfig(
|
||||
endpoint=dify_config.QDRANT_URL,
|
||||
api_key=dify_config.QDRANT_API_KEY,
|
||||
@ -539,41 +535,39 @@ def add_qdrant_doc_id_index(field: str):
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
click.echo(
|
||||
click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red")
|
||||
)
|
||||
click.echo(click.style(f"Collection not found: {binding.collection_name}.", fg="red"))
|
||||
continue
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red"
|
||||
f"Failed to create Qdrant index for collection: {binding.collection_name}.", fg="red"
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(click.style("Failed to create qdrant client.", fg="red"))
|
||||
click.echo(click.style("Failed to create Qdrant client.", fg="red"))
|
||||
|
||||
click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green"))
|
||||
click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green"))
|
||||
|
||||
|
||||
@click.command("create-tenant", help="Create account and tenant.")
|
||||
@click.option("--email", prompt=True, help="The email address of the tenant account.")
|
||||
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
|
||||
@click.option("--email", prompt=True, help="Tenant account email.")
|
||||
@click.option("--name", prompt=True, help="Workspace name.")
|
||||
@click.option("--language", prompt=True, help="Account language, default: en-US.")
|
||||
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None):
|
||||
"""
|
||||
Create tenant account
|
||||
"""
|
||||
if not email:
|
||||
click.echo(click.style("Sorry, email is required.", fg="red"))
|
||||
click.echo(click.style("Email is required.", fg="red"))
|
||||
return
|
||||
|
||||
# Create account
|
||||
email = email.strip()
|
||||
|
||||
if "@" not in email:
|
||||
click.echo(click.style("Sorry, invalid email address.", fg="red"))
|
||||
click.echo(click.style("Invalid email address.", fg="red"))
|
||||
return
|
||||
|
||||
account_name = email.split("@")[0]
|
||||
@ -593,19 +587,19 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
"Congratulations! Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password),
|
||||
"Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command("upgrade-db", help="upgrade the database")
|
||||
@click.command("upgrade-db", help="Upgrade the database")
|
||||
def upgrade_db():
|
||||
click.echo("Preparing database migration...")
|
||||
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
|
||||
if lock.acquire(blocking=False):
|
||||
try:
|
||||
click.echo(click.style("Start database migration.", fg="green"))
|
||||
click.echo(click.style("Starting database migration.", fg="green"))
|
||||
|
||||
# run db migration
|
||||
import flask_migrate
|
||||
@ -615,7 +609,7 @@ def upgrade_db():
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"Database migration failed, error: {e}")
|
||||
logging.exception(f"Database migration failed: {e}")
|
||||
finally:
|
||||
lock.release()
|
||||
else:
|
||||
@ -627,7 +621,7 @@ def fix_app_site_missing():
|
||||
"""
|
||||
Fix app related site missing issue.
|
||||
"""
|
||||
click.echo(click.style("Start fix app related site missing issue.", fg="green"))
|
||||
click.echo(click.style("Starting fix for missing app-related sites.", fg="green"))
|
||||
|
||||
failed_app_ids = []
|
||||
while True:
|
||||
@ -650,22 +644,22 @@ where sites.id is null limit 1000"""
|
||||
if tenant:
|
||||
accounts = tenant.get_accounts()
|
||||
if not accounts:
|
||||
print("Fix app {} failed.".format(app.id))
|
||||
print("Fix failed for app {}".format(app.id))
|
||||
continue
|
||||
|
||||
account = accounts[0]
|
||||
print("Fix app {} related site missing issue.".format(app.id))
|
||||
print("Fixing missing site for app {}".format(app.id))
|
||||
app_was_created.send(app, account=account)
|
||||
except Exception as e:
|
||||
failed_app_ids.append(app_id)
|
||||
click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red"))
|
||||
click.echo(click.style("FFailed to fix missing site for app {}".format(app_id), fg="red"))
|
||||
logging.exception(f"Fix app related site missing issue failed, error: {e}")
|
||||
continue
|
||||
|
||||
if not processed_count:
|
||||
break
|
||||
|
||||
click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green"))
|
||||
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
|
@ -4,30 +4,30 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class DeploymentConfig(BaseSettings):
|
||||
"""
|
||||
Deployment configs
|
||||
Configuration settings for application deployment
|
||||
"""
|
||||
|
||||
APPLICATION_NAME: str = Field(
|
||||
description="application name",
|
||||
description="Name of the application, used for identification and logging purposes",
|
||||
default="langgenius/dify",
|
||||
)
|
||||
|
||||
DEBUG: bool = Field(
|
||||
description="whether to enable debug mode.",
|
||||
description="Enable debug mode for additional logging and development features",
|
||||
default=False,
|
||||
)
|
||||
|
||||
TESTING: bool = Field(
|
||||
description="",
|
||||
description="Enable testing mode for running automated tests",
|
||||
default=False,
|
||||
)
|
||||
|
||||
EDITION: str = Field(
|
||||
description="deployment edition",
|
||||
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
|
||||
default="SELF_HOSTED",
|
||||
)
|
||||
|
||||
DEPLOY_ENV: str = Field(
|
||||
description="deployment environment, default to PRODUCTION.",
|
||||
description="Deployment environment (e.g., 'PRODUCTION', 'DEVELOPMENT'), default to PRODUCTION",
|
||||
default="PRODUCTION",
|
||||
)
|
||||
|
@ -4,17 +4,17 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class EnterpriseFeatureConfig(BaseSettings):
|
||||
"""
|
||||
Enterprise feature configs.
|
||||
Configuration for enterprise-level features.
|
||||
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
"""
|
||||
|
||||
ENTERPRISE_ENABLED: bool = Field(
|
||||
description="whether to enable enterprise features."
|
||||
description="Enable or disable enterprise-level features."
|
||||
"Before using, please contact business@dify.ai by email to inquire about licensing matters.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
CAN_REPLACE_LOGO: bool = Field(
|
||||
description="whether to allow replacing enterprise logo.",
|
||||
description="Allow customization of the enterprise logo.",
|
||||
default=False,
|
||||
)
|
||||
|
@ -6,30 +6,31 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class NotionConfig(BaseSettings):
|
||||
"""
|
||||
Notion integration configs
|
||||
Configuration settings for Notion integration
|
||||
"""
|
||||
|
||||
NOTION_CLIENT_ID: Optional[str] = Field(
|
||||
description="Notion client ID",
|
||||
description="Client ID for Notion API authentication. Required for OAuth 2.0 flow.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_CLIENT_SECRET: Optional[str] = Field(
|
||||
description="Notion client secret key",
|
||||
description="Client secret for Notion API authentication. Required for OAuth 2.0 flow.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_INTEGRATION_TYPE: Optional[str] = Field(
|
||||
description="Notion integration type, default to None, available values: internal.",
|
||||
description="Type of Notion integration."
|
||||
" Set to 'internal' for internal integrations, or None for public integrations.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_INTERNAL_SECRET: Optional[str] = Field(
|
||||
description="Notion internal secret key",
|
||||
description="Secret key for internal Notion integrations. Required when NOTION_INTEGRATION_TYPE is 'internal'.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
NOTION_INTEGRATION_TOKEN: Optional[str] = Field(
|
||||
description="Notion integration token",
|
||||
description="Integration token for Notion API access. Used for direct API calls without OAuth flow.",
|
||||
default=None,
|
||||
)
|
||||
|
@ -6,20 +6,23 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class SentryConfig(BaseSettings):
|
||||
"""
|
||||
Sentry configs
|
||||
Configuration settings for Sentry error tracking and performance monitoring
|
||||
"""
|
||||
|
||||
SENTRY_DSN: Optional[str] = Field(
|
||||
description="Sentry DSN",
|
||||
description="Sentry Data Source Name (DSN)."
|
||||
" This is the unique identifier of your Sentry project, used to send events to the correct project.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field(
|
||||
description="Sentry trace sample rate",
|
||||
description="Sample rate for Sentry performance monitoring traces."
|
||||
" Value between 0.0 and 1.0, where 1.0 means 100% of traces are sent to Sentry.",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field(
|
||||
description="Sentry profiles sample rate",
|
||||
description="Sample rate for Sentry profiling."
|
||||
" Value between 0.0 and 1.0, where 1.0 means 100% of profiles are sent to Sentry.",
|
||||
default=1.0,
|
||||
)
|
||||
|
@ -8,145 +8,143 @@ from configs.feature.hosted_service import HostedServiceConfig
|
||||
|
||||
class SecurityConfig(BaseSettings):
|
||||
"""
|
||||
Secret Key configs
|
||||
Security-related configurations for the application
|
||||
"""
|
||||
|
||||
SECRET_KEY: Optional[str] = Field(
|
||||
description="Your App secret key will be used for securely signing the session cookie"
|
||||
description="Secret key for secure session cookie signing."
|
||||
"Make sure you are changing this key for your deployment with a strong key."
|
||||
"You can generate a strong key using `openssl rand -base64 42`."
|
||||
"Alternatively you can set it with `SECRET_KEY` environment variable.",
|
||||
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
||||
description="Expiry time in hours for reset token",
|
||||
description="Duration in hours for which a password reset token remains valid",
|
||||
default=24,
|
||||
)
|
||||
|
||||
|
||||
class AppExecutionConfig(BaseSettings):
|
||||
"""
|
||||
App Execution configs
|
||||
Configuration parameters for application execution
|
||||
"""
|
||||
|
||||
APP_MAX_EXECUTION_TIME: PositiveInt = Field(
|
||||
description="execution timeout in seconds for app execution",
|
||||
description="Maximum allowed execution time for the application in seconds",
|
||||
default=1200,
|
||||
)
|
||||
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description="max active request per app, 0 means unlimited",
|
||||
description="Maximum number of concurrent active requests per app (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
"""
|
||||
Code Execution Sandbox configs
|
||||
Configuration for the code execution sandbox environment
|
||||
"""
|
||||
|
||||
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
|
||||
description="endpoint URL of code execution service",
|
||||
description="URL endpoint for the code execution service",
|
||||
default="http://sandbox:8194",
|
||||
)
|
||||
|
||||
CODE_EXECUTION_API_KEY: str = Field(
|
||||
description="API key for code execution service",
|
||||
description="API key for accessing the code execution service",
|
||||
default="dify-sandbox",
|
||||
)
|
||||
|
||||
CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field(
|
||||
description="connect timeout in seconds for code execution request",
|
||||
description="Connection timeout in seconds for code execution requests",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field(
|
||||
description="read timeout in seconds for code execution request",
|
||||
description="Read timeout in seconds for code execution requests",
|
||||
default=60.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field(
|
||||
description="write timeout in seconds for code execution request",
|
||||
description="Write timeout in seconds for code execution request",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: PositiveInt = Field(
|
||||
description="max depth for code execution",
|
||||
description="Maximum allowed numeric value in code execution",
|
||||
default=9223372036854775807,
|
||||
)
|
||||
|
||||
CODE_MIN_NUMBER: NegativeInt = Field(
|
||||
description="",
|
||||
description="Minimum allowed numeric value in code execution",
|
||||
default=-9223372036854775807,
|
||||
)
|
||||
|
||||
CODE_MAX_DEPTH: PositiveInt = Field(
|
||||
description="max depth for code execution",
|
||||
description="Maximum allowed depth for nested structures in code execution",
|
||||
default=5,
|
||||
)
|
||||
|
||||
CODE_MAX_PRECISION: PositiveInt = Field(
|
||||
description="max precision digits for float type in code execution",
|
||||
description="mMaximum number of decimal places for floating-point numbers in code execution",
|
||||
default=20,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
|
||||
description="max string length for code execution",
|
||||
description="Maximum allowed length for strings in code execution",
|
||||
default=80000,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
|
||||
description="",
|
||||
description="Maximum allowed length for string arrays in code execution",
|
||||
default=30,
|
||||
)
|
||||
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field(
|
||||
description="",
|
||||
description="Maximum allowed length for object arrays in code execution",
|
||||
default=30,
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field(
|
||||
description="",
|
||||
description="Maximum allowed length for numeric arrays in code execution",
|
||||
default=1000,
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Module URL configs
|
||||
Configuration for various application endpoints and URLs
|
||||
"""
|
||||
|
||||
CONSOLE_API_URL: str = Field(
|
||||
description="The backend URL prefix of the console API."
|
||||
"used to concatenate the login authorization callback or notion integration callback.",
|
||||
description="Base URL for the console API,"
|
||||
"used for login authentication callback or notion integration callbacks",
|
||||
default="",
|
||||
)
|
||||
|
||||
CONSOLE_WEB_URL: str = Field(
|
||||
description="The front-end URL prefix of the console web."
|
||||
"used to concatenate some front-end addresses and for CORS configuration use.",
|
||||
description="Base URL for the console web interface," "used for frontend references and CORS configuration",
|
||||
default="",
|
||||
)
|
||||
|
||||
SERVICE_API_URL: str = Field(
|
||||
description="Service API Url prefix. used to display Service API Base Url to the front-end.",
|
||||
description="Base URL for the service API, displayed to users for API access",
|
||||
default="",
|
||||
)
|
||||
|
||||
APP_WEB_URL: str = Field(
|
||||
description="WebApp Url prefix. used to display WebAPP API Base Url to the front-end.",
|
||||
description="Base URL for the web application, used for frontend references",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class FileAccessConfig(BaseSettings):
|
||||
"""
|
||||
File Access configs
|
||||
Configuration for file access and handling
|
||||
"""
|
||||
|
||||
FILES_URL: str = Field(
|
||||
description="File preview or download Url prefix."
|
||||
" used to display File preview or download Url to the front-end or as Multi-model inputs;"
|
||||
description="Base URL for file preview or download,"
|
||||
" used for frontend display and multi-model inputs"
|
||||
"Url is signed and has expiration time.",
|
||||
validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"),
|
||||
alias_priority=1,
|
||||
@ -154,49 +152,49 @@ class FileAccessConfig(BaseSettings):
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
description="timeout in seconds for file accessing",
|
||||
description="Expiration time in seconds for file access URLs",
|
||||
default=300,
|
||||
)
|
||||
|
||||
|
||||
class FileUploadConfig(BaseSettings):
|
||||
"""
|
||||
File Uploading configs
|
||||
Configuration for file upload limitations
|
||||
"""
|
||||
|
||||
UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description="size limit in Megabytes for uploading files",
|
||||
description="Maximum allowed file size for uploads in megabytes",
|
||||
default=15,
|
||||
)
|
||||
|
||||
UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field(
|
||||
description="batch size limit for uploading files",
|
||||
description="Maximum number of files allowed in a single upload batch",
|
||||
default=5,
|
||||
)
|
||||
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description="image file size limit in Megabytes for uploading files",
|
||||
description="Maximum allowed image file size for uploads in megabytes",
|
||||
default=10,
|
||||
)
|
||||
|
||||
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
|
||||
description="", # todo: to be clarified
|
||||
description="Maximum number of files allowed in a batch upload operation",
|
||||
default=20,
|
||||
)
|
||||
|
||||
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
HTTP configs
|
||||
HTTP-related configurations for the application
|
||||
"""
|
||||
|
||||
API_COMPRESSION_ENABLED: bool = Field(
|
||||
description="whether to enable HTTP response compression of gzip",
|
||||
description="Enable or disable gzip compression for HTTP responses",
|
||||
default=False,
|
||||
)
|
||||
|
||||
inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description="",
|
||||
description="Comma-separated list of allowed origins for CORS in the console",
|
||||
validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"),
|
||||
default="",
|
||||
)
|
||||
@ -218,359 +216,360 @@ class HttpConfig(BaseSettings):
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
|
||||
PositiveInt, Field(ge=10, description="Maximum connection timeout in seconds for HTTP requests")
|
||||
] = 10
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
|
||||
PositiveInt, Field(ge=60, description="Maximum read timeout in seconds for HTTP requests")
|
||||
] = 60
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
|
||||
PositiveInt, Field(ge=10, description="Maximum write timeout in seconds for HTTP requests")
|
||||
] = 20
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
description="",
|
||||
description="Maximum allowed size in bytes for binary data in HTTP requests",
|
||||
default=10 * 1024 * 1024,
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field(
|
||||
description="",
|
||||
description="Maximum allowed size in bytes for text data in HTTP requests",
|
||||
default=1 * 1024 * 1024,
|
||||
)
|
||||
|
||||
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
|
||||
description="HTTP URL for SSRF proxy",
|
||||
description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SSRF_PROXY_HTTPS_URL: Optional[str] = Field(
|
||||
description="HTTPS URL for SSRF proxy",
|
||||
description="Proxy URL for HTTPS requests to prevent Server-Side Request Forgery (SSRF)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class InnerAPIConfig(BaseSettings):
|
||||
"""
|
||||
Inner API configs
|
||||
Configuration for internal API functionality
|
||||
"""
|
||||
|
||||
INNER_API: bool = Field(
|
||||
description="whether to enable the inner API",
|
||||
description="Enable or disable the internal API",
|
||||
default=False,
|
||||
)
|
||||
|
||||
INNER_API_KEY: Optional[str] = Field(
|
||||
description="The inner API key is used to authenticate the inner API",
|
||||
description="API key for accessing the internal API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class LoggingConfig(BaseSettings):
|
||||
"""
|
||||
Logging configs
|
||||
Configuration for application logging
|
||||
"""
|
||||
|
||||
LOG_LEVEL: str = Field(
|
||||
description="Log output level, default to INFO. It is recommended to set it to ERROR for production.",
|
||||
description="Logging level, default to INFO. Set to ERROR for production environments.",
|
||||
default="INFO",
|
||||
)
|
||||
|
||||
LOG_FILE: Optional[str] = Field(
|
||||
description="logging output file path",
|
||||
description="File path for log output.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
LOG_FORMAT: str = Field(
|
||||
description="log format",
|
||||
description="Format string for log messages",
|
||||
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
|
||||
)
|
||||
|
||||
LOG_DATEFORMAT: Optional[str] = Field(
|
||||
description="log date format",
|
||||
description="Date format string for log timestamps",
|
||||
default=None,
|
||||
)
|
||||
|
||||
LOG_TZ: Optional[str] = Field(
|
||||
description="specify log timezone, eg: America/New_York",
|
||||
description="Timezone for log timestamps (e.g., 'America/New_York')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadBalanceConfig(BaseSettings):
|
||||
"""
|
||||
Model load balance configs
|
||||
Configuration for model load balancing
|
||||
"""
|
||||
|
||||
MODEL_LB_ENABLED: bool = Field(
|
||||
description="whether to enable model load balancing",
|
||||
description="Enable or disable load balancing for models",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class BillingConfig(BaseSettings):
|
||||
"""
|
||||
Platform Billing Configurations
|
||||
Configuration for platform billing features
|
||||
"""
|
||||
|
||||
BILLING_ENABLED: bool = Field(
|
||||
description="whether to enable billing",
|
||||
description="Enable or disable billing functionality",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class UpdateConfig(BaseSettings):
|
||||
"""
|
||||
Update configs
|
||||
Configuration for application update checks
|
||||
"""
|
||||
|
||||
CHECK_UPDATE_URL: str = Field(
|
||||
description="url for checking updates",
|
||||
description="URL to check for application updates",
|
||||
default="https://updates.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowConfig(BaseSettings):
|
||||
"""
|
||||
Workflow feature configs
|
||||
Configuration for workflow execution
|
||||
"""
|
||||
|
||||
WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field(
|
||||
description="max execution steps in single workflow execution",
|
||||
description="Maximum number of steps allowed in a single workflow execution",
|
||||
default=500,
|
||||
)
|
||||
|
||||
WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field(
|
||||
description="max execution time in seconds in single workflow execution",
|
||||
description="Maximum execution time in seconds for a single workflow",
|
||||
default=1200,
|
||||
)
|
||||
|
||||
WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field(
|
||||
description="max depth of calling in single workflow execution",
|
||||
description="Maximum allowed depth for nested workflow calls",
|
||||
default=5,
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="The maximum size in bytes of a variable. default to 5KB.",
|
||||
description="Maximum size in bytes for a single variable in workflows. Default to 5KB.",
|
||||
default=5 * 1024,
|
||||
)
|
||||
|
||||
|
||||
class OAuthConfig(BaseSettings):
|
||||
"""
|
||||
oauth configs
|
||||
Configuration for OAuth authentication
|
||||
"""
|
||||
|
||||
OAUTH_REDIRECT_PATH: str = Field(
|
||||
description="redirect path for OAuth",
|
||||
description="Redirect path for OAuth authentication callbacks",
|
||||
default="/console/api/oauth/authorize",
|
||||
)
|
||||
|
||||
GITHUB_CLIENT_ID: Optional[str] = Field(
|
||||
description="GitHub client id for OAuth",
|
||||
description="GitHub OAuth client secret",
|
||||
default=None,
|
||||
)
|
||||
|
||||
GITHUB_CLIENT_SECRET: Optional[str] = Field(
|
||||
description="GitHub client secret key for OAuth",
|
||||
description="GitHub OAuth client secret",
|
||||
default=None,
|
||||
)
|
||||
|
||||
GOOGLE_CLIENT_ID: Optional[str] = Field(
|
||||
description="Google client id for OAuth",
|
||||
description="Google OAuth client ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = Field(
|
||||
description="Google client secret key for OAuth",
|
||||
description="Google OAuth client secret",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
Moderation in app configs.
|
||||
Configuration for content moderation
|
||||
"""
|
||||
|
||||
MODERATION_BUFFER_SIZE: PositiveInt = Field(
|
||||
description="buffer size for moderation",
|
||||
description="Size of the buffer for content moderation processing",
|
||||
default=300,
|
||||
)
|
||||
|
||||
|
||||
class ToolConfig(BaseSettings):
|
||||
"""
|
||||
Tool configs
|
||||
Configuration for tool management
|
||||
"""
|
||||
|
||||
TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field(
|
||||
description="max age in seconds for tool icon caching",
|
||||
description="Maximum age in seconds for caching tool icons",
|
||||
default=3600,
|
||||
)
|
||||
|
||||
|
||||
class MailConfig(BaseSettings):
|
||||
"""
|
||||
Mail Configurations
|
||||
Configuration for email services
|
||||
"""
|
||||
|
||||
MAIL_TYPE: Optional[str] = Field(
|
||||
description="Mail provider type name, default to None, available values are `smtp` and `resend`.",
|
||||
description="Email service provider type ('smtp' or 'resend'), default to None.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MAIL_DEFAULT_SEND_FROM: Optional[str] = Field(
|
||||
description="default email address for sending from ",
|
||||
description="Default email address to use as the sender",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESEND_API_KEY: Optional[str] = Field(
|
||||
description="API key for Resend",
|
||||
description="API key for Resend email service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RESEND_API_URL: Optional[str] = Field(
|
||||
description="API URL for Resend",
|
||||
description="API URL for Resend email service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_SERVER: Optional[str] = Field(
|
||||
description="smtp server host",
|
||||
description="SMTP server hostname",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_PORT: Optional[int] = Field(
|
||||
description="smtp server port",
|
||||
description="SMTP server port number",
|
||||
default=465,
|
||||
)
|
||||
|
||||
SMTP_USERNAME: Optional[str] = Field(
|
||||
description="smtp server username",
|
||||
description="Username for SMTP authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_PASSWORD: Optional[str] = Field(
|
||||
description="smtp server password",
|
||||
description="Password for SMTP authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SMTP_USE_TLS: bool = Field(
|
||||
description="whether to use TLS connection to smtp server",
|
||||
description="Enable TLS encryption for SMTP connections",
|
||||
default=False,
|
||||
)
|
||||
|
||||
SMTP_OPPORTUNISTIC_TLS: bool = Field(
|
||||
description="whether to use opportunistic TLS connection to smtp server",
|
||||
description="Enable opportunistic TLS for SMTP connections",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
RAG ETL Configurations.
|
||||
Configuration for RAG ETL processes
|
||||
"""
|
||||
|
||||
ETL_TYPE: str = Field(
|
||||
description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ",
|
||||
description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
KEYWORD_DATA_SOURCE_TYPE: str = Field(
|
||||
description="source type for keyword data, default to `database`, available values are `database` .",
|
||||
description="Data source type for keyword extraction"
|
||||
" ('database' or other supported types), default to 'database'",
|
||||
default="database",
|
||||
)
|
||||
|
||||
UNSTRUCTURED_API_URL: Optional[str] = Field(
|
||||
description="API URL for Unstructured",
|
||||
description="API URL for Unstructured.io service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
UNSTRUCTURED_API_KEY: Optional[str] = Field(
|
||||
description="API key for Unstructured",
|
||||
description="API key for Unstructured.io service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class DataSetConfig(BaseSettings):
|
||||
"""
|
||||
Dataset configs
|
||||
Configuration for dataset management
|
||||
"""
|
||||
|
||||
CLEAN_DAY_SETTING: PositiveInt = Field(
|
||||
description="interval in days for cleaning up dataset",
|
||||
description="Interval in days for dataset cleanup operations",
|
||||
default=30,
|
||||
)
|
||||
|
||||
DATASET_OPERATOR_ENABLED: bool = Field(
|
||||
description="whether to enable dataset operator",
|
||||
description="Enable or disable dataset operator functionality",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
Workspace configs
|
||||
Configuration for workspace management
|
||||
"""
|
||||
|
||||
INVITE_EXPIRY_HOURS: PositiveInt = Field(
|
||||
description="workspaces invitation expiration in hours",
|
||||
description="Expiration time in hours for workspace invitation links",
|
||||
default=72,
|
||||
)
|
||||
|
||||
|
||||
class IndexingConfig(BaseSettings):
|
||||
"""
|
||||
Indexing configs.
|
||||
Configuration for indexing operations
|
||||
"""
|
||||
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field(
|
||||
description="max segmentation token length for indexing",
|
||||
description="Maximum token length for text segmentation during indexing",
|
||||
default=1000,
|
||||
)
|
||||
|
||||
|
||||
class ImageFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
|
||||
description="multi model send image format, support base64, url, default is base64",
|
||||
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
|
||||
class CeleryBeatConfig(BaseSettings):
|
||||
CELERY_BEAT_SCHEDULER_TIME: int = Field(
|
||||
description="the time of the celery scheduler, default to 1 day",
|
||||
description="Interval in days for Celery Beat scheduler execution, default to 1 day",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
POSITION_PROVIDER_PINS: str = Field(
|
||||
description="The heads of model providers",
|
||||
description="Comma-separated list of pinned model providers",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_PROVIDER_INCLUDES: str = Field(
|
||||
description="The included model providers",
|
||||
description="Comma-separated list of included model providers",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_PROVIDER_EXCLUDES: str = Field(
|
||||
description="The excluded model providers",
|
||||
description="Comma-separated list of excluded model providers",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_TOOL_PINS: str = Field(
|
||||
description="The heads of tools",
|
||||
description="Comma-separated list of pinned tools",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_TOOL_INCLUDES: str = Field(
|
||||
description="The included tools",
|
||||
description="Comma-separated list of included tools",
|
||||
default="",
|
||||
)
|
||||
|
||||
POSITION_TOOL_EXCLUDES: str = Field(
|
||||
description="The excluded tools",
|
||||
description="Comma-separated list of excluded tools",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
@ -6,31 +6,31 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class HostedOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
Hosted OpenAI service config
|
||||
Configuration for hosted OpenAI service
|
||||
"""
|
||||
|
||||
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
|
||||
description="",
|
||||
description="API key for hosted OpenAI service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_API_BASE: Optional[str] = Field(
|
||||
description="",
|
||||
description="Base URL for hosted OpenAI API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field(
|
||||
description="",
|
||||
description="Organization ID for hosted OpenAI service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_TRIAL_ENABLED: bool = Field(
|
||||
description="",
|
||||
description="Enable trial access to hosted OpenAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
|
||||
description="",
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
@ -42,17 +42,17 @@ class HostedOpenAiConfig(BaseSettings):
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="",
|
||||
description="Quota limit for hosted OpenAI service usage",
|
||||
default=200,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
|
||||
description="",
|
||||
description="Enable paid access to hosted OpenAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_PAID_MODELS: str = Field(
|
||||
description="",
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="gpt-4,"
|
||||
"gpt-4-turbo-preview,"
|
||||
"gpt-4-turbo-2024-04-09,"
|
||||
@ -71,124 +71,122 @@ class HostedOpenAiConfig(BaseSettings):
|
||||
|
||||
class HostedAzureOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
Hosted OpenAI service config
|
||||
Configuration for hosted Azure OpenAI service
|
||||
"""
|
||||
|
||||
HOSTED_AZURE_OPENAI_ENABLED: bool = Field(
|
||||
description="",
|
||||
description="Enable hosted Azure OpenAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
|
||||
description="",
|
||||
description="API key for hosted Azure OpenAI service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field(
|
||||
description="",
|
||||
description="Base URL for hosted Azure OpenAI API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="",
|
||||
description="Quota limit for hosted Azure OpenAI service usage",
|
||||
default=200,
|
||||
)
|
||||
|
||||
|
||||
class HostedAnthropicConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Azure OpenAI service config
|
||||
Configuration for hosted Anthropic service
|
||||
"""
|
||||
|
||||
HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field(
|
||||
description="",
|
||||
description="Base URL for hosted Anthropic API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field(
|
||||
description="",
|
||||
description="API key for hosted Anthropic service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field(
|
||||
description="",
|
||||
description="Enable trial access to hosted Anthropic service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="",
|
||||
description="Quota limit for hosted Anthropic service usage",
|
||||
default=600000,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
|
||||
description="",
|
||||
description="Enable paid access to hosted Anthropic service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class HostedMinmaxConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Minmax service config
|
||||
Configuration for hosted Minmax service
|
||||
"""
|
||||
|
||||
HOSTED_MINIMAX_ENABLED: bool = Field(
|
||||
description="",
|
||||
description="Enable hosted Minmax service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class HostedSparkConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Spark service config
|
||||
Configuration for hosted Spark service
|
||||
"""
|
||||
|
||||
HOSTED_SPARK_ENABLED: bool = Field(
|
||||
description="",
|
||||
description="Enable hosted Spark service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class HostedZhipuAIConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Minmax service config
|
||||
Configuration for hosted ZhipuAI service
|
||||
"""
|
||||
|
||||
HOSTED_ZHIPUAI_ENABLED: bool = Field(
|
||||
description="",
|
||||
description="Enable hosted ZhipuAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class HostedModerationConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Moderation service config
|
||||
Configuration for hosted Moderation service
|
||||
"""
|
||||
|
||||
HOSTED_MODERATION_ENABLED: bool = Field(
|
||||
description="",
|
||||
description="Enable hosted Moderation service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_MODERATION_PROVIDERS: str = Field(
|
||||
description="",
|
||||
description="Comma-separated list of moderation providers",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
"""
|
||||
Hosted Moderation service config
|
||||
Configuration for fetching app templates
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
|
||||
description="the mode for fetching app templates,"
|
||||
" default to remote,"
|
||||
" available values: remote, db, builtin",
|
||||
description="Mode for fetching app templates: remote, db, or builtin" " default to remote,",
|
||||
default="remote",
|
||||
)
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
description="the domain for fetching remote app templates",
|
||||
description="Domain for fetching remote app templates",
|
||||
default="https://tmpl.dify.ai",
|
||||
)
|
||||
|
||||
|
@ -31,70 +31,71 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: str = Field(
|
||||
description="storage type,"
|
||||
" default to `local`,"
|
||||
" available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.",
|
||||
description="Type of storage to use."
|
||||
" Options: 'local', 's3', 'azure-blob', 'aliyun-oss', 'google-storage'. Default is 'local'.",
|
||||
default="local",
|
||||
)
|
||||
|
||||
STORAGE_LOCAL_PATH: str = Field(
|
||||
description="local storage path",
|
||||
description="Path for local storage when STORAGE_TYPE is set to 'local'.",
|
||||
default="storage",
|
||||
)
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseSettings):
|
||||
VECTOR_STORE: Optional[str] = Field(
|
||||
description="vector store type",
|
||||
description="Type of vector store to use for efficient similarity search."
|
||||
" Set to None if not using a vector store.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class KeywordStoreConfig(BaseSettings):
|
||||
KEYWORD_STORE: str = Field(
|
||||
description="keyword store type",
|
||||
description="Method for keyword extraction and storage."
|
||||
" Default is 'jieba', a Chinese text segmentation library.",
|
||||
default="jieba",
|
||||
)
|
||||
|
||||
|
||||
class DatabaseConfig:
|
||||
DB_HOST: str = Field(
|
||||
description="db host",
|
||||
description="Hostname or IP address of the database server.",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
DB_PORT: PositiveInt = Field(
|
||||
description="db port",
|
||||
description="Port number for database connection.",
|
||||
default=5432,
|
||||
)
|
||||
|
||||
DB_USERNAME: str = Field(
|
||||
description="db username",
|
||||
description="Username for database authentication.",
|
||||
default="postgres",
|
||||
)
|
||||
|
||||
DB_PASSWORD: str = Field(
|
||||
description="db password",
|
||||
description="Password for database authentication.",
|
||||
default="",
|
||||
)
|
||||
|
||||
DB_DATABASE: str = Field(
|
||||
description="db database",
|
||||
description="Name of the database to connect to.",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
DB_CHARSET: str = Field(
|
||||
description="db charset",
|
||||
description="Character set for database connection.",
|
||||
default="",
|
||||
)
|
||||
|
||||
DB_EXTRAS: str = Field(
|
||||
description="db extras options. Example: keepalives_idle=60&keepalives=1",
|
||||
description="Additional database connection parameters. Example: 'keepalives_idle=60&keepalives=1'",
|
||||
default="",
|
||||
)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
|
||||
description="db uri scheme",
|
||||
description="Database URI scheme for SQLAlchemy connection.",
|
||||
default="postgresql",
|
||||
)
|
||||
|
||||
@ -112,27 +113,27 @@ class DatabaseConfig:
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
|
||||
description="pool size of SqlAlchemy",
|
||||
description="Maximum number of database connections in the pool.",
|
||||
default=30,
|
||||
)
|
||||
|
||||
SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
|
||||
description="max overflows for SqlAlchemy",
|
||||
description="Maximum number of connections that can be created beyond the pool_size.",
|
||||
default=10,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
|
||||
description="SqlAlchemy pool recycle",
|
||||
description="Number of seconds after which a connection is automatically recycled.",
|
||||
default=3600,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_PRE_PING: bool = Field(
|
||||
description="whether to enable pool pre-ping in SqlAlchemy",
|
||||
description="If True, enables connection pool pre-ping feature to check connections.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
SQLALCHEMY_ECHO: bool | str = Field(
|
||||
description="whether to enable SqlAlchemy echo",
|
||||
description="If True, SQLAlchemy will log all SQL statements.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
@ -150,27 +151,27 @@ class DatabaseConfig:
|
||||
|
||||
class CeleryConfig(DatabaseConfig):
|
||||
CELERY_BACKEND: str = Field(
|
||||
description="Celery backend, available values are `database`, `redis`",
|
||||
description="Backend for Celery task results. Options: 'database', 'redis'.",
|
||||
default="database",
|
||||
)
|
||||
|
||||
CELERY_BROKER_URL: Optional[str] = Field(
|
||||
description="CELERY_BROKER_URL",
|
||||
description="URL of the message broker for Celery tasks.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CELERY_USE_SENTINEL: Optional[bool] = Field(
|
||||
description="Whether to use Redis Sentinel mode",
|
||||
description="Whether to use Redis Sentinel for high availability.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
|
||||
description="Redis Sentinel master name",
|
||||
description="Name of the Redis Sentinel master.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
|
||||
description="Redis Sentinel socket timeout",
|
||||
description="Timeout for Redis Sentinel socket operations in seconds.",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
|
26
api/configs/middleware/cache/redis_config.py
vendored
26
api/configs/middleware/cache/redis_config.py
vendored
@ -6,65 +6,65 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class RedisConfig(BaseSettings):
|
||||
"""
|
||||
Redis configs
|
||||
Configuration settings for Redis connection
|
||||
"""
|
||||
|
||||
REDIS_HOST: str = Field(
|
||||
description="Redis host",
|
||||
description="Hostname or IP address of the Redis server",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
REDIS_PORT: PositiveInt = Field(
|
||||
description="Redis port",
|
||||
description="Port number on which the Redis server is listening",
|
||||
default=6379,
|
||||
)
|
||||
|
||||
REDIS_USERNAME: Optional[str] = Field(
|
||||
description="Redis username",
|
||||
description="Username for Redis authentication (if required)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_PASSWORD: Optional[str] = Field(
|
||||
description="Redis password",
|
||||
description="Password for Redis authentication (if required)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_DB: NonNegativeInt = Field(
|
||||
description="Redis database id, default to 0",
|
||||
description="Redis database number to use (0-15)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
REDIS_USE_SSL: bool = Field(
|
||||
description="whether to use SSL for Redis connection",
|
||||
description="Enable SSL/TLS for the Redis connection",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_USE_SENTINEL: Optional[bool] = Field(
|
||||
description="Whether to use Redis Sentinel mode",
|
||||
description="Enable Redis Sentinel mode for high availability",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_SENTINELS: Optional[str] = Field(
|
||||
description="Redis Sentinel nodes",
|
||||
description="Comma-separated list of Redis Sentinel nodes (host:port)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
|
||||
description="Redis Sentinel service name",
|
||||
description="Name of the Redis Sentinel service to monitor",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
|
||||
description="Redis Sentinel username",
|
||||
description="Username for Redis Sentinel authentication (if required)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
|
||||
description="Redis Sentinel password",
|
||||
description="Password for Redis Sentinel authentication (if required)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
|
||||
description="Redis Sentinel socket timeout",
|
||||
description="Socket timeout in seconds for Redis Sentinel connections",
|
||||
default=0.1,
|
||||
)
|
||||
|
@ -6,40 +6,40 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class AliyunOSSStorageConfig(BaseSettings):
|
||||
"""
|
||||
Aliyun storage configs
|
||||
Configuration settings for Aliyun Object Storage Service (OSS)
|
||||
"""
|
||||
|
||||
ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Aliyun OSS bucket name",
|
||||
description="Name of the Aliyun OSS bucket to store and retrieve objects",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Aliyun OSS access key",
|
||||
description="Access key ID for authenticating with Aliyun OSS",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Aliyun OSS secret key",
|
||||
description="Secret access key for authenticating with Aliyun OSS",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_ENDPOINT: Optional[str] = Field(
|
||||
description="Aliyun OSS endpoint URL",
|
||||
description="URL of the Aliyun OSS endpoint for your chosen region",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_REGION: Optional[str] = Field(
|
||||
description="Aliyun OSS region",
|
||||
description="Aliyun OSS region where your bucket is located (e.g., 'oss-cn-hangzhou')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field(
|
||||
description="Aliyun OSS authentication version",
|
||||
description="Version of the authentication protocol to use with Aliyun OSS (e.g., 'v4')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_PATH: Optional[str] = Field(
|
||||
description="Aliyun OSS path",
|
||||
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
|
||||
default=None,
|
||||
)
|
||||
|
@ -6,40 +6,40 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class S3StorageConfig(BaseSettings):
|
||||
"""
|
||||
S3 storage configs
|
||||
Configuration settings for S3-compatible object storage
|
||||
"""
|
||||
|
||||
S3_ENDPOINT: Optional[str] = Field(
|
||||
description="S3 storage endpoint",
|
||||
description="URL of the S3-compatible storage endpoint (e.g., 'https://s3.amazonaws.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_REGION: Optional[str] = Field(
|
||||
description="S3 storage region",
|
||||
description="Region where the S3 bucket is located (e.g., 'us-east-1')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_BUCKET_NAME: Optional[str] = Field(
|
||||
description="S3 storage bucket name",
|
||||
description="Name of the S3 bucket to store and retrieve objects",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_ACCESS_KEY: Optional[str] = Field(
|
||||
description="S3 storage access key",
|
||||
description="Access key ID for authenticating with the S3 service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_SECRET_KEY: Optional[str] = Field(
|
||||
description="S3 storage secret key",
|
||||
description="Secret access key for authenticating with the S3 service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
S3_ADDRESS_STYLE: str = Field(
|
||||
description="S3 storage address style",
|
||||
description="S3 addressing style: 'auto', 'path', or 'virtual'",
|
||||
default="auto",
|
||||
)
|
||||
|
||||
S3_USE_AWS_MANAGED_IAM: bool = Field(
|
||||
description="whether to use aws managed IAM for S3",
|
||||
description="Use AWS managed IAM roles for authentication instead of access/secret keys",
|
||||
default=False,
|
||||
)
|
||||
|
@ -6,25 +6,25 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class AzureBlobStorageConfig(BaseSettings):
|
||||
"""
|
||||
Azure Blob storage configs
|
||||
Configuration settings for Azure Blob Storage
|
||||
"""
|
||||
|
||||
AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field(
|
||||
description="Azure Blob account name",
|
||||
description="Name of the Azure Storage account (e.g., 'mystorageaccount')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field(
|
||||
description="Azure Blob account key",
|
||||
description="Access key for authenticating with the Azure Storage account",
|
||||
default=None,
|
||||
)
|
||||
|
||||
AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field(
|
||||
description="Azure Blob container name",
|
||||
description="Name of the Azure Blob container to store and retrieve objects",
|
||||
default=None,
|
||||
)
|
||||
|
||||
AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field(
|
||||
description="Azure Blob account URL",
|
||||
description="URL of the Azure Blob storage endpoint (e.g., 'https://mystorageaccount.blob.core.windows.net')",
|
||||
default=None,
|
||||
)
|
||||
|
@ -6,15 +6,15 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class GoogleCloudStorageConfig(BaseSettings):
|
||||
"""
|
||||
Google Cloud storage configs
|
||||
Configuration settings for Google Cloud Storage
|
||||
"""
|
||||
|
||||
GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Google Cloud storage bucket name",
|
||||
description="Name of the Google Cloud Storage bucket to store and retrieve objects (e.g., 'my-gcs-bucket')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field(
|
||||
description="Google Cloud storage service account json base64",
|
||||
description="Base64-encoded JSON key file for Google Cloud service account authentication",
|
||||
default=None,
|
||||
)
|
||||
|
@ -5,25 +5,25 @@ from pydantic import BaseModel, Field
|
||||
|
||||
class HuaweiCloudOBSStorageConfig(BaseModel):
|
||||
"""
|
||||
Huawei Cloud OBS storage configs
|
||||
Configuration settings for Huawei Cloud Object Storage Service (OBS)
|
||||
"""
|
||||
|
||||
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS bucket name",
|
||||
description="Name of the Huawei Cloud OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS Access key",
|
||||
description="Access Key ID for authenticating with Huawei Cloud OBS",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS Secret key",
|
||||
description="Secret Access Key for authenticating with Huawei Cloud OBS",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_SERVER: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS server URL",
|
||||
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
|
||||
default=None,
|
||||
)
|
||||
|
@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class OCIStorageConfig(BaseSettings):
|
||||
"""
|
||||
OCI storage configs
|
||||
Configuration settings for Oracle Cloud Infrastructure (OCI) Object Storage
|
||||
"""
|
||||
|
||||
OCI_ENDPOINT: Optional[str] = Field(
|
||||
description="OCI storage endpoint",
|
||||
description="URL of the OCI Object Storage endpoint (e.g., 'https://objectstorage.us-phoenix-1.oraclecloud.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_REGION: Optional[str] = Field(
|
||||
description="OCI storage region",
|
||||
description="OCI region where the bucket is located (e.g., 'us-phoenix-1')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_BUCKET_NAME: Optional[str] = Field(
|
||||
description="OCI storage bucket name",
|
||||
description="Name of the OCI Object Storage bucket to store and retrieve objects (e.g., 'my-oci-bucket')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_ACCESS_KEY: Optional[str] = Field(
|
||||
description="OCI storage access key",
|
||||
description="Access key (also known as API key) for authenticating with OCI Object Storage",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCI_SECRET_KEY: Optional[str] = Field(
|
||||
description="OCI storage secret key",
|
||||
description="Secret key associated with the access key for authenticating with OCI Object Storage",
|
||||
default=None,
|
||||
)
|
||||
|
@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class TencentCloudCOSStorageConfig(BaseSettings):
|
||||
"""
|
||||
Tencent Cloud COS storage configs
|
||||
Configuration settings for Tencent Cloud Object Storage (COS)
|
||||
"""
|
||||
|
||||
TENCENT_COS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Tencent Cloud COS bucket name",
|
||||
description="Name of the Tencent Cloud COS bucket to store and retrieve objects",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_REGION: Optional[str] = Field(
|
||||
description="Tencent Cloud COS region",
|
||||
description="Tencent Cloud region where the COS bucket is located (e.g., 'ap-guangzhou')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_SECRET_ID: Optional[str] = Field(
|
||||
description="Tencent Cloud COS secret id",
|
||||
description="SecretId for authenticating with Tencent Cloud COS (part of API credentials)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Tencent Cloud COS secret key",
|
||||
description="SecretKey for authenticating with Tencent Cloud COS (part of API credentials)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_SCHEME: Optional[str] = Field(
|
||||
description="Tencent Cloud COS scheme",
|
||||
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
|
||||
default=None,
|
||||
)
|
||||
|
@ -5,30 +5,30 @@ from pydantic import BaseModel, Field
|
||||
|
||||
class VolcengineTOSStorageConfig(BaseModel):
|
||||
"""
|
||||
Volcengine tos storage configs
|
||||
Configuration settings for Volcengine Tinder Object Storage (TOS)
|
||||
"""
|
||||
|
||||
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Volcengine TOS Bucket Name",
|
||||
description="Name of the Volcengine TOS bucket to store and retrieve objects (e.g., 'my-tos-bucket')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Volcengine TOS Access Key",
|
||||
description="Access Key ID for authenticating with Volcengine TOS",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Volcengine TOS Secret Key",
|
||||
description="Secret Access Key for authenticating with Volcengine TOS",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
|
||||
description="Volcengine TOS Endpoint URL",
|
||||
description="URL of the Volcengine TOS endpoint (e.g., 'https://tos-cn-beijing.volces.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_REGION: Optional[str] = Field(
|
||||
description="Volcengine TOS Region",
|
||||
description="Volcengine region where the TOS bucket is located (e.g., 'cn-beijing')",
|
||||
default=None,
|
||||
)
|
||||
|
@ -5,33 +5,38 @@ from pydantic import BaseModel, Field
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to AnalyticDB.
|
||||
Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
|
||||
"""
|
||||
|
||||
ANALYTICDB_KEY_ID: Optional[str] = Field(
|
||||
default=None, description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||
default=None, description="The Access Key ID provided by Alibaba Cloud for API authentication."
|
||||
)
|
||||
ANALYTICDB_KEY_SECRET: Optional[str] = Field(
|
||||
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure API access."
|
||||
)
|
||||
ANALYTICDB_REGION_ID: Optional[str] = Field(
|
||||
default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||
default=None,
|
||||
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou', 'ap-southeast-1').",
|
||||
)
|
||||
ANALYTICDB_INSTANCE_ID: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..",
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to.",
|
||||
)
|
||||
ANALYTICDB_ACCOUNT: Optional[str] = Field(
|
||||
default=None, description="The account name used to log in to the AnalyticDB instance."
|
||||
default=None,
|
||||
description="The account name used to log in to the AnalyticDB instance"
|
||||
" (usually the initial account created with the instance).",
|
||||
)
|
||||
ANALYTICDB_PASSWORD: Optional[str] = Field(
|
||||
default=None, description="The password associated with the AnalyticDB account for authentication."
|
||||
default=None, description="The password associated with the AnalyticDB account for database authentication."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE: Optional[str] = Field(
|
||||
default=None, description="The namespace within AnalyticDB for schema isolation."
|
||||
default=None, description="The namespace within AnalyticDB for schema isolation (if using namespace feature)."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field(
|
||||
default=None, description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||
default=None,
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance"
|
||||
" (if namespace feature is enabled).",
|
||||
)
|
||||
|
@ -6,35 +6,35 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class ChromaConfig(BaseSettings):
|
||||
"""
|
||||
Chroma configs
|
||||
Configuration settings for Chroma vector database
|
||||
"""
|
||||
|
||||
CHROMA_HOST: Optional[str] = Field(
|
||||
description="Chroma host",
|
||||
description="Hostname or IP address of the Chroma server (e.g., 'localhost' or '192.168.1.100')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_PORT: PositiveInt = Field(
|
||||
description="Chroma port",
|
||||
description="Port number on which the Chroma server is listening (default is 8000)",
|
||||
default=8000,
|
||||
)
|
||||
|
||||
CHROMA_TENANT: Optional[str] = Field(
|
||||
description="Chroma database",
|
||||
description="Tenant identifier for multi-tenancy support in Chroma",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_DATABASE: Optional[str] = Field(
|
||||
description="Chroma database",
|
||||
description="Name of the Chroma database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_AUTH_PROVIDER: Optional[str] = Field(
|
||||
description="Chroma authentication provider",
|
||||
description="Authentication provider for Chroma (e.g., 'basic', 'token', or a custom provider)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CHROMA_AUTH_CREDENTIALS: Optional[str] = Field(
|
||||
description="Chroma authentication credentials",
|
||||
description="Authentication credentials for Chroma (format depends on the auth provider)",
|
||||
default=None,
|
||||
)
|
||||
|
@ -6,25 +6,25 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class ElasticsearchConfig(BaseSettings):
|
||||
"""
|
||||
Elasticsearch configs
|
||||
Configuration settings for Elasticsearch
|
||||
"""
|
||||
|
||||
ELASTICSEARCH_HOST: Optional[str] = Field(
|
||||
description="Elasticsearch host",
|
||||
description="Hostname or IP address of the Elasticsearch server (e.g., 'localhost' or '192.168.1.100')",
|
||||
default="127.0.0.1",
|
||||
)
|
||||
|
||||
ELASTICSEARCH_PORT: PositiveInt = Field(
|
||||
description="Elasticsearch port",
|
||||
description="Port number on which the Elasticsearch server is listening (default is 9200)",
|
||||
default=9200,
|
||||
)
|
||||
|
||||
ELASTICSEARCH_USERNAME: Optional[str] = Field(
|
||||
description="Elasticsearch username",
|
||||
description="Username for authenticating with Elasticsearch (default is 'elastic')",
|
||||
default="elastic",
|
||||
)
|
||||
|
||||
ELASTICSEARCH_PASSWORD: Optional[str] = Field(
|
||||
description="Elasticsearch password",
|
||||
description="Password for authenticating with Elasticsearch (default is 'elastic')",
|
||||
default="elastic",
|
||||
)
|
||||
|
@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class MilvusConfig(BaseSettings):
|
||||
"""
|
||||
Milvus configs
|
||||
Configuration settings for Milvus vector database
|
||||
"""
|
||||
|
||||
MILVUS_URI: Optional[str] = Field(
|
||||
description="Milvus uri",
|
||||
description="URI for connecting to the Milvus server (e.g., 'http://localhost:19530' or 'https://milvus-instance.example.com:19530')",
|
||||
default="http://127.0.0.1:19530",
|
||||
)
|
||||
|
||||
MILVUS_TOKEN: Optional[str] = Field(
|
||||
description="Milvus token",
|
||||
description="Authentication token for Milvus, if token-based authentication is enabled",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_USER: Optional[str] = Field(
|
||||
description="Milvus user",
|
||||
description="Username for authenticating with Milvus, if username/password authentication is enabled",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_PASSWORD: Optional[str] = Field(
|
||||
description="Milvus password",
|
||||
description="Password for authenticating with Milvus, if username/password authentication is enabled",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_DATABASE: str = Field(
|
||||
description="Milvus database, default to `default`",
|
||||
description="Name of the Milvus database to connect to (default is 'default')",
|
||||
default="default",
|
||||
)
|
||||
|
@ -3,35 +3,35 @@ from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
class MyScaleConfig(BaseModel):
|
||||
"""
|
||||
MyScale configs
|
||||
Configuration settings for MyScale vector database
|
||||
"""
|
||||
|
||||
MYSCALE_HOST: str = Field(
|
||||
description="MyScale host",
|
||||
description="Hostname or IP address of the MyScale server (e.g., 'localhost' or 'myscale.example.com')",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
MYSCALE_PORT: PositiveInt = Field(
|
||||
description="MyScale port",
|
||||
description="Port number on which the MyScale server is listening (default is 8123)",
|
||||
default=8123,
|
||||
)
|
||||
|
||||
MYSCALE_USER: str = Field(
|
||||
description="MyScale user",
|
||||
description="Username for authenticating with MyScale (default is 'default')",
|
||||
default="default",
|
||||
)
|
||||
|
||||
MYSCALE_PASSWORD: str = Field(
|
||||
description="MyScale password",
|
||||
description="Password for authenticating with MyScale (default is an empty string)",
|
||||
default="",
|
||||
)
|
||||
|
||||
MYSCALE_DATABASE: str = Field(
|
||||
description="MyScale database name",
|
||||
description="Name of the MyScale database to connect to (default is 'default')",
|
||||
default="default",
|
||||
)
|
||||
|
||||
MYSCALE_FTS_PARAMS: str = Field(
|
||||
description="MyScale fts index parameters",
|
||||
description="Additional parameters for MyScale Full Text Search index)",
|
||||
default="",
|
||||
)
|
||||
|
@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class OpenSearchConfig(BaseSettings):
|
||||
"""
|
||||
OpenSearch configs
|
||||
Configuration settings for OpenSearch
|
||||
"""
|
||||
|
||||
OPENSEARCH_HOST: Optional[str] = Field(
|
||||
description="OpenSearch host",
|
||||
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_PORT: PositiveInt = Field(
|
||||
description="OpenSearch port",
|
||||
description="Port number on which the OpenSearch server is listening (default is 9200)",
|
||||
default=9200,
|
||||
)
|
||||
|
||||
OPENSEARCH_USER: Optional[str] = Field(
|
||||
description="OpenSearch user",
|
||||
description="Username for authenticating with OpenSearch",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_PASSWORD: Optional[str] = Field(
|
||||
description="OpenSearch password",
|
||||
description="Password for authenticating with OpenSearch",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OPENSEARCH_SECURE: bool = Field(
|
||||
description="whether to use SSL connection for OpenSearch",
|
||||
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
|
||||
default=False,
|
||||
)
|
||||
|
@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class OracleConfig(BaseSettings):
|
||||
"""
|
||||
ORACLE configs
|
||||
Configuration settings for Oracle database
|
||||
"""
|
||||
|
||||
ORACLE_HOST: Optional[str] = Field(
|
||||
description="ORACLE host",
|
||||
description="Hostname or IP address of the Oracle database server (e.g., 'localhost' or 'oracle.example.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_PORT: Optional[PositiveInt] = Field(
|
||||
description="ORACLE port",
|
||||
description="Port number on which the Oracle database server is listening (default is 1521)",
|
||||
default=1521,
|
||||
)
|
||||
|
||||
ORACLE_USER: Optional[str] = Field(
|
||||
description="ORACLE user",
|
||||
description="Username for authenticating with the Oracle database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_PASSWORD: Optional[str] = Field(
|
||||
description="ORACLE password",
|
||||
description="Password for authenticating with the Oracle database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_DATABASE: Optional[str] = Field(
|
||||
description="ORACLE database",
|
||||
description="Name of the Oracle database or service to connect to (e.g., 'ORCL' or 'pdborcl')",
|
||||
default=None,
|
||||
)
|
||||
|
@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class PGVectorConfig(BaseSettings):
|
||||
"""
|
||||
PGVector configs
|
||||
Configuration settings for PGVector (PostgreSQL with vector extension)
|
||||
"""
|
||||
|
||||
PGVECTOR_HOST: Optional[str] = Field(
|
||||
description="PGVector host",
|
||||
description="Hostname or IP address of the PostgreSQL server with PGVector extension (e.g., 'localhost')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_PORT: Optional[PositiveInt] = Field(
|
||||
description="PGVector port",
|
||||
description="Port number on which the PostgreSQL server is listening (default is 5433)",
|
||||
default=5433,
|
||||
)
|
||||
|
||||
PGVECTOR_USER: Optional[str] = Field(
|
||||
description="PGVector user",
|
||||
description="Username for authenticating with the PostgreSQL database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_PASSWORD: Optional[str] = Field(
|
||||
description="PGVector password",
|
||||
description="Password for authenticating with the PostgreSQL database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_DATABASE: Optional[str] = Field(
|
||||
description="PGVector database",
|
||||
description="Name of the PostgreSQL database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class PGVectoRSConfig(BaseSettings):
|
||||
"""
|
||||
PGVectoRS configs
|
||||
Configuration settings for PGVecto.RS (Rust-based vector extension for PostgreSQL)
|
||||
"""
|
||||
|
||||
PGVECTO_RS_HOST: Optional[str] = Field(
|
||||
description="PGVectoRS host",
|
||||
description="Hostname or IP address of the PostgreSQL server with PGVecto.RS extension (e.g., 'localhost')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
|
||||
description="PGVectoRS port",
|
||||
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
|
||||
default=5431,
|
||||
)
|
||||
|
||||
PGVECTO_RS_USER: Optional[str] = Field(
|
||||
description="PGVectoRS user",
|
||||
description="Username for authenticating with the PostgreSQL database using PGVecto.RS",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_PASSWORD: Optional[str] = Field(
|
||||
description="PGVectoRS password",
|
||||
description="Password for authenticating with the PostgreSQL database using PGVecto.RS",
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_DATABASE: Optional[str] = Field(
|
||||
description="PGVectoRS database",
|
||||
description="Name of the PostgreSQL database with PGVecto.RS extension to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class QdrantConfig(BaseSettings):
|
||||
"""
|
||||
Qdrant configs
|
||||
Configuration settings for Qdrant vector database
|
||||
"""
|
||||
|
||||
QDRANT_URL: Optional[str] = Field(
|
||||
description="Qdrant url",
|
||||
description="URL of the Qdrant server (e.g., 'http://localhost:6333' or 'https://qdrant.example.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
QDRANT_API_KEY: Optional[str] = Field(
|
||||
description="Qdrant api key",
|
||||
description="API key for authenticating with the Qdrant server",
|
||||
default=None,
|
||||
)
|
||||
|
||||
QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
|
||||
description="Qdrant client timeout in seconds",
|
||||
description="Timeout in seconds for Qdrant client operations (default is 20 seconds)",
|
||||
default=20,
|
||||
)
|
||||
|
||||
QDRANT_GRPC_ENABLED: bool = Field(
|
||||
description="whether enable grpc support for Qdrant connection",
|
||||
description="Whether to enable gRPC support for Qdrant connection (True for gRPC, False for HTTP)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
QDRANT_GRPC_PORT: PositiveInt = Field(
|
||||
description="Qdrant grpc port",
|
||||
description="Port number for gRPC connection to Qdrant server (default is 6334)",
|
||||
default=6334,
|
||||
)
|
||||
|
@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class RelytConfig(BaseSettings):
|
||||
"""
|
||||
Relyt configs
|
||||
Configuration settings for Relyt database
|
||||
"""
|
||||
|
||||
RELYT_HOST: Optional[str] = Field(
|
||||
description="Relyt host",
|
||||
description="Hostname or IP address of the Relyt server (e.g., 'localhost' or 'relyt.example.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RELYT_PORT: PositiveInt = Field(
|
||||
description="Relyt port",
|
||||
description="Port number on which the Relyt server is listening (default is 9200)",
|
||||
default=9200,
|
||||
)
|
||||
|
||||
RELYT_USER: Optional[str] = Field(
|
||||
description="Relyt user",
|
||||
description="Username for authenticating with the Relyt database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RELYT_PASSWORD: Optional[str] = Field(
|
||||
description="Relyt password",
|
||||
description="Password for authenticating with the Relyt database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
RELYT_DATABASE: Optional[str] = Field(
|
||||
description="Relyt database",
|
||||
description="Name of the Relyt database to connect to (default is 'default')",
|
||||
default="default",
|
||||
)
|
||||
|
@ -6,45 +6,45 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class TencentVectorDBConfig(BaseSettings):
|
||||
"""
|
||||
Tencent Vector configs
|
||||
Configuration settings for Tencent Vector Database
|
||||
"""
|
||||
|
||||
TENCENT_VECTOR_DB_URL: Optional[str] = Field(
|
||||
description="Tencent Vector URL",
|
||||
description="URL of the Tencent Vector Database service (e.g., 'https://vectordb.tencentcloudapi.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field(
|
||||
description="Tencent Vector API key",
|
||||
description="API key for authenticating with the Tencent Vector Database service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field(
|
||||
description="Tencent Vector timeout in seconds",
|
||||
description="Timeout in seconds for Tencent Vector Database operations (default is 30 seconds)",
|
||||
default=30,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
|
||||
description="Tencent Vector username",
|
||||
description="Username for authenticating with the Tencent Vector Database (if required)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field(
|
||||
description="Tencent Vector password",
|
||||
description="Password for authenticating with the Tencent Vector Database (if required)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_SHARD: PositiveInt = Field(
|
||||
description="Tencent Vector sharding number",
|
||||
description="Number of shards for the Tencent Vector Database (default is 1)",
|
||||
default=1,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
|
||||
description="Tencent Vector replicas",
|
||||
description="Number of replicas for the Tencent Vector Database (default is 2)",
|
||||
default=2,
|
||||
)
|
||||
|
||||
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
|
||||
description="Tencent Vector Database",
|
||||
description="Name of the specific Tencent Vector Database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class TiDBVectorConfig(BaseSettings):
|
||||
"""
|
||||
TiDB Vector configs
|
||||
Configuration settings for TiDB Vector database
|
||||
"""
|
||||
|
||||
TIDB_VECTOR_HOST: Optional[str] = Field(
|
||||
description="TiDB Vector host",
|
||||
description="Hostname or IP address of the TiDB Vector server (e.g., 'localhost' or 'tidb.example.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_PORT: Optional[PositiveInt] = Field(
|
||||
description="TiDB Vector port",
|
||||
description="Port number on which the TiDB Vector server is listening (default is 4000)",
|
||||
default=4000,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_USER: Optional[str] = Field(
|
||||
description="TiDB Vector user",
|
||||
description="Username for authenticating with the TiDB Vector database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_PASSWORD: Optional[str] = Field(
|
||||
description="TiDB Vector password",
|
||||
description="Password for authenticating with the TiDB Vector database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_VECTOR_DATABASE: Optional[str] = Field(
|
||||
description="TiDB Vector database",
|
||||
description="Name of the TiDB Vector database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
@ -6,25 +6,25 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
class WeaviateConfig(BaseSettings):
|
||||
"""
|
||||
Weaviate configs
|
||||
Configuration settings for Weaviate vector database
|
||||
"""
|
||||
|
||||
WEAVIATE_ENDPOINT: Optional[str] = Field(
|
||||
description="Weaviate endpoint URL",
|
||||
description="URL of the Weaviate server (e.g., 'http://localhost:8080' or 'https://weaviate.example.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_API_KEY: Optional[str] = Field(
|
||||
description="Weaviate API key",
|
||||
description="API key for authenticating with the Weaviate server",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENABLED: bool = Field(
|
||||
description="whether to enable gRPC for Weaviate connection",
|
||||
description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)",
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Weaviate batch size",
|
||||
description="Number of objects to be processed in a single batch operation (default is 100)",
|
||||
default=100,
|
||||
)
|
||||
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.8.2",
|
||||
default="0.8.3",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
@ -1 +1,2 @@
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||
|
@ -109,6 +109,7 @@ class ChatMessageApi(Resource):
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
args = parser.parse_args()
|
||||
|
@ -105,8 +105,6 @@ class ChatMessageListApi(Resource):
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
|
||||
|
||||
|
@ -166,6 +166,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
parser.add_argument("query", type=str, required=True, location="json", default="")
|
||||
parser.add_argument("files", type=list, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
@ -100,6 +100,7 @@ class ChatApi(InstalledAppResource):
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -51,7 +51,7 @@ class MessageListApi(InstalledAppResource):
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
@ -54,6 +54,7 @@ class MessageListApi(Resource):
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
|
@ -96,6 +96,7 @@ class ChatApi(WebApiResource):
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -57,6 +57,7 @@ class MessageListApi(WebApiResource):
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
@ -89,7 +90,7 @@ class MessageListApi(WebApiResource):
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
@ -32,6 +32,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolRuntimeVariablePool,
|
||||
@ -441,10 +442,12 @@ class BaseAgentRunner(AppRunner):
|
||||
.filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.order_by(Message.created_at.asc())
|
||||
.order_by(Message.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
messages = list(reversed(extract_thread_messages(messages)))
|
||||
|
||||
for message in messages:
|
||||
if message.id == self.message.id:
|
||||
continue
|
||||
|
@ -121,6 +121,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
@ -127,6 +127,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
@ -128,6 +128,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
@ -218,6 +218,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
|
@ -122,6 +122,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
@ -138,6 +139,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
@ -149,6 +151,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
query: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
|
@ -47,6 +47,8 @@ class LLMGenerator:
|
||||
)
|
||||
answer = response.message.content
|
||||
cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
|
||||
if cleaned_answer is None:
|
||||
return ""
|
||||
result_dict = json.loads(cleaned_answer)
|
||||
answer = result_dict["Your Output"]
|
||||
name = answer.strip()
|
||||
|
@ -11,6 +11,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowRun
|
||||
@ -33,8 +34,17 @@ class TokenBufferMemory:
|
||||
|
||||
# fetch limited messages, and return reversed
|
||||
query = (
|
||||
db.session.query(Message.id, Message.query, Message.answer, Message.created_at, Message.workflow_run_id)
|
||||
.filter(Message.conversation_id == self.conversation.id, Message.answer != "")
|
||||
db.session.query(
|
||||
Message.id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message.created_at,
|
||||
Message.workflow_run_id,
|
||||
Message.parent_message_id,
|
||||
)
|
||||
.filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
@ -45,7 +55,12 @@ class TokenBufferMemory:
|
||||
|
||||
messages = query.limit(message_limit).all()
|
||||
|
||||
messages = list(reversed(messages))
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
thread_messages.pop(0)
|
||||
messages = list(reversed(thread_messages))
|
||||
|
||||
message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id)
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
|
@ -62,7 +62,7 @@ pricing: # 价格信息
|
||||
|
||||
建议将所有模型配置都准备完毕后再开始模型代码的实现。
|
||||
|
||||
同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#AIModel)。
|
||||
同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#aimodelentity)。
|
||||
|
||||
### 实现模型调用代码
|
||||
|
||||
|
@ -37,3 +37,4 @@
|
||||
- siliconflow
|
||||
- perfxcloud
|
||||
- zhinao
|
||||
- fireworks
|
||||
|
@ -0,0 +1,3 @@
|
||||
<svg width="130" role="graphics-symbol" aria-label="Fireworks AI Home" viewBox="0 0 835 130" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M112.65 0L91.33 51.09L69.99 0H56.3L79.69 55.85C81.63 60.51 86.18 63.52 91.25 63.52C96.32 63.52 100.86 60.51 102.81 55.87L126.34 0H112.65ZM121.76 77.84L160.76 38.41L155.44 25.86L112.84 69.01C109.28 72.62 108.26 77.94 110.23 82.6C112.19 87.22 116.72 90.21 121.77 90.21L121.79 90.23L182.68 90.08L177.36 77.53L121.77 77.84H121.76ZM21.92 38.38L27.24 25.83L69.84 68.98C73.4 72.58 74.43 77.92 72.45 82.57C70.49 87.2 65.94 90.18 60.91 90.18L0.02 90.04L0 90.06L5.32 77.51L60.91 77.82L21.92 38.38Z" fill="#6720FF"></path>
|
||||
<path d="M231.32 85.2198L231.33 85.2298H241.8V49.1698H275.62V39.8198H241.8V16.3598H279V7.00977H231.32V85.2198Z" class="fill-black dark:fill-white"></path><path d="M299.68 28.73H289.86V85.22H299.68V28.73Z" class="fill-black dark:fill-white"></path><path d="M324.58 36.2198H324.59C324.37 36.7598 324.16 37.0898 323.5 37.0898C322.95 37.0898 322.74 36.8798 322.74 36.3398V28.7298H312.92V85.2198H322.72V53.1598C322.72 42.3098 327.75 38.0698 337.24 38.0698H345.1V28.5098H338.77C331.03 28.5098 327.1 30.7898 324.58 36.2198Z" class="fill-black dark:fill-white"></path><path d="M377.76 78.3996C367.23 78.3996 359.37 72.4196 358.71 59.7196H404.6V54.2796C404.6 38.5296 395 27.1196 377.53 27.1196C360.06 27.1196 348.93 38.5296 348.93 56.9896C348.93 75.4496 359.73 86.8596 377.74 86.8596C395.75 86.8596 403.15 75.8996 404.81 67.3196H394.57C392.98 73.7396 388.29 78.3996 377.76 78.3996ZM377.53 35.5696C387.91 35.5696 394.33 41.1196 394.78 51.5496H358.98C360.61 40.8896 368.14 35.5696 377.53 35.5696Z" class="fill-black dark:fill-white"></path><path d="M474.29 74.68C474.05 75.66 473.75 75.99 472.97 75.99C472.19 75.99 471.86 75.66 471.65 74.68L460.73 28.73H443.81L432.89 74.68C432.65 75.66 432.35 75.99 431.57 75.99C430.79 75.99 430.46 75.66 430.25 74.68L419.33 28.73H409.73V30.91H409.79L423.11 85.22H439.97L451.22 37.85C451.43 37.08 451.64 36.87 452.3 36.87C452.84 36.87 453.17 37.1 453.38 37.85L464.63 85.22H481.49L494.81 30.91V28.73H485.21L474.29 74.68Z" class="fill-black dark:fill-white"></path><path d="M529.05 27.1099C512.56 27.1099 499.47 37.4199 499.47 56.9799C499.47 76.5399 512.55 86.8499 529.05 86.8499C545.55 86.8499 558.64 76.5399 558.64 56.9799C558.64 37.4199 545.54 27.1099 529.05 27.1099ZM529.07 78.1599C517.61 78.1599 509.42 70.5699 509.42 56.9799C509.42 43.3899 517.61 35.7999 529.07 35.7999C540.53 35.7999 548.72 43.4099 548.72 56.9799C548.72 70.5499 540.53 78.1599 529.07 78.1599Z" class="fill-black dark:fill-white"></path><path d="M580.68 36.2198C580.47 36.7598 580.26 37.0898 579.6 37.0898C579.05 37.0898 578.841 36.8798 578.841 36.3398V28.7298H569.021V85.2098H578.82V53.1598C578.82 42.3098 583.851 38.0698 593.341 38.0698H601.201V28.5098H594.87C587.13 28.5098 583.2 30.7898 580.68 36.2198Z" class="fill-black dark:fill-white"></path><path d="M618.591 55.0198V7.00977H608.771V85.2698H618.591V67.2298L629.24 58.1498L650.42 85.2498H661.16V83.0698L636.49 51.9398L661.16 30.9098V28.7298H648.54L618.591 55.0198Z" class="fill-black dark:fill-white"></path><path d="M695.19 52.8899L687.12 51.3699C679.38 49.8999 675.99 48.2799 675.99 43.5999C675.99 38.9199 679.82 35.4499 688.98 35.4499C698.14 35.4499 703.38 38.9399 704.14 46.6499H714.14C713.03 32.8799 702.34 27.1299 688.94 27.1299C675.54 27.1299 666.13 32.8899 666.13 43.7399C666.13 54.5899 673.83 58.3499 684.91 60.4099L692.98 61.9299C700.84 63.3999 704.77 65.0899 704.77 69.9699C704.77 74.8499 700.83 78.4899 691.35 78.4899C681.87 78.4899 675.58 74.5799 674.82 67.0799H664.83C665.76 80.5499 676.73 86.8499 691.36 86.8499C705.99 86.8499 714.61 80.6099 714.61 69.4099C714.61 58.2099 705.55 54.8399 695.19 52.8899Z" class="fill-black dark:fill-white"></path><path d="M834.64 7.00977H823.63V85.2698H834.64V7.00977Z" class="fill-black dark:fill-white"></path><path d="M770.23 7.77L739.71 83.8398V85.2698H750.61L758.34 64.8398H795.08L802.81 85.2698H814.04V83.8598L783.3 7.00977H770.23ZM761.97 55.3798L775.09 21.0098H775.08C775.3 20.4198 775.87 20.0298 776.5 20.0298H777.04C777.67 20.0298 778.24 20.4198 778.46 21.0098L791.48 55.3798H761.97Z" class="fill-black dark:fill-white"></path><path d="M299.68 7.00977H289.86V18.5298H299.68V7.00977Z" class="fill-black dark:fill-white"></path></svg>
|
After Width: | Height: | Size: 4.2 KiB |
@ -0,0 +1,5 @@
|
||||
<svg width="638" height="315" viewBox="0 0 638 315" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M318.563 221.755C300.863 221.755 284.979 211.247 278.206 194.978L196.549 0H244.342L318.842 178.361L393.273 0H441.066L358.92 195.048C352.112 211.247 336.263 221.755 318.563 221.755Z" fill="#6720FF"/>
|
||||
<path d="M425.111 314.933C407.481 314.933 391.667 304.494 384.824 288.366C377.947 272.097 381.507 253.524 393.936 240.921L542.657 90.2803L561.229 134.094L425.076 271.748L619.147 270.666L637.72 314.479L425.146 315.003L425.076 314.933H425.111Z" fill="#6720FF"/>
|
||||
<path d="M0 314.408L18.5727 270.595L212.643 271.677L76.525 133.988L95.0977 90.1748L243.819 240.816C256.247 253.384 259.843 272.026 252.93 288.26C246.088 304.424 230.203 314.827 212.643 314.827L0.0698221 314.339L0 314.408Z" fill="#6720FF"/>
|
||||
</svg>
|
After Width: | Height: | Size: 815 B |
52
api/core/model_runtime/model_providers/fireworks/_common.py
Normal file
52
api/core/model_runtime/model_providers/fireworks/_common.py
Normal file
@ -0,0 +1,52 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import openai
|
||||
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
|
||||
|
||||
class _CommonFireworks:
|
||||
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
|
||||
"""
|
||||
Transform credentials to kwargs for model instance
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials["fireworks_api_key"],
|
||||
"base_url": "https://api.fireworks.ai/inference/v1",
|
||||
"max_retries": 1,
|
||||
}
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||
InvokeRateLimitError: [openai.RateLimitError],
|
||||
InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
|
||||
InvokeBadRequestError: [
|
||||
openai.BadRequestError,
|
||||
openai.NotFoundError,
|
||||
openai.UnprocessableEntityError,
|
||||
openai.APIError,
|
||||
],
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FireworksProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
model_instance.validate_credentials(
|
||||
model="accounts/fireworks/models/llama-v3p1-8b-instruct", credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
@ -0,0 +1,29 @@
|
||||
provider: fireworks
|
||||
label:
|
||||
zh_Hans: Fireworks AI
|
||||
en_US: Fireworks AI
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#FCFDFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from Fireworks AI
|
||||
zh_Hans: 从 Fireworks AI 获取 API Key
|
||||
url:
|
||||
en_US: https://fireworks.ai/account/api-keys
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: fireworks_api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
@ -0,0 +1,16 @@
|
||||
- llama-v3p1-405b-instruct
|
||||
- llama-v3p1-70b-instruct
|
||||
- llama-v3p1-8b-instruct
|
||||
- llama-v3-70b-instruct
|
||||
- mixtral-8x22b-instruct
|
||||
- mixtral-8x7b-instruct
|
||||
- firefunction-v2
|
||||
- firefunction-v1
|
||||
- gemma2-9b-it
|
||||
- llama-v3-70b-instruct-hf
|
||||
- llama-v3-8b-instruct
|
||||
- llama-v3-8b-instruct-hf
|
||||
- mixtral-8x7b-instruct-hf
|
||||
- mythomax-l2-13b
|
||||
- phi-3-vision-128k-instruct
|
||||
- yi-large
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/firefunction-v1
|
||||
label:
|
||||
zh_Hans: Firefunction V1
|
||||
en_US: Firefunction V1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.5'
|
||||
output: '0.5'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/firefunction-v2
|
||||
label:
|
||||
zh_Hans: Firefunction V2
|
||||
en_US: Firefunction V2
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.9'
|
||||
output: '0.9'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,45 @@
|
||||
model: accounts/fireworks/models/gemma2-9b-it
|
||||
label:
|
||||
zh_Hans: Gemma2 9B Instruct
|
||||
en_US: Gemma2 9B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.2'
|
||||
output: '0.2'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/llama-v3-70b-instruct-hf
|
||||
label:
|
||||
zh_Hans: Llama3 70B Instruct(HF version)
|
||||
en_US: Llama3 70B Instruct(HF version)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.9'
|
||||
output: '0.9'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/llama-v3-70b-instruct
|
||||
label:
|
||||
zh_Hans: Llama3 70B Instruct
|
||||
en_US: Llama3 70B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.9'
|
||||
output: '0.9'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/llama-v3-8b-instruct-hf
|
||||
label:
|
||||
zh_Hans: Llama3 8B Instruct(HF version)
|
||||
en_US: Llama3 8B Instruct(HF version)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.2'
|
||||
output: '0.2'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/llama-v3-8b-instruct
|
||||
label:
|
||||
zh_Hans: Llama3 8B Instruct
|
||||
en_US: Llama3 8B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.2'
|
||||
output: '0.2'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||
label:
|
||||
zh_Hans: Llama3.1 405B Instruct
|
||||
en_US: Llama3.1 405B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '3'
|
||||
output: '3'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/llama-v3p1-70b-instruct
|
||||
label:
|
||||
zh_Hans: Llama3.1 70B Instruct
|
||||
en_US: Llama3.1 70B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.2'
|
||||
output: '0.2'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/llama-v3p1-8b-instruct
|
||||
label:
|
||||
zh_Hans: Llama3.1 8B Instruct
|
||||
en_US: Llama3.1 8B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.2'
|
||||
output: '0.2'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
610
api/core/model_runtime/model_providers/fireworks/llm/llm.py
Normal file
610
api/core/model_runtime/model_providers/fireworks/llm/llm.py
Normal file
@ -0,0 +1,610 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from openai import OpenAI, Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FIREWORKS_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
|
||||
"""
|
||||
Model class for Fireworks large language model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
return self._chat_generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def _code_block_mode_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
"""
|
||||
if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
|
||||
stop = stop or []
|
||||
self._transform_chat_json_prompts(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
response_format=model_parameters["response_format"],
|
||||
)
|
||||
model_parameters.pop("response_format")
|
||||
|
||||
return self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def _transform_chat_json_prompts(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
response_format: str = "JSON",
|
||||
) -> None:
|
||||
"""
|
||||
Transform json prompts
|
||||
"""
|
||||
if stop is None:
|
||||
stop = []
|
||||
if "```\n" not in stop:
|
||||
stop.append("```\n")
|
||||
if "\n```" not in stop:
|
||||
stop.append("\n```")
|
||||
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=FIREWORKS_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
|
||||
"{{block}}", response_format
|
||||
)
|
||||
)
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
|
||||
else:
|
||||
prompt_messages.insert(
|
||||
0,
|
||||
SystemPromptMessage(
|
||||
content=FIREWORKS_BLOCK_MODE_PROMPT.replace(
|
||||
"{{instructions}}", f"Please output a valid {response_format} object."
|
||||
).replace("{{block}}", response_format)
|
||||
),
|
||||
)
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
return self._num_tokens_from_messages(model, prompt_messages, tools)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
|
||||
client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=10, stream=False
|
||||
)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(str(e))
|
||||
|
||||
def _chat_generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs["functions"] = [
|
||||
{"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools
|
||||
]
|
||||
|
||||
if stop:
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
# chat model
|
||||
response = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: ChatCompletion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return: llm response
|
||||
"""
|
||||
assistant_message = response.choices[0].message
|
||||
# assistant_message_tool_calls = assistant_message.tool_calls
|
||||
assistant_message_function_call = assistant_message.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||
tool_calls = [function_call] if function_call else []
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
|
||||
|
||||
# calculate num tokens
|
||||
if response.usage:
|
||||
# transform usage
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
completion_tokens = response.usage.completion_tokens
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
|
||||
completion_tokens = self._num_tokens_from_messages(model, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=response.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
|
||||
:param model: model name
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ""
|
||||
delta_assistant_message_function_call_storage: Optional[ChoiceDeltaFunctionCall] = None
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
final_tool_calls = []
|
||||
final_chunk = LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
),
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
if chunk.usage:
|
||||
# calculate num tokens
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0]
|
||||
has_finish_reason = delta.finish_reason is not None
|
||||
|
||||
if (
|
||||
not has_finish_reason
|
||||
and (delta.delta.content is None or delta.delta.content == "")
|
||||
and delta.delta.function_call is None
|
||||
):
|
||||
continue
|
||||
|
||||
# assistant_message_tool_calls = delta.delta.tool_calls
|
||||
assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if delta_assistant_message_function_call_storage is not None:
|
||||
# handle process of stream function call
|
||||
if assistant_message_function_call:
|
||||
# message has not ended ever
|
||||
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
|
||||
continue
|
||||
else:
|
||||
# message has ended
|
||||
assistant_message_function_call = delta_assistant_message_function_call_storage
|
||||
delta_assistant_message_function_call_storage = None
|
||||
else:
|
||||
if assistant_message_function_call:
|
||||
# start of stream function call
|
||||
delta_assistant_message_function_call_storage = assistant_message_function_call
|
||||
if delta_assistant_message_function_call_storage.arguments is None:
|
||||
delta_assistant_message_function_call_storage.arguments = ""
|
||||
if not has_finish_reason:
|
||||
continue
|
||||
|
||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||
tool_calls = [function_call] if function_call else []
|
||||
if tool_calls:
|
||||
final_tool_calls.extend(tool_calls)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
|
||||
|
||||
full_assistant_content += delta.delta.content or ""
|
||||
|
||||
if has_finish_reason:
|
||||
final_chunk = LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=chunk.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=chunk.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
if not prompt_tokens:
|
||||
prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
|
||||
|
||||
if not completion_tokens:
|
||||
full_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=full_assistant_content, tool_calls=final_tool_calls
|
||||
)
|
||||
completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
final_chunk.delta.usage = usage
|
||||
|
||||
yield final_chunk
|
||||
|
||||
def _extract_response_tool_calls(
|
||||
self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]
|
||||
) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
:param response_tool_calls: response tool calls
|
||||
:return: list of tool calls
|
||||
"""
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call.id, type=response_tool_call.type, function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _extract_response_function_call(
|
||||
self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall
|
||||
) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Extract function call from response
|
||||
|
||||
:param response_function_call: response function call
|
||||
:return: tool call
|
||||
"""
|
||||
tool_call = None
|
||||
if response_function_call:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_function_call.name, arguments=response_function_call.arguments
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_function_call.name, type="function", function=function
|
||||
)
|
||||
|
||||
return tool_call
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for Fireworks API
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
# message_dict["tool_calls"] = [tool_call.dict() for tool_call in
|
||||
# message.tool_calls]
|
||||
function_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
# message_dict = {
|
||||
# "role": "tool",
|
||||
# "content": message.content,
|
||||
# "tool_call_id": message.tool_call_id
|
||||
# }
|
||||
message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
credentials: dict = None,
|
||||
) -> int:
|
||||
"""
|
||||
Approximate num tokens with GPT2 tokenizer.
|
||||
"""
|
||||
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(t_key)
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(f_key)
|
||||
num_tokens += self._get_num_tokens_by_gpt2(f_value)
|
||||
else:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(t_key)
|
||||
num_tokens += self._get_num_tokens_by_gpt2(t_value)
|
||||
else:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(str(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
||||
"""
|
||||
Calculate num tokens for tool calling with tiktoken package.
|
||||
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
num_tokens = 0
|
||||
for tool in tools:
|
||||
num_tokens += self._get_num_tokens_by_gpt2("type")
|
||||
num_tokens += self._get_num_tokens_by_gpt2("function")
|
||||
num_tokens += self._get_num_tokens_by_gpt2("function")
|
||||
|
||||
# calculate num tokens for function object
|
||||
num_tokens += self._get_num_tokens_by_gpt2("name")
|
||||
num_tokens += self._get_num_tokens_by_gpt2(tool.name)
|
||||
num_tokens += self._get_num_tokens_by_gpt2("description")
|
||||
num_tokens += self._get_num_tokens_by_gpt2(tool.description)
|
||||
parameters = tool.parameters
|
||||
num_tokens += self._get_num_tokens_by_gpt2("parameters")
|
||||
if "title" in parameters:
|
||||
num_tokens += self._get_num_tokens_by_gpt2("title")
|
||||
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title"))
|
||||
num_tokens += self._get_num_tokens_by_gpt2("type")
|
||||
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type"))
|
||||
if "properties" in parameters:
|
||||
num_tokens += self._get_num_tokens_by_gpt2("properties")
|
||||
for key, value in parameters.get("properties").items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(key)
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
||||
if field_key == "enum":
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += self._get_num_tokens_by_gpt2(enum_field)
|
||||
else:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
||||
num_tokens += self._get_num_tokens_by_gpt2(str(field_value))
|
||||
if "required" in parameters:
|
||||
num_tokens += self._get_num_tokens_by_gpt2("required")
|
||||
for required_field in parameters["required"]:
|
||||
num_tokens += 3
|
||||
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
||||
|
||||
return num_tokens
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/mixtral-8x22b-instruct
|
||||
label:
|
||||
zh_Hans: Mixtral MoE 8x22B Instruct
|
||||
en_US: Mixtral MoE 8x22B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 65536
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '1.2'
|
||||
output: '1.2'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/mixtral-8x7b-instruct-hf
|
||||
label:
|
||||
zh_Hans: Mixtral MoE 8x7B Instruct(HF version)
|
||||
en_US: Mixtral MoE 8x7B Instruct(HF version)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.5'
|
||||
output: '0.5'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/mixtral-8x7b-instruct
|
||||
label:
|
||||
zh_Hans: Mixtral MoE 8x7B Instruct
|
||||
en_US: Mixtral MoE 8x7B Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.5'
|
||||
output: '0.5'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/mythomax-l2-13b
|
||||
label:
|
||||
zh_Hans: MythoMax L2 13b
|
||||
en_US: MythoMax L2 13b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.2'
|
||||
output: '0.2'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,46 @@
|
||||
model: accounts/fireworks/models/phi-3-vision-128k-instruct
|
||||
label:
|
||||
zh_Hans: Phi3.5 Vision Instruct
|
||||
en_US: Phi3.5 Vision Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.2'
|
||||
output: '0.2'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,45 @@
|
||||
model: accounts/yi-01-ai/models/yi-large
|
||||
label:
|
||||
zh_Hans: Yi-Large
|
||||
en_US: Yi-Large
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
- name: context_length_exceeded_behavior
|
||||
default: None
|
||||
label:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
help:
|
||||
zh_Hans: 上下文长度超出行为
|
||||
en_US: Context Length Exceeded Behavior
|
||||
type: string
|
||||
options:
|
||||
- None
|
||||
- truncate
|
||||
- error
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '3'
|
||||
output: '3'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,39 @@
|
||||
model: gemini-1.5-flash-8b-exp-0827
|
||||
label:
|
||||
en_US: Gemini 1.5 Flash 8B 0827
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,39 @@
|
||||
model: gemini-1.5-flash-exp-0827
|
||||
label:
|
||||
en_US: Gemini 1.5 Flash 0827
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,39 @@
|
||||
model: gemini-1.5-pro-exp-0801
|
||||
label:
|
||||
en_US: Gemini 1.5 Pro 0801
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,39 @@
|
||||
model: gemini-1.5-pro-exp-0827
|
||||
label:
|
||||
en_US: Gemini 1.5 Pro 0827
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -9,7 +9,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
@ -3,3 +3,4 @@
|
||||
- hunyuan-standard-256k
|
||||
- hunyuan-pro
|
||||
- hunyuan-turbo
|
||||
- hunyuan-vision
|
||||
|
@ -0,0 +1,39 @@
|
||||
model: hunyuan-vision
|
||||
label:
|
||||
zh_Hans: hunyuan-vision
|
||||
en_US: hunyuan-vision
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: enable_enhance
|
||||
label:
|
||||
zh_Hans: 功能增强
|
||||
en_US: Enable Enhancement
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
default: true
|
||||
pricing:
|
||||
input: '0.018'
|
||||
output: '0.018'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.exception import TencentCloudSDKException
|
||||
@ -11,9 +12,12 @@ from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@ -143,6 +147,25 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
tool_execute_result = {"result": message.content}
|
||||
content = json.dumps(tool_execute_result, ensure_ascii=False)
|
||||
dict_list.append({"Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id})
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
dict_list.append({"Role": message.role.value, "Content": message.content})
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_dict = {"Type": "text", "Text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"Type": "image_url",
|
||||
"ImageUrl": {"Url": message_content.data},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
dict_list.append({"Role": message.role.value, "Contents": sub_messages})
|
||||
else:
|
||||
dict_list.append({"Role": message.role.value, "Content": message.content})
|
||||
return dict_list
|
||||
|
@ -1,3 +1,8 @@
|
||||
- pixtral-12b-2409
|
||||
- codestral-latest
|
||||
- mistral-embed
|
||||
- open-mistral-nemo
|
||||
- open-codestral-mamba
|
||||
- open-mistral-7b
|
||||
- open-mixtral-8x7b
|
||||
- open-mixtral-8x22b
|
||||
|
@ -0,0 +1,51 @@
|
||||
model: codestral-latest
|
||||
label:
|
||||
zh_Hans: codestral-latest
|
||||
en_US: codestral-latest
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
@ -0,0 +1,51 @@
|
||||
model: mistral-embed
|
||||
label:
|
||||
zh_Hans: mistral-embed
|
||||
en_US: mistral-embed
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
@ -0,0 +1,51 @@
|
||||
model: open-codestral-mamba
|
||||
label:
|
||||
zh_Hans: open-codestral-mamba
|
||||
en_US: open-codestral-mamba
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 256000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
@ -0,0 +1,51 @@
|
||||
model: open-mistral-nemo
|
||||
label:
|
||||
zh_Hans: open-mistral-nemo
|
||||
en_US: open-mistral-nemo
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
@ -0,0 +1,51 @@
|
||||
model: pixtral-12b-2409
|
||||
label:
|
||||
zh_Hans: pixtral-12b-2409
|
||||
en_US: pixtral-12b-2409
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
@ -472,12 +472,13 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TEMPERATURE.value,
|
||||
use_template=DefaultParameterName.TEMPERATURE.value,
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="The temperature of the model. "
|
||||
"Increasing the temperature will make the model answer "
|
||||
"more creatively. (Default: 0.8)"
|
||||
"more creatively. (Default: 0.8)",
|
||||
zh_Hans="模型的温度。增加温度将使模型的回答更具创造性。(默认值:0.8)",
|
||||
),
|
||||
default=0.1,
|
||||
min=0,
|
||||
@ -486,12 +487,13 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_P.value,
|
||||
use_template=DefaultParameterName.TOP_P.value,
|
||||
label=I18nObject(en_US="Top P"),
|
||||
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||
"focused and conservative text. (Default: 0.9)"
|
||||
"focused and conservative text. (Default: 0.9)",
|
||||
zh_Hans="与top-k一起工作。较高的值(例如,0.95)会导致生成更多样化的文本,而较低的值(例如,0.5)会生成更专注和保守的文本。(默认值:0.9)",
|
||||
),
|
||||
default=0.9,
|
||||
min=0,
|
||||
@ -499,12 +501,13 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_k",
|
||||
label=I18nObject(en_US="Top K"),
|
||||
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Reduces the probability of generating nonsense. "
|
||||
"A higher value (e.g. 100) will give more diverse answers, "
|
||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
|
||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)",
|
||||
zh_Hans="减少生成无意义内容的可能性。较高的值(例如100)将提供更多样化的答案,而较低的值(例如10)将更为保守。(默认值:40)",
|
||||
),
|
||||
min=1,
|
||||
max=100,
|
||||
@ -516,7 +519,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
help=I18nObject(
|
||||
en_US="Sets how strongly to penalize repetitions. "
|
||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
|
||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)",
|
||||
zh_Hans="设置对重复内容的惩罚强度。一个较高的值(例如,1.5)会更强地惩罚重复内容,而一个较低的值(例如,0.9)则会相对宽容。(默认值:1.1)",
|
||||
),
|
||||
min=-2,
|
||||
max=2,
|
||||
@ -524,11 +528,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
ParameterRule(
|
||||
name="num_predict",
|
||||
use_template="max_tokens",
|
||||
label=I18nObject(en_US="Num Predict"),
|
||||
label=I18nObject(en_US="Num Predict", zh_Hans="最大令牌数预测"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Maximum number of tokens to predict when generating text. "
|
||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"
|
||||
"(Default: 128, -1 = infinite generation, -2 = fill context)",
|
||||
zh_Hans="生成文本时预测的最大令牌数。(默认值:128,-1 = 无限生成,-2 = 填充上下文)",
|
||||
),
|
||||
default=(512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128),
|
||||
min=-2,
|
||||
@ -536,121 +541,137 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
),
|
||||
ParameterRule(
|
||||
name="mirostat",
|
||||
label=I18nObject(en_US="Mirostat sampling"),
|
||||
label=I18nObject(en_US="Mirostat sampling", zh_Hans="Mirostat 采样"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Enable Mirostat sampling for controlling perplexity. "
|
||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
|
||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)",
|
||||
zh_Hans="启用 Mirostat 采样以控制困惑度。"
|
||||
"(默认值:0,0 = 禁用,1 = Mirostat,2 = Mirostat 2.0)",
|
||||
),
|
||||
min=0,
|
||||
max=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name="mirostat_eta",
|
||||
label=I18nObject(en_US="Mirostat Eta"),
|
||||
label=I18nObject(en_US="Mirostat Eta", zh_Hans="学习率"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Influences how quickly the algorithm responds to feedback from "
|
||||
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||
"while a higher learning rate will make the algorithm more responsive. "
|
||||
"(Default: 0.1)"
|
||||
"(Default: 0.1)",
|
||||
zh_Hans="影响算法对生成文本反馈响应的速度。较低的学习率会导致调整速度变慢,而较高的学习率会使得算法更加灵敏。(默认值:0.1)",
|
||||
),
|
||||
precision=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="mirostat_tau",
|
||||
label=I18nObject(en_US="Mirostat Tau"),
|
||||
label=I18nObject(en_US="Mirostat Tau", zh_Hans="文本连贯度"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Controls the balance between coherence and diversity of the output. "
|
||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"
|
||||
"A lower value will result in more focused and coherent text. (Default: 5.0)",
|
||||
zh_Hans="控制输出的连贯性和多样性之间的平衡。较低的值会导致更专注和连贯的文本。(默认值:5.0)",
|
||||
),
|
||||
precision=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="num_ctx",
|
||||
label=I18nObject(en_US="Size of context window"),
|
||||
label=I18nObject(en_US="Size of context window", zh_Hans="上下文窗口大小"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets the size of the context window used to generate the next token. (Default: 2048)"
|
||||
en_US="Sets the size of the context window used to generate the next token. (Default: 2048)",
|
||||
zh_Hans="设置用于生成下一个标记的上下文窗口大小。(默认值:2048)",
|
||||
),
|
||||
default=2048,
|
||||
min=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="num_gpu",
|
||||
label=I18nObject(en_US="GPU Layers"),
|
||||
label=I18nObject(en_US="GPU Layers", zh_Hans="GPU 层数"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="The number of layers to offload to the GPU(s). "
|
||||
"On macOS it defaults to 1 to enable metal support, 0 to disable."
|
||||
"As long as a model fits into one gpu it stays in one. "
|
||||
"It does not set the number of GPU(s). "
|
||||
"It does not set the number of GPU(s). ",
|
||||
zh_Hans="加载到 GPU 的层数。在 macOS 上,默认为 1 以启用 Metal 支持,设置为 0 则禁用。"
|
||||
"只要模型适合一个 GPU,它就保留在其中。它不设置 GPU 的数量。",
|
||||
),
|
||||
min=-1,
|
||||
default=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="num_thread",
|
||||
label=I18nObject(en_US="Num Thread"),
|
||||
label=I18nObject(en_US="Num Thread", zh_Hans="线程数"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets the number of threads to use during computation. "
|
||||
"By default, Ollama will detect this for optimal performance. "
|
||||
"It is recommended to set this value to the number of physical CPU cores "
|
||||
"your system has (as opposed to the logical number of cores)."
|
||||
"your system has (as opposed to the logical number of cores).",
|
||||
zh_Hans="设置计算过程中使用的线程数。默认情况下,Ollama会检测以获得最佳性能。建议将此值设置为系统拥有的物理CPU核心数(而不是逻辑核心数)。",
|
||||
),
|
||||
min=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="repeat_last_n",
|
||||
label=I18nObject(en_US="Repeat last N"),
|
||||
label=I18nObject(en_US="Repeat last N", zh_Hans="回溯内容"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets how far back for the model to look back to prevent repetition. "
|
||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"
|
||||
"(Default: 64, 0 = disabled, -1 = num_ctx)",
|
||||
zh_Hans="设置模型回溯多远的内容以防止重复。(默认值:64,0 = 禁用,-1 = num_ctx)",
|
||||
),
|
||||
min=-1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="tfs_z",
|
||||
label=I18nObject(en_US="TFS Z"),
|
||||
label=I18nObject(en_US="TFS Z", zh_Hans="减少标记影响"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(
|
||||
en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||
"while a value of 1.0 disables this setting. (default: 1)"
|
||||
"while a value of 1.0 disables this setting. (default: 1)",
|
||||
zh_Hans="用于减少输出中不太可能的标记的影响。较高的值(例如,2.0)会更多地减少这种影响,而1.0的值则会禁用此设置。(默认值:1)",
|
||||
),
|
||||
precision=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed"),
|
||||
label=I18nObject(en_US="Seed", zh_Hans="随机数种子"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(
|
||||
en_US="Sets the random number seed to use for generation. Setting this to "
|
||||
"a specific number will make the model generate the same text for "
|
||||
"the same prompt. (Default: 0)"
|
||||
"the same prompt. (Default: 0)",
|
||||
zh_Hans="设置用于生成的随机数种子。将此设置为特定数字将使模型对相同的提示生成相同的文本。(默认值:0)",
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name="keep_alive",
|
||||
label=I18nObject(en_US="Keep Alive"),
|
||||
label=I18nObject(en_US="Keep Alive", zh_Hans="模型存活时间"),
|
||||
type=ParameterType.STRING,
|
||||
help=I18nObject(
|
||||
en_US="Sets how long the model is kept in memory after generating a response. "
|
||||
"This must be a duration string with a unit (e.g., '10m' for 10 minutes or '24h' for 24 hours)."
|
||||
" A negative number keeps the model loaded indefinitely, and '0' unloads the model"
|
||||
" immediately after generating a response."
|
||||
" Valid time units are 's','m','h'. (Default: 5m)"
|
||||
" Valid time units are 's','m','h'. (Default: 5m)",
|
||||
zh_Hans="设置模型在生成响应后在内存中保留的时间。"
|
||||
"这必须是一个带有单位的持续时间字符串(例如,'10m' 表示10分钟,'24h' 表示24小时)。"
|
||||
"负数表示无限期地保留模型,'0'表示在生成响应后立即卸载模型。"
|
||||
"有效的时间单位有 's'(秒)、'm'(分钟)、'h'(小时)。(默认值:5m)",
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name="format",
|
||||
label=I18nObject(en_US="Format"),
|
||||
label=I18nObject(en_US="Format", zh_Hans="返回格式"),
|
||||
type=ParameterType.STRING,
|
||||
help=I18nObject(
|
||||
en_US="the format to return a response in. Currently the only accepted value is json."
|
||||
en_US="the format to return a response in. Currently the only accepted value is json.",
|
||||
zh_Hans="返回响应的格式。目前唯一接受的值是json。",
|
||||
),
|
||||
options=["json"],
|
||||
),
|
||||
|
@ -205,7 +205,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TEMPERATURE.value,
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
help=I18nObject(
|
||||
en_US="Kernel sampling threshold. Used to determine the randomness of the results."
|
||||
"The higher the value, the stronger the randomness."
|
||||
"The higher the possibility of getting different answers to the same question.",
|
||||
zh_Hans="核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。",
|
||||
),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("temperature", 0.7)),
|
||||
min=0,
|
||||
@ -214,7 +220,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_P.value,
|
||||
label=I18nObject(en_US="Top P"),
|
||||
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||
help=I18nObject(
|
||||
en_US="The probability threshold of the nucleus sampling method during the generation process."
|
||||
"The larger the value is, the higher the randomness of generation will be."
|
||||
"The smaller the value is, the higher the certainty of generation will be.",
|
||||
zh_Hans="生成过程中核采样方法概率阈值。取值越大,生成的随机性越高;取值越小,生成的确定性越高。",
|
||||
),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("top_p", 1)),
|
||||
min=0,
|
||||
@ -223,7 +235,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
||||
label=I18nObject(en_US="Frequency Penalty"),
|
||||
label=I18nObject(en_US="Frequency Penalty", zh_Hans="频率惩罚"),
|
||||
help=I18nObject(
|
||||
en_US="For controlling the repetition rate of words used by the model."
|
||||
"Increasing this can reduce the repetition of the same words in the model's output.",
|
||||
zh_Hans="用于控制模型已使用字词的重复率。 提高此项可以降低模型在输出中重复相同字词的重复度。",
|
||||
),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("frequency_penalty", 0)),
|
||||
min=-2,
|
||||
@ -231,7 +248,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.PRESENCE_PENALTY.value,
|
||||
label=I18nObject(en_US="Presence Penalty"),
|
||||
label=I18nObject(en_US="Presence Penalty", zh_Hans="存在惩罚"),
|
||||
help=I18nObject(
|
||||
en_US="Used to control the repetition rate when generating models."
|
||||
"Increasing this can reduce the repetition rate of model generation.",
|
||||
zh_Hans="用于控制模型生成时的重复度。提高此项可以降低模型生成的重复度。",
|
||||
),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("presence_penalty", 0)),
|
||||
min=-2,
|
||||
@ -239,7 +261,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.MAX_TOKENS.value,
|
||||
label=I18nObject(en_US="Max Tokens"),
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
||||
help=I18nObject(
|
||||
en_US="Maximum length of tokens for the model response.", zh_Hans="模型回答的tokens的最大长度。"
|
||||
),
|
||||
type=ParameterType.INT,
|
||||
default=512,
|
||||
min=1,
|
||||
|
@ -1,3 +1,5 @@
|
||||
- openai/o1-preview
|
||||
- openai/o1-mini
|
||||
- openai/gpt-4o
|
||||
- openai/gpt-4o-mini
|
||||
- openai/gpt-4
|
||||
|
@ -1,7 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
@ -26,7 +26,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._update_credential(model, credentials)
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._update_credential(model, credentials)
|
||||
@ -46,7 +46,48 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._update_credential(model, credentials)
|
||||
|
||||
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
block_as_stream = False
|
||||
if model.startswith("openai/o1"):
|
||||
block_as_stream = True
|
||||
stop = None
|
||||
|
||||
# invoke block as stream
|
||||
if stream and block_as_stream:
|
||||
return self._generate_block_as_stream(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, user
|
||||
)
|
||||
else:
|
||||
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _generate_block_as_stream(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Generator:
|
||||
resp: LLMResult = super()._generate(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, False, user
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=resp.message,
|
||||
usage=self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=resp.usage.prompt_tokens,
|
||||
completion_tokens=resp.usage.completion_tokens,
|
||||
),
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
self._update_credential(model, credentials)
|
||||
|
@ -0,0 +1,40 @@
|
||||
model: openai/o1-mini
|
||||
label:
|
||||
en_US: o1-mini
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 65536
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: "3.00"
|
||||
output: "12.00"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
@ -0,0 +1,40 @@
|
||||
model: openai/o1-preview
|
||||
label:
|
||||
en_US: o1-preview
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 32768
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: "15.00"
|
||||
output: "60.00"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
@ -59,3 +59,4 @@ pricing:
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
@ -59,3 +59,4 @@ pricing:
|
||||
output: "0.000"
|
||||
unit: "0.000"
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user