diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 7b5ed7ddd7..0e1c9d6927 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -10,9 +10,33 @@ concurrency: cancel-in-progress: true jobs: + python-style: + name: Python Style + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Python dependencies + run: pip install ruff + + - name: Ruff check + run: ruff check ./api + + - name: Lint hints + if: failure() + run: echo "Please run 'dev/reformat' to fix the fixable linting errors." + test: name: ESLint and SuperLinter runs-on: ubuntu-latest + needs: python-style steps: - name: Checkout code diff --git a/api/app.py b/api/app.py index cb3b226e4c..255c1dbc05 100644 --- a/api/app.py +++ b/api/app.py @@ -19,18 +19,28 @@ import threading import time import warnings -from commands import register_commands -from config import CloudEditionConfig, Config -from events import event_handlers -from extensions import (ext_celery, ext_code_based_extension, ext_database, ext_hosting_provider, ext_login, ext_mail, - ext_migrate, ext_redis, ext_sentry, ext_storage) -from extensions.ext_database import db -from extensions.ext_login import login_manager from flask import Flask, Response, request from flask_cors import CORS + +from commands import register_commands +from config import CloudEditionConfig, Config +from extensions import ( + ext_celery, + ext_code_based_extension, + ext_database, + ext_hosting_provider, + ext_login, + ext_mail, + ext_migrate, + ext_redis, + ext_sentry, + ext_storage, +) +from extensions.ext_database import db +from extensions.ext_login import login_manager from libs.passport import PassportService + # DO NOT REMOVE BELOW -from models import account, dataset, model, source, task, tool, tools, web from services.account_service import AccountService # DO NOT REMOVE ABOVE diff --git a/api/commands.py b/api/commands.py index b44f166926..91b50445e6 100644 --- a/api/commands.py +++ b/api/commands.py @@ -3,11 +3,13 @@ import json import secrets import click +from flask import current_app +from werkzeug.exceptions import NotFound + from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db -from flask import current_app from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair @@ -15,7 +17,6 @@ from models.account import Tenant from models.dataset import Dataset from models.model import Account from models.provider import Provider, ProviderModel -from werkzeug.exceptions import NotFound @click.command('reset-password', help='Reset the account password.') diff --git a/api/constants/model_template.py b/api/constants/model_template.py index 5ec0f3125e..5b9a09fd9b 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,7 +1,5 @@ import json -from models.model import App, AppModelConfig - model_templates = { # completion default mode 'completion_default': { diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 4b8d27434d..aaa737f83a 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,14 +1,15 @@ import os from functools import wraps +from flask import request +from flask_restful import Resource, reqparse +from werkzeug.exceptions import NotFound, Unauthorized + from constants.languages import supported_language from controllers.console import api from controllers.console.wraps import only_edition_cloud from extensions.ext_database import db -from flask import request -from flask_restful import Resource, reqparse from models.model import App, InstalledApp, RecommendedApp -from werkzeug.exceptions import NotFound, Unauthorized def admin_required(view): diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index b8dd1ed5bf..324b831175 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,12 +1,13 @@ import flask_restful -from extensions.ext_database import db from flask_login import current_user from flask_restful import Resource, fields, marshal_with +from werkzeug.exceptions import Forbidden + +from extensions.ext_database import db from libs.helper import TimestampField from libs.login import login_required from models.dataset import Dataset from models.model import ApiToken, App -from werkzeug.exceptions import Forbidden from . import api from .setup import setup_required diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c7693fb950..fa2b3807e8 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,7 +1,8 @@ +from flask_restful import Resource, reqparse + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from flask_restful import Resource, reqparse from libs.login import login_required from services.advanced_prompt_template_service import AdvancedPromptTemplateService diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 8c7cae9519..1ac8e60dcd 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,17 +1,20 @@ +from flask import request +from flask_login import current_user +from flask_restful import Resource, marshal, marshal_with, reqparse +from werkzeug.exceptions import Forbidden + from controllers.console import api from controllers.console.app.error import NoFileUploadedError from controllers.console.datasets.error import TooManyFilesError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_redis import redis_client -from fields.annotation_fields import (annotation_fields, annotation_hit_history_fields, - annotation_hit_history_list_fields, annotation_list_fields) -from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from fields.annotation_fields import ( + annotation_fields, + annotation_hit_history_fields, +) from libs.login import login_required from services.annotation_service import AppAnnotationService -from werkzeug.exceptions import Forbidden class AnnotationReplyActionApi(Resource): diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 2aac27af3e..5036d2074d 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -3,6 +3,10 @@ import json import logging from datetime import datetime +from flask_login import current_user +from flask_restful import Resource, abort, inputs, marshal_with, reqparse +from werkzeug.exceptions import Forbidden + from constants.languages import demo_model_templates, languages from constants.model_template import model_templates from controllers.console import api @@ -15,16 +19,15 @@ from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from events.app_event import app_was_created, app_was_deleted from extensions.ext_database import db -from fields.app_fields import (app_detail_fields, app_detail_fields_with_site, app_pagination_fields, - template_list_fields) -from flask import current_app -from flask_login import current_user -from flask_restful import Resource, abort, inputs, marshal_with, reqparse +from fields.app_fields import ( + app_detail_fields, + app_detail_fields_with_site, + app_pagination_fields, + template_list_fields, +) from libs.login import login_required from models.model import App, AppModelConfig, Site -from models.tools import ApiToolProvider from services.app_model_config_service import AppModelConfigService -from werkzeug.exceptions import Forbidden def _get_app(app_id, tenant_id): diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index d7a4f3e3e0..775b3315a8 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,24 +1,36 @@ # -*- coding:utf-8 -*- import logging +from flask import request +from flask_restful import Resource +from werkzeug.exceptions import InternalServerError + import services from controllers.console import api from controllers.console.app import _get_app -from controllers.console.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, - NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, - ProviderQuotaExceededError, UnsupportedAudioTypeError) +from controllers.console.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import request -from flask_restful import Resource from libs.login import login_required from services.audio_service import AudioService -from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, - ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) -from werkzeug.exceptions import InternalServerError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) class ChatMessageAudioApi(Resource): diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index b530a9ee2f..be8d3bf082 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -4,24 +4,30 @@ import logging from typing import Generator, Union import flask_login +from flask import Response, stream_with_context +from flask_restful import Resource, reqparse +from werkzeug.exceptions import InternalServerError, NotFound + import services from controllers.console import api from controllers.console.app import _get_app -from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, - ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) +from controllers.console.app.error import ( + AppUnavailableError, + CompletionRequestError, + ConversationCompletedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.application_queue_manager import ApplicationQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import Response, stream_with_context -from flask_restful import Resource, reqparse from libs.helper import uuid_value from libs.login import login_required from services.completion_service import CompletionService -from werkzeug.exceptions import InternalServerError, NotFound # define completion message api for user diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index f159f74c71..452b0fddf6 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,22 +1,27 @@ from datetime import datetime import pytz +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from flask_restful.inputs import int_range +from sqlalchemy import func, or_ +from sqlalchemy.orm import joinedload +from werkzeug.exceptions import NotFound + from controllers.console import api from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from fields.conversation_fields import (conversation_detail_fields, conversation_message_detail_fields, - conversation_pagination_fields, conversation_with_summary_pagination_fields) -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from fields.conversation_fields import ( + conversation_detail_fields, + conversation_message_detail_fields, + conversation_pagination_fields, + conversation_with_summary_pagination_fields, +) from libs.helper import datetime_string from libs.login import login_required from models.model import Conversation, Message, MessageAnnotation -from sqlalchemy import func, or_ -from sqlalchemy.orm import joinedload -from werkzeug.exceptions import NotFound class CompletionConversationApi(Resource): diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index d7a320db99..3ec932b5f1 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,13 +1,18 @@ +from flask_login import current_user +from flask_restful import Resource, reqparse + from controllers.console import api -from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderQuotaExceededError) +from controllers.console.app.error import ( + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError -from flask_login import current_user -from flask_restful import Resource, reqparse from libs.login import login_required diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 50b4e2d983..d29d826b69 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -2,11 +2,21 @@ import json import logging from typing import Generator, Union +from flask import Response, stream_with_context +from flask_login import current_user +from flask_restful import Resource, fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + from controllers.console import api from controllers.console.app import _get_app -from controllers.console.app.error import (AppMoreLikeThisDisabledError, CompletionRequestError, - ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) +from controllers.console.app.error import ( + AppMoreLikeThisDisabledError, + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.entities.application_entities import InvokeFrom @@ -14,10 +24,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from fields.conversation_fields import annotation_fields, message_detail_fields -from flask import Response, stream_with_context -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import login_required @@ -28,7 +34,6 @@ from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError from services.message_service import MessageService -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound class ChatMessageListApi(Resource): diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index d447bfa756..fd526b393d 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,14 +1,15 @@ # -*- coding:utf-8 -*- +from flask import request +from flask_login import current_user +from flask_restful import Resource + from controllers.console import api from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from events.app_event import app_model_config_was_updated from extensions.ext_database import db -from flask import request -from flask_login import current_user -from flask_restful import Resource from libs.login import login_required from models.model import AppModelConfig from services.app_model_config_service import AppModelConfigService diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index daba012bd9..8d6231cbac 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,4 +1,8 @@ # -*- coding:utf-8 -*- +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from werkzeug.exceptions import Forbidden, NotFound + from constants.languages import supported_language from controllers.console import api from controllers.console.app import _get_app @@ -6,11 +10,8 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from fields.app_fields import app_site_fields -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse from libs.login import login_required from models.model import Site -from werkzeug.exceptions import Forbidden, NotFound def parse_app_site_args(): diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index f2c1726433..d6ced934a7 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -3,14 +3,15 @@ from datetime import datetime from decimal import Decimal import pytz +from flask import jsonify +from flask_login import current_user +from flask_restful import Resource, reqparse + from controllers.console import api from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from flask import jsonify -from flask_login import current_user -from flask_restful import Resource, reqparse from libs.helper import datetime_string from libs.login import login_required diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 0b3672efc9..20e028af99 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -2,14 +2,15 @@ import base64 import secrets from datetime import datetime +from flask_restful import Resource, reqparse + from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db -from flask_restful import Resource, reqparse from libs.helper import email, str_len, timezone from libs.password import hash_password, valid_password -from models.account import AccountStatus, Tenant +from models.account import AccountStatus from services.account_service import RegisterService diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index d0b28c6d4b..293ec1c4d3 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -1,13 +1,14 @@ import logging import requests -from controllers.console import api from flask import current_app, redirect, request from flask_login import current_user from flask_restful import Resource +from werkzeug.exceptions import Forbidden + +from controllers.console import api from libs.login import login_required from libs.oauth_data_source import NotionOAuth -from werkzeug.exceptions import Forbidden from ..setup import setup_required from ..wraps import account_initialization_required diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 2c8fdeeaf5..646f672c72 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,11 +1,11 @@ # -*- coding:utf-8 -*- -import flask import flask_login +from flask import current_app, request +from flask_restful import Resource, reqparse + import services from controllers.console import api from controllers.console.setup import setup_required -from flask import current_app, request -from flask_restful import Resource, reqparse from libs.helper import email from libs.password import valid_password from services.account_service import AccountService diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index b7d4e51910..96765d189a 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -3,10 +3,11 @@ from datetime import datetime from typing import Optional import requests -from constants.languages import languages -from extensions.ext_database import db from flask import current_app, redirect, request from flask_restful import Resource + +from constants.languages import languages +from extensions.ext_database import db from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models.account import Account, AccountStatus from services.account_service import AccountService, RegisterService diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 71de01c779..72a6129efa 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,8 +1,9 @@ +from flask_login import current_user +from flask_restful import Resource, reqparse + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud -from flask_login import current_user -from flask_restful import Resource, reqparse from libs.login import login_required from services.billing_service import BillingService diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index a9ecd3d27d..86fcf704c7 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,6 +1,11 @@ import datetime import json +from flask import request +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from werkzeug.exceptions import NotFound + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required @@ -8,15 +13,11 @@ from core.data_loader.loader.notion import NotionLoader from core.indexing_runner import IndexingRunner from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields -from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse from libs.login import login_required from models.dataset import Document from models.source import DataSourceBinding from services.dataset_service import DatasetService, DocumentService from tasks.document_indexing_sync_task import document_indexing_sync_task -from werkzeug.exceptions import NotFound class DataSourceApi(Resource): diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 01700ea63b..a6d869593b 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,5 +1,10 @@ # -*- coding:utf-8 -*- import flask_restful +from flask import current_app, request +from flask_login import current_user +from flask_restful import Resource, marshal, marshal_with, reqparse +from werkzeug.exceptions import Forbidden, NotFound + import services from controllers.console import api from controllers.console.apikey import api_key_fields, api_key_list @@ -15,14 +20,10 @@ from extensions.ext_database import db from fields.app_fields import related_app_list from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.document_fields import document_status_fields -from flask import current_app, request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse from libs.login import login_required from models.dataset import Dataset, Document, DocumentSegment from models.model import ApiToken, UploadFile from services.dataset_service import DatasetService, DocumentService -from werkzeug.exceptions import Forbidden, NotFound def _validate_name(name): diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 586bbafbb0..88bbd25645 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -2,35 +2,52 @@ from datetime import datetime from typing import List +from flask import request +from flask_login import current_user +from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from sqlalchemy import asc, desc +from werkzeug.exceptions import Forbidden, NotFound + import services from controllers.console import api -from controllers.console.app.error import (ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) -from controllers.console.datasets.error import (ArchivedDocumentImmutableError, DocumentAlreadyFinishedError, - DocumentIndexingError, InvalidActionError, InvalidMetadataError) +from controllers.console.app.error import ( + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.datasets.error import ( + ArchivedDocumentImmutableError, + DocumentAlreadyFinishedError, + DocumentIndexingError, + InvalidActionError, + InvalidMetadataError, +) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.errors.error import (LLMBadRequestError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError, - QuotaExceededError) +from core.errors.error import ( + LLMBadRequestError, + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from extensions.ext_redis import redis_client -from fields.document_fields import (dataset_and_document_fields, document_fields, document_status_fields, - document_with_segments_fields) -from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from fields.document_fields import ( + dataset_and_document_fields, + document_fields, + document_status_fields, + document_with_segments_fields, +) from libs.login import login_required from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment from models.model import UploadFile from services.dataset_service import DatasetService, DocumentService -from sqlalchemy import asc, desc from tasks.add_document_to_index_task import add_document_to_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task -from werkzeug.exceptions import Forbidden, NotFound class DocumentResource(Resource): diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 8de5bc91d7..9cfc5ad796 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -3,6 +3,11 @@ import uuid from datetime import datetime import pandas as pd +from flask import request +from flask_login import current_user +from flask_restful import Resource, marshal, reqparse +from werkzeug.exceptions import Forbidden, NotFound + import services from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError @@ -15,16 +20,12 @@ from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import segment_fields -from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse from libs.login import login_required from models.dataset import DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task from tasks.enable_segment_to_index_task import enable_segment_to_index_task -from werkzeug.exceptions import Forbidden, NotFound class DatasetDocumentSegmentListApi(Resource): diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index c15b3c0cd4..0eba232289 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -1,13 +1,18 @@ -import services -from controllers.console import api -from controllers.console.datasets.error import (FileTooLargeError, NoFileUploadedError, TooManyFilesError, - UnsupportedFileTypeError) -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required -from fields.file_fields import file_fields, upload_config_fields from flask import current_app, request from flask_login import current_user from flask_restful import Resource, marshal_with + +import services +from controllers.console import api +from controllers.console.datasets.error import ( + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from fields.file_fields import file_fields, upload_config_fields from libs.login import login_required from services.file_service import ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS, FileService diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index a32a3217e5..4738566241 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,22 +1,31 @@ import logging +from flask_login import current_user +from flask_restful import Resource, marshal, reqparse +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + import services from controllers.console import api -from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderQuotaExceededError) +from controllers.console.app.error import ( + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.console.datasets.error import DatasetNotInitializedError, HighQualityDatasetOnlyError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.errors.error import (LLMBadRequestError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError, - QuotaExceededError) +from core.errors.error import ( + LLMBadRequestError, + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) from core.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse from libs.login import login_required from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound class HitTestingApi(Resource): diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 784c0c6330..48d58524bb 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,21 +1,33 @@ # -*- coding:utf-8 -*- import logging +from flask import request +from werkzeug.exceptions import InternalServerError + import services from controllers.console import api -from controllers.console.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, - NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, - ProviderQuotaExceededError, UnsupportedAudioTypeError) +from controllers.console.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import request from models.model import AppModelConfig from services.audio_service import AudioService -from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, - ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) -from werkzeug.exceptions import InternalServerError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) class ChatAudioApi(InstalledAppResource): diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index b608130307..924578f7b4 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -4,11 +4,21 @@ import logging from datetime import datetime from typing import Generator, Union +from flask import Response, stream_with_context +from flask_login import current_user +from flask_restful import reqparse +from werkzeug.exceptions import InternalServerError, NotFound + import services from controllers.console import api -from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, - ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) +from controllers.console.app.error import ( + AppUnavailableError, + CompletionRequestError, + ConversationCompletedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from core.application_queue_manager import ApplicationQueueManager @@ -16,12 +26,8 @@ from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db -from flask import Response, stream_with_context -from flask_login import current_user -from flask_restful import reqparse from libs.helper import uuid_value from services.completion_service import CompletionService -from werkzeug.exceptions import InternalServerError, NotFound # define completion api for user diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 1b6b493671..8a3fb3a205 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,16 +1,17 @@ # -*- coding:utf-8 -*- +from flask_login import current_user +from flask_restful import marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range -from libs.helper import TimestampField, uuid_value +from libs.helper import uuid_value from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService -from werkzeug.exceptions import NotFound class ConversationListApi(InstalledAppResource): diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 44c54427a4..6e914ef3a4 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,18 +1,19 @@ # -*- coding:utf-8 -*- from datetime import datetime +from flask_login import current_user +from flask_restful import Resource, inputs, marshal_with, reqparse +from sqlalchemy import and_ +from werkzeug.exceptions import BadRequest, Forbidden, NotFound + from controllers.console import api from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields -from flask_login import current_user -from flask_restful import Resource, inputs, marshal_with, reqparse from libs.login import login_required from models.model import App, InstalledApp, RecommendedApp from services.account_service import TenantService -from sqlalchemy import and_ -from werkzeug.exceptions import BadRequest, Forbidden, NotFound class InstalledAppsListApi(Resource): diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 684ecd8b28..75c3cdd5c4 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -3,29 +3,37 @@ import json import logging from typing import Generator, Union +from flask import Response, stream_with_context +from flask_login import current_user +from flask_restful import marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import InternalServerError, NotFound + import services from controllers.console import api -from controllers.console.app.error import (AppMoreLikeThisDisabledError, CompletionRequestError, - ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) -from controllers.console.explore.error import (AppSuggestedQuestionsAfterAnswerDisabledError, NotChatAppError, - NotCompletionAppError) +from controllers.console.app.error import ( + AppMoreLikeThisDisabledError, + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.explore.error import ( + AppSuggestedQuestionsAfterAnswerDisabledError, + NotChatAppError, + NotCompletionAppError, +) from controllers.console.explore.wraps import InstalledAppResource from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields -from flask import Response, stream_with_context -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range -from libs.helper import TimestampField, uuid_value +from libs.helper import uuid_value from services.completion_service import CompletionService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService -from werkzeug.exceptions import InternalServerError, NotFound class MessageListApi(InstalledAppResource): diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index c073ebad01..f37bf3e1e5 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,11 +1,12 @@ # -*- coding:utf-8 -*- import json +from flask import current_app +from flask_restful import fields, marshal_with + from controllers.console import api from controllers.console.explore.wraps import InstalledAppResource from extensions.ext_database import db -from flask import current_app -from flask_restful import fields, marshal_with from models.model import AppModelConfig, InstalledApp from models.tools import ApiToolProvider diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 4ce8fbfbe9..3c2c806664 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,15 +1,16 @@ # -*- coding:utf-8 -*- +from flask_login import current_user +from flask_restful import Resource, fields, marshal_with +from sqlalchemy import and_ + from constants.languages import languages from controllers.console import api from controllers.console.app.error import AppNotFoundError from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with from libs.login import login_required from models.model import App, InstalledApp, RecommendedApp from services.account_service import TenantService -from sqlalchemy import and_ app_fields = { 'id': fields.String, diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 9d355df355..cf86b2fee1 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,14 +1,15 @@ +from flask_login import current_user +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + from controllers.console import api from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import message_file_fields -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range from libs.helper import TimestampField, uuid_value from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -from werkzeug.exceptions import NotFound feedback_fields = { 'rating': fields.String diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index d02b869bf7..84890f1b46 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,12 +1,13 @@ from functools import wraps +from flask_login import current_user +from flask_restful import Resource +from werkzeug.exceptions import NotFound + from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from flask_login import current_user -from flask_restful import Resource from libs.login import login_required from models.model import InstalledApp -from werkzeug.exceptions import NotFound def installed_app_required(view=None): diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 78374cf2a9..fa73c44c22 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,9 +1,10 @@ +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from fields.api_based_extension_fields import api_based_extension_fields -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse from libs.login import login_required from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 40f86fc235..824549050f 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,5 +1,6 @@ from flask_login import current_user from flask_restful import Resource + from services.feature_service import FeatureService from . import api diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index d1994a84c9..b319f706b4 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -2,6 +2,7 @@ import os from flask import current_app, session from flask_restful import Resource, reqparse + from libs.helper import str_len from models.model import DifySetup from services.account_service import TenantService diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 765161ff9d..58c2853470 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,9 +1,10 @@ # -*- coding:utf-8 -*- from functools import wraps -from extensions.ext_database import db from flask import current_app, request from flask_restful import Resource, reqparse + +from extensions.ext_database import db from libs.helper import email, str_len from libs.password import valid_password from models.model import DifySetup diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index ba49506618..519fa25516 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -6,7 +6,6 @@ import logging import requests from flask import current_app from flask_restful import Resource, reqparse -from werkzeug.exceptions import InternalServerError from . import api diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 1f856394e2..c511c9778b 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -2,16 +2,21 @@ from datetime import datetime import pytz -from constants.languages import supported_language -from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.workspace.error import (AccountAlreadyInitedError, CurrentPasswordIncorrectError, - InvalidInvitationCodeError, RepeatPasswordNotMatchError) -from controllers.console.wraps import account_initialization_required -from extensions.ext_database import db from flask import current_app, request from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse + +from constants.languages import supported_language +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.workspace.error import ( + AccountAlreadyInitedError, + CurrentPasswordIncorrectError, + InvalidInvitationCodeError, + RepeatPasswordNotMatchError, +) +from controllers.console.wraps import account_initialization_required +from extensions.ext_database import db from libs.helper import TimestampField, timezone from libs.login import login_required from models.account import AccountIntegrate, InvitationCode diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 84be878545..1b7d08a879 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,12 +1,13 @@ # -*- coding:utf-8 -*- +from flask import current_app +from flask_login import current_user +from flask_restful import Resource, abort, fields, marshal_with, reqparse + import services from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from flask import current_app -from flask_login import current_user -from flask_restful import Resource, abort, fields, marshal_with, reqparse from libs.helper import TimestampField from libs.login import login_required from models.account import Account diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index cb76e5cdd2..c888159f83 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,18 +1,19 @@ import io +from flask import send_file +from flask_login import current_user +from flask_restful import Resource, reqparse +from werkzeug.exceptions import Forbidden + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder -from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse from libs.login import login_required from services.billing_service import BillingService from services.model_provider_service import ModelProviderService -from werkzeug.exceptions import Forbidden class ModelProviderListApi(Resource): diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 305c9f09af..5745c0d408 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,16 +1,17 @@ import logging +from flask_login import current_user +from flask_restful import Resource, reqparse +from werkzeug.exceptions import Forbidden + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder -from flask_login import current_user -from flask_restful import Resource, reqparse from libs.login import login_required from services.model_provider_service import ModelProviderService -from werkzeug.exceptions import Forbidden class DefaultModelApi(Resource): diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index fb42146eee..c2c5286d51 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,15 +1,15 @@ import io -import json + +from flask import send_file +from flask_login import current_user +from flask_restful import Resource, reqparse +from werkzeug.exceptions import Forbidden from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse from libs.login import login_required from services.tools_manage_service import ToolManageService -from werkzeug.exceptions import Forbidden class ToolProviderListApi(Resource): diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 8f00d76f7a..dbeb712bc2 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,18 +1,23 @@ # -*- coding:utf-8 -*- import logging +from flask import request +from flask_login import current_user +from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse + import services from controllers.console import api from controllers.console.admin import admin_required -from controllers.console.datasets.error import (FileTooLargeError, NoFileUploadedError, TooManyFilesError, - UnsupportedFileTypeError) +from controllers.console.datasets.error import ( + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) from controllers.console.error import AccountNotLinkTenantError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse from libs.helper import TimestampField from libs.login import login_required from models.account import Tenant diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 7bfb064a23..1e20265c4b 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -2,9 +2,10 @@ import json from functools import wraps -from controllers.console.workspace.error import AccountNotInitializedError from flask import abort, current_app, request from flask_login import current_user + +from controllers.console.workspace.error import AccountNotInitializedError from services.feature_service import FeatureService from services.operation_service import OperationService diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 4227f139dd..66b9eee0de 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,11 +1,12 @@ -import services -from controllers.files import api from flask import Response, request from flask_restful import Resource +from werkzeug.exceptions import NotFound + +import services +from controllers.files import api from libs.exception import BaseHTTPException from services.account_service import TenantService from services.file_service import FileService -from werkzeug.exceptions import NotFound class ImagePreviewApi(Resource): diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index b4a290ec87..ecafd7b231 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -1,10 +1,11 @@ -from controllers.files import api -from core.tools.tool_file_manager import ToolFileManager from flask import Response from flask_restful import Resource, reqparse -from libs.exception import BaseHTTPException from werkzeug.exceptions import Forbidden, NotFound +from controllers.files import api +from core.tools.tool_file_manager import ToolFileManager +from libs.exception import BaseHTTPException + class ToolFilePreviewApi(Resource): def get(self, file_id, extension): diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 63591f8f49..8e1ecebce7 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,11 +1,12 @@ # -*- coding:utf-8 -*- import json +from flask import current_app +from flask_restful import fields, marshal_with + from controllers.service_api import api from controllers.service_api.wraps import AppApiResource from extensions.ext_database import db -from flask import current_app -from flask_restful import fields, marshal_with from models.model import App, AppModelConfig from models.tools import ApiToolProvider diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 9c5ae9a836..574fc55454 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,21 +1,33 @@ import logging +from flask import request +from flask_restful import reqparse +from werkzeug.exceptions import InternalServerError + import services from controllers.service_api import api -from controllers.service_api.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, - NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, - ProviderQuotaExceededError, UnsupportedAudioTypeError) +from controllers.service_api.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) from controllers.service_api.wraps import AppApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import request -from flask_restful import reqparse from models.model import App, AppModelConfig from services.audio_service import AudioService -from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, - ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) -from werkzeug.exceptions import InternalServerError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) class AudioApi(AppApiResource): diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index cc1ad0888e..d47bb089dc 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -2,22 +2,29 @@ import json import logging from typing import Generator, Union +from flask import Response, stream_with_context +from flask_restful import reqparse +from werkzeug.exceptions import InternalServerError, NotFound + import services from controllers.service_api import api from controllers.service_api.app import create_or_update_end_user_for_user_id -from controllers.service_api.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, - NotChatAppError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderQuotaExceededError) +from controllers.service_api.app.error import ( + AppUnavailableError, + CompletionRequestError, + ConversationCompletedError, + NotChatAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.service_api.wraps import AppApiResource from core.application_queue_manager import ApplicationQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import Response, stream_with_context -from flask_restful import reqparse from libs.helper import uuid_value from services.completion_service import CompletionService -from werkzeug.exceptions import InternalServerError, NotFound class CompletionApi(AppApiResource): diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 604e2f93db..d275552d0b 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,16 +1,17 @@ # -*- coding:utf-8 -*- +from flask import request +from flask_restful import marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + import services from controllers.service_api import api from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import AppApiResource from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields -from flask import request -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range -from libs.helper import TimestampField, uuid_value +from libs.helper import uuid_value from services.conversation_service import ConversationService -from werkzeug.exceptions import NotFound class ConversationApi(AppApiResource): diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 8e7984ced1..a901375ec0 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,12 +1,17 @@ +from flask import request +from flask_restful import marshal_with + import services from controllers.service_api import api from controllers.service_api.app import create_or_update_end_user_for_user_id -from controllers.service_api.app.error import (FileTooLargeError, NoFileUploadedError, TooManyFilesError, - UnsupportedFileTypeError) +from controllers.service_api.app.error import ( + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) from controllers.service_api.wraps import AppApiResource from fields.file_fields import file_fields -from flask import request -from flask_restful import marshal_with from services.file_service import FileService diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index c90a1fb1e2..a0257b3ed5 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,4 +1,8 @@ # -*- coding:utf-8 -*- +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + import services from controllers.service_api import api from controllers.service_api.app import create_or_update_end_user_for_user_id @@ -6,12 +10,9 @@ from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import AppApiResource from extensions.ext_database import db from fields.conversation_fields import message_file_fields -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range from libs.helper import TimestampField, uuid_value from models.model import EndUser, Message from services.message_service import MessageService -from werkzeug.exceptions import NotFound class MessageListApi(AppApiResource): diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 900a796674..60c7ca4549 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,3 +1,6 @@ +from flask import request +from flask_restful import marshal, reqparse + import services.dataset_service from controllers.service_api import api from controllers.service_api.dataset.error import DatasetNameDuplicateError @@ -5,8 +8,6 @@ from controllers.service_api.wraps import DatasetApiResource from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields -from flask import request -from flask_restful import marshal, reqparse from libs.login import current_user from models.dataset import Dataset from services.dataset_service import DatasetService diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d7694070f0..c997edc234 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,23 +1,28 @@ import json +from flask import request +from flask_login import current_user +from flask_restful import marshal, reqparse +from sqlalchemy import desc +from werkzeug.exceptions import NotFound + import services.dataset_service from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError -from controllers.service_api.dataset.error import (ArchivedDocumentImmutableError, DocumentIndexingError, - NoFileUploadedError, TooManyFilesError) +from controllers.service_api.dataset.error import ( + ArchivedDocumentImmutableError, + DocumentIndexingError, + NoFileUploadedError, + TooManyFilesError, +) from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from core.errors.error import ProviderTokenNotInitError from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields -from flask import request -from flask_login import current_user -from flask_restful import marshal, reqparse from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment from services.dataset_service import DocumentService from services.file_service import FileService -from sqlalchemy import desc -from werkzeug.exceptions import NotFound class DocumentAddByTextApi(DatasetApiResource): diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 4cc313e042..495d2f56a5 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,3 +1,7 @@ +from flask_login import current_user +from flask_restful import marshal, reqparse +from werkzeug.exceptions import NotFound + from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check @@ -6,11 +10,8 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import segment_fields -from flask_login import current_user -from flask_restful import marshal, reqparse from models.dataset import Dataset, DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService -from werkzeug.exceptions import NotFound class SegmentApi(DatasetApiResource): diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index 489018cf9b..932388b562 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -1,7 +1,8 @@ -from controllers.service_api import api from flask import current_app from flask_restful import Resource +from controllers.service_api import api + class IndexApi(Resource): def get(self): diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 60e573ec93..0cc63a2ad3 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -2,15 +2,16 @@ from datetime import datetime from functools import wraps -from extensions.ext_database import db from flask import current_app, request from flask_login import user_logged_in from flask_restful import Resource +from werkzeug.exceptions import NotFound, Unauthorized + +from extensions.ext_database import db from libs.login import _get_user from models.account import Account, Tenant, TenantAccountJoin from models.model import ApiToken, App from services.feature_service import FeatureService -from werkzeug.exceptions import NotFound, Unauthorized def validate_app_token(view=None): diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 82a9ad8683..4dc15b9bee 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,11 +1,12 @@ # -*- coding:utf-8 -*- import json +from flask import current_app +from flask_restful import fields, marshal_with + from controllers.web import api from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from flask import current_app -from flask_restful import fields, marshal_with from models.model import App, AppModelConfig from models.tools import ApiToolProvider diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 44ca7b660a..b3d7280b64 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -1,21 +1,33 @@ # -*- coding:utf-8 -*- import logging +from flask import request +from werkzeug.exceptions import InternalServerError + import services from controllers.web import api -from controllers.web.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, - NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, - ProviderQuotaExceededError, UnsupportedAudioTypeError) +from controllers.web.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import request from models.model import App, AppModelConfig from services.audio_service import AudioService -from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, - ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) -from werkzeug.exceptions import InternalServerError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) class AudioApi(WebApiResource): diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index af571f1ff7..c61995b72c 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -3,21 +3,29 @@ import json import logging from typing import Generator, Union +from flask import Response, stream_with_context +from flask_restful import reqparse +from werkzeug.exceptions import InternalServerError, NotFound + import services from controllers.web import api -from controllers.web.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, - NotChatAppError, NotCompletionAppError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderQuotaExceededError) +from controllers.web.error import ( + AppUnavailableError, + CompletionRequestError, + ConversationCompletedError, + NotChatAppError, + NotCompletionAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.web.wraps import WebApiResource from core.application_queue_manager import ApplicationQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import Response, stream_with_context -from flask_restful import reqparse from libs.helper import uuid_value from services.completion_service import CompletionService -from werkzeug.exceptions import InternalServerError, NotFound # define completion api for user diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 1f17f7883e..b0d7747d65 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,15 +1,16 @@ # -*- coding:utf-8 -*- +from flask_restful import marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + from controllers.web import api from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range -from libs.helper import TimestampField, uuid_value +from libs.helper import uuid_value from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService -from werkzeug.exceptions import NotFound class ConversationListApi(WebApiResource): diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py index c43fe6fdf5..ca83f6037a 100644 --- a/api/controllers/web/file.py +++ b/api/controllers/web/file.py @@ -1,10 +1,11 @@ +from flask import request +from flask_restful import marshal_with + import services from controllers.web import api from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError from controllers.web.wraps import WebApiResource from fields.file_fields import file_fields -from flask import request -from flask_restful import marshal_with from services.file_service import FileService diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 2712e84691..1a084fe539 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -3,28 +3,35 @@ import json import logging from typing import Generator, Union +from flask import Response, stream_with_context +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import InternalServerError, NotFound + import services from controllers.web import api -from controllers.web.error import (AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, - CompletionRequestError, NotChatAppError, NotCompletionAppError, - ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) +from controllers.web.error import ( + AppMoreLikeThisDisabledError, + AppSuggestedQuestionsAfterAnswerDisabledError, + CompletionRequestError, + NotChatAppError, + NotCompletionAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.web.wraps import WebApiResource from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields -from flask import Response, stream_with_context -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range from libs.helper import TimestampField, uuid_value from services.completion_service import CompletionService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService -from werkzeug.exceptions import InternalServerError, NotFound class MessageListApi(WebApiResource): diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index bc6cf6028b..188cc41254 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,13 +1,14 @@ # -*- coding:utf-8 -*- import uuid -from controllers.web import api -from extensions.ext_database import db from flask import request from flask_restful import Resource +from werkzeug.exceptions import NotFound, Unauthorized + +from controllers.web import api +from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site -from werkzeug.exceptions import NotFound, Unauthorized class PassportResource(Resource): diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index b353b9682e..e17869ffdb 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,13 +1,14 @@ +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + from controllers.web import api from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import message_file_fields -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range from libs.helper import TimestampField, uuid_value from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -from werkzeug.exceptions import NotFound feedback_fields = { 'rating': fields.String diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 9f1297a06c..8ce3a81083 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,14 +1,14 @@ # -*- coding:utf-8 -*- -import os + +from flask import current_app +from flask_restful import fields, marshal_with +from werkzeug.exceptions import Forbidden from controllers.web import api from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from flask import current_app -from flask_restful import fields, marshal_with from models.model import Site from services.feature_service import FeatureService -from werkzeug.exceptions import Forbidden class AppSiteApi(WebApiResource): diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 0803a3b5ea..ebf6611784 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,12 +1,13 @@ # -*- coding:utf-8 -*- from functools import wraps -from extensions.ext_database import db from flask import request from flask_restful import Resource +from werkzeug.exceptions import NotFound, Unauthorized + +from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site -from werkzeug.exceptions import NotFound, Unauthorized def validate_jwt_token(view=None): diff --git a/api/core/agent/agent/calc_token_mixin.py b/api/core/agent/agent/calc_token_mixin.py index 1ca6c49812..b25ab2d88a 100644 --- a/api/core/agent/agent/calc_token_mixin.py +++ b/api/core/agent/agent/calc_token_mixin.py @@ -1,11 +1,9 @@ from typing import List, cast from core.entities.application_entities import ModelConfigEntity -from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from langchain.schema import BaseMessage class CalcTokenMixin: diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index c13641b84d..201421910d 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -1,10 +1,5 @@ -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, List, Optional, Sequence, Tuple, Union -from core.entities.application_entities import ModelConfigEntity -from core.entities.message_entities import lc_messages_to_prompt_messages -from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import PromptMessageTool -from core.third_party.langchain.llms.fake import FakeLLM from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message from langchain.callbacks.base import BaseCallbackManager @@ -14,6 +9,12 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage from langchain.tools import BaseTool from pydantic import root_validator +from core.entities.application_entities import ModelConfigEntity +from core.entities.message_entities import lc_messages_to_prompt_messages +from core.model_manager import ModelInstance +from core.model_runtime.entities.message_entities import PromptMessageTool +from core.third_party.langchain.llms.fake import FakeLLM + class MultiDatasetRouterAgent(OpenAIFunctionsAgent): """ diff --git a/api/core/agent/agent/openai_function_call.py b/api/core/agent/agent/openai_function_call.py index e17282a293..3dafa4517b 100644 --- a/api/core/agent/agent/openai_function_call.py +++ b/api/core/agent/agent/openai_function_call.py @@ -1,4 +1,23 @@ -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, List, Optional, Sequence, Tuple, Union + +from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent +from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message +from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import Callbacks +from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken +from langchain.memory.prompt import SUMMARY_PROMPT +from langchain.prompts.chat import BaseMessagePromptTemplate +from langchain.schema import ( + AgentAction, + AgentFinish, + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + get_buffer_string, +) +from langchain.tools import BaseTool +from pydantic import root_validator from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError @@ -7,19 +26,7 @@ from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.third_party.langchain.llms.fake import FakeLLM -from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent -from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken -from langchain.memory.prompt import SUMMARY_PROMPT -from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage, HumanMessage, SystemMessage, - get_buffer_string) -from langchain.tools import BaseTool -from pydantic import root_validator class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin): diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/agent/agent/structed_multi_dataset_router_agent.py index c8e6a84b09..9d36e01d7c 100644 --- a/api/core/agent/agent/structed_multi_dataset_router_agent.py +++ b/api/core/agent/agent/structed_multi_dataset_router_agent.py @@ -1,8 +1,6 @@ import re from typing import Any, List, Optional, Sequence, Tuple, Union, cast -from core.chain.llm_chain import LLMChain -from core.entities.application_entities import ModelConfigEntity from langchain import BasePromptTemplate, PromptTemplate from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE @@ -13,6 +11,9 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.tools import BaseTool +from core.chain.llm_chain import LLMChain +from core.entities.application_entities import ModelConfigEntity + FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. Valid "action" values: "Final Answer" or {tool_names} diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py index af0130b314..03fea8c27d 100644 --- a/api/core/agent/agent/structured_chat.py +++ b/api/core/agent/agent/structured_chat.py @@ -1,11 +1,6 @@ import re from typing import Any, List, Optional, Sequence, Tuple, Union, cast -from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError -from core.chain.llm_chain import LLMChain -from core.entities.application_entities import ModelConfigEntity -from core.entities.message_entities import lc_messages_to_prompt_messages from langchain import BasePromptTemplate, PromptTemplate from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE @@ -14,10 +9,23 @@ from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.memory.prompt import SUMMARY_PROMPT from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate -from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage, HumanMessage, OutputParserException, - get_buffer_string) +from langchain.schema import ( + AgentAction, + AgentFinish, + AIMessage, + BaseMessage, + HumanMessage, + OutputParserException, + get_buffer_string, +) from langchain.tools import BaseTool +from core.agent.agent.agent_llm_callback import AgentLLMCallback +from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError +from core.chain.llm_chain import LLMChain +from core.entities.application_entities import ModelConfigEntity +from core.entities.message_entities import lc_messages_to_prompt_messages + FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. Valid "action" values: "Final Answer" or {tool_names} diff --git a/api/core/agent/agent_executor.py b/api/core/agent/agent_executor.py index 2565fb2315..70fe00ee13 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/agent/agent_executor.py @@ -2,6 +2,12 @@ import enum import logging from typing import Optional, Union +from langchain.agents import AgentExecutor as LCAgentExecutor +from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent +from langchain.callbacks.manager import Callbacks +from langchain.tools import BaseTool +from pydantic import BaseModel, Extra + from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent @@ -15,11 +21,6 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.errors.invoke import InvokeError from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool -from langchain.agents import AgentExecutor as LCAgentExecutor -from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent -from langchain.callbacks.manager import Callbacks -from langchain.tools import BaseTool -from pydantic import BaseModel, Extra class PlanningStrategy(str, enum.Enum): diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index d751e301cc..457cae8289 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -2,9 +2,14 @@ import time from typing import Generator, List, Optional, Tuple, Union, cast from core.application_queue_manager import ApplicationQueueManager, PublishFrom -from core.entities.application_entities import (ApplicationGenerateEntity, AppOrchestrationConfigEntity, - ExternalDataVariableEntity, InvokeFrom, ModelConfigEntity, - PromptTemplateEntity) +from core.entities.application_entities import ( + ApplicationGenerateEntity, + AppOrchestrationConfigEntity, + ExternalDataVariableEntity, + InvokeFrom, + ModelConfigEntity, + PromptTemplateEntity, +) from core.features.annotation_reply import AnnotationReplyFeature from core.features.external_data_fetch import ExternalDataFetchFeature from core.features.hosting_moderation import HostingModerationFeature diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index ea54e709e1..39f51ee1b6 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -3,28 +3,42 @@ import logging import time from typing import Generator, Optional, Union, cast +from pydantic import BaseModel + from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom -from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentMessageEvent, QueueAgentThoughtEvent, - QueueErrorEvent, QueueMessageEndEvent, QueueMessageEvent, - QueueMessageFileEvent, QueueMessageReplaceEvent, QueuePingEvent, - QueueRetrieverResourcesEvent, QueueStopEvent) +from core.entities.queue_entities import ( + AnnotationReplyEvent, + QueueAgentMessageEvent, + QueueAgentThoughtEvent, + QueueErrorEvent, + QueueMessageEndEvent, + QueueMessageEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, +) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, PromptMessageContentType, PromptMessageRole, - TextPromptMessageContent) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageRole, + TextPromptMessageContent, +) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.prompt_template import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager -from core.tools.tool_manager import ToolManager from events.message_event import message_was_created from extensions.ext_database import db from models.model import Conversation, Message, MessageAgentThought, MessageFile -from pydantic import BaseModel from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) diff --git a/api/core/app_runner/moderation_handler.py b/api/core/app_runner/moderation_handler.py index 24ea085612..392425ed8e 100644 --- a/api/core/app_runner/moderation_handler.py +++ b/api/core/app_runner/moderation_handler.py @@ -3,11 +3,12 @@ import threading import time from typing import Any, Dict, Optional +from flask import Flask, current_app +from pydantic import BaseModel + from core.application_queue_manager import PublishFrom from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.factory import ModerationFactory -from flask import Flask, current_app -from pydantic import BaseModel logger = logging.getLogger(__name__) diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 7f07bed3a5..b718cefab6 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -4,17 +4,30 @@ import threading import uuid from typing import Any, Generator, Optional, Tuple, Union, cast +from flask import Flask, current_app +from pydantic import ValidationError + from core.app_runner.assistant_app_runner import AssistantApplicationRunner from core.app_runner.basic_app_runner import BasicApplicationRunner from core.app_runner.generate_task_pipeline import GenerateTaskPipeline from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom -from core.entities.application_entities import (AdvancedChatPromptTemplateEntity, - AdvancedCompletionPromptTemplateEntity, AgentEntity, AgentPromptEntity, - AgentToolEntity, ApplicationGenerateEntity, - AppOrchestrationConfigEntity, DatasetEntity, - DatasetRetrieveConfigEntity, ExternalDataVariableEntity, - FileUploadEntity, InvokeFrom, ModelConfigEntity, PromptTemplateEntity, - SensitiveWordAvoidanceEntity) +from core.entities.application_entities import ( + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + AgentEntity, + AgentPromptEntity, + AgentToolEntity, + ApplicationGenerateEntity, + AppOrchestrationConfigEntity, + DatasetEntity, + DatasetRetrieveConfigEntity, + ExternalDataVariableEntity, + FileUploadEntity, + InvokeFrom, + ModelConfigEntity, + PromptTemplateEntity, + SensitiveWordAvoidanceEntity, +) from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file.file_obj import FileObj @@ -26,10 +39,8 @@ from core.prompt.prompt_template import PromptTemplateParser from core.provider_manager import ProviderManager from core.tools.prompt.template import REACT_PROMPT_TEMPLATES from extensions.ext_database import db -from flask import Flask, current_app from models.account import Account from models.model import App, Conversation, EndUser, Message, MessageFile -from pydantic import ValidationError logger = logging.getLogger(__name__) diff --git a/api/core/application_queue_manager.py b/api/core/application_queue_manager.py index 65b52fd1f3..75a56d6706 100644 --- a/api/core/application_queue_manager.py +++ b/api/core/application_queue_manager.py @@ -3,15 +3,27 @@ import time from enum import Enum from typing import Any, Generator +from sqlalchemy.orm import DeclarativeMeta + from core.entities.application_entities import InvokeFrom -from core.entities.queue_entities import (AnnotationReplyEvent, AppQueueEvent, QueueAgentMessageEvent, - QueueAgentThoughtEvent, QueueErrorEvent, QueueMessage, QueueMessageEndEvent, - QueueMessageEvent, QueueMessageFileEvent, QueueMessageReplaceEvent, - QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent) +from core.entities.queue_entities import ( + AnnotationReplyEvent, + AppQueueEvent, + QueueAgentMessageEvent, + QueueAgentThoughtEvent, + QueueErrorEvent, + QueueMessage, + QueueMessageEndEvent, + QueueMessageEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, +) from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from extensions.ext_redis import redis_client from models.model import MessageAgentThought, MessageFile -from sqlalchemy.orm import DeclarativeMeta class PublishFrom(Enum): diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index edee77e25f..f9347198dc 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -3,6 +3,10 @@ import logging import time from typing import Any, Dict, List, Optional, Union, cast +from langchain.agents import openai_functions_agent, openai_functions_multi_agent +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult + from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.entity.agent_loop import AgentLoop from core.entities.application_entities import ModelConfigEntity @@ -10,9 +14,6 @@ from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResu from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db -from langchain.agents import openai_functions_agent, openai_functions_multi_agent -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, BaseMessage, ChatGeneration, LLMResult from models.model import Message, MessageAgentThought, MessageChain diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 9947028806..63c9bbe416 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,9 +1,10 @@ -from typing import List, Union +from typing import List + +from langchain.schema import Document from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import InvokeFrom from extensions.ext_database import db -from langchain.schema import Document from models.dataset import DatasetQuery, DocumentSegment from models.model import DatasetRetrieverResource diff --git a/api/core/chain/llm_chain.py b/api/core/chain/llm_chain.py index 20b71f2f64..a5d160c99e 100644 --- a/api/core/chain/llm_chain.py +++ b/api/core/chain/llm_chain.py @@ -1,14 +1,15 @@ from typing import Any, Dict, List, Optional +from langchain import LLMChain as LCLLMChain +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.schema import Generation, LLMResult +from langchain.schema.language_model import BaseLanguageModel + from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_manager import ModelInstance from core.third_party.langchain.llms.fake import FakeLLM -from langchain import LLMChain as LCLLMChain -from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.schema import Generation, LLMResult -from langchain.schema.language_model import BaseLanguageModel class LLMChain(LCLLMChain): diff --git a/api/core/data_loader/file_extractor.py b/api/core/data_loader/file_extractor.py index 14a9693623..af0fb1d35a 100644 --- a/api/core/data_loader/file_extractor.py +++ b/api/core/data_loader/file_extractor.py @@ -3,6 +3,10 @@ from pathlib import Path from typing import List, Optional, Union import requests +from flask import current_app +from langchain.document_loaders import Docx2txtLoader, TextLoader +from langchain.schema import Document + from core.data_loader.loader.csv_loader import CSVLoader from core.data_loader.loader.excel import ExcelLoader from core.data_loader.loader.html import HTMLLoader @@ -16,9 +20,6 @@ from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredP from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader from extensions.ext_storage import storage -from flask import current_app -from langchain.document_loaders import Docx2txtLoader, TextLoader -from langchain.schema import Document from models.model import UploadFile SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain'] diff --git a/api/core/data_loader/loader/excel.py b/api/core/data_loader/loader/excel.py index 5e76c21a8f..f5f6b2d69c 100644 --- a/api/core/data_loader/loader/excel.py +++ b/api/core/data_loader/loader/excel.py @@ -1,4 +1,3 @@ -import json import logging from typing import List diff --git a/api/core/data_loader/loader/notion.py b/api/core/data_loader/loader/notion.py index 914c04d5c0..9f9198c3ce 100644 --- a/api/core/data_loader/loader/notion.py +++ b/api/core/data_loader/loader/notion.py @@ -3,10 +3,11 @@ import logging from typing import Any, Dict, List, Optional import requests -from extensions.ext_database import db from flask import current_app from langchain.document_loaders.base import BaseLoader from langchain.schema import Document + +from extensions.ext_database import db from models.dataset import Document as DocumentModel from models.source import DataSourceBinding diff --git a/api/core/data_loader/loader/pdf.py b/api/core/data_loader/loader/pdf.py index 8b08393d91..881d0026b5 100644 --- a/api/core/data_loader/loader/pdf.py +++ b/api/core/data_loader/loader/pdf.py @@ -1,10 +1,11 @@ import logging from typing import List, Optional -from extensions.ext_storage import storage from langchain.document_loaders import PyPDFium2Loader from langchain.document_loaders.base import BaseLoader from langchain.schema import Document + +from extensions.ext_storage import storage from models.model import UploadFile logger = logging.getLogger(__name__) diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index 49e87ec340..77a5dde9ed 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -1,12 +1,13 @@ from typing import Any, Dict, Optional, Sequence, cast +from langchain.schema import Document +from sqlalchemy import func + from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db -from langchain.schema import Document from models.dataset import Dataset, DocumentSegment -from sqlalchemy import func class DatasetDocumentStore: diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 185b87b8b6..4f7b3a1530 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -1,18 +1,17 @@ import base64 -import json import logging from typing import List, Optional, cast import numpy as np +from langchain.embeddings.base import Embeddings +from sqlalchemy.exc import IntegrityError + from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db from extensions.ext_redis import redis_client -from langchain.embeddings.base import Embeddings from libs import helper -from models.dataset import Embedding -from sqlalchemy.exc import IntegrityError logger = logging.getLogger(__name__) diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index 6883a004e4..d26998ce80 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -1,11 +1,12 @@ from enum import Enum -from typing import Any, Literal, Optional, Union, cast +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel from core.entities.provider_configuration import ProviderModelBundle from core.file.file_obj import FileObj from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import AIModelEntity -from pydantic import BaseModel class ModelConfigEntity(BaseModel): diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py index 9b0b287f28..51b9582a91 100644 --- a/api/core/entities/message_entities.py +++ b/api/core/entities/message_entities.py @@ -1,12 +1,19 @@ import enum from typing import Any, cast -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, SystemPromptMessage, TextPromptMessageContent, - ToolPromptMessage, UserPromptMessage) from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage from pydantic import BaseModel +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) + class PromptMessageFileType(enum.Enum): IMAGE = 'image' diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 3888807227..05719e5b8d 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,10 +1,11 @@ from enum import Enum from typing import Optional +from pydantic import BaseModel + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType, ProviderModel -from core.model_runtime.entities.provider_entities import ProviderEntity, SimpleProviderEntity -from pydantic import BaseModel +from core.model_runtime.entities.provider_entities import ProviderEntity class ModelStatus(Enum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index a7a365fe69..fd61647635 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -4,20 +4,24 @@ import logging from json import JSONDecodeError from typing import Dict, Iterator, List, Optional, Tuple +from pydantic import BaseModel + from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import FetchFrom, ModelType -from core.model_runtime.entities.provider_entities import (ConfigurateMethod, CredentialFormSchema, FormType, - ProviderEntity) +from core.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.model_provider import ModelProvider -from core.model_runtime.utils import encoders from extensions.ext_database import db from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider -from pydantic import BaseModel logger = logging.getLogger(__name__) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index ab6fea0a2f..114dfaf911 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,9 +1,10 @@ from enum import Enum from typing import Optional +from pydantic import BaseModel + from core.model_runtime.entities.model_entities import ModelType from models.provider import ProviderQuotaType -from pydantic import BaseModel class QuotaUnit(Enum): diff --git a/api/core/entities/queue_entities.py b/api/core/entities/queue_entities.py index d6ef28b138..c1f8fb7e89 100644 --- a/api/core/entities/queue_entities.py +++ b/api/core/entities/queue_entities.py @@ -1,9 +1,10 @@ from enum import Enum from typing import Any -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from pydantic import BaseModel +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk + class QueueEvent(Enum): """ diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 97e2f4bdf0..40e60687b2 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,6 +1,7 @@ import os import requests + from models.api_based_extension import APIBasedExtensionPoint diff --git a/api/core/features/agent_runner.py b/api/core/features/agent_runner.py index 66d41dace0..7412d81281 100644 --- a/api/core/features/agent_runner.py +++ b/api/core/features/agent_runner.py @@ -1,5 +1,7 @@ import logging -from typing import List, Optional, cast +from typing import Optional, cast + +from langchain.tools import BaseTool from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy @@ -7,20 +9,20 @@ from core.application_queue_manager import ApplicationQueueManager from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.entities.application_entities import (AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity, InvokeFrom, - ModelConfigEntity) +from core.entities.application_entities import ( + AgentEntity, + AppOrchestrationConfigEntity, + InvokeFrom, + ModelConfigEntity, +) from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db -from langchain import WikipediaAPIWrapper -from langchain.callbacks.base import BaseCallbackHandler -from langchain.tools import BaseTool, Tool, WikipediaQueryRun from models.dataset import Dataset from models.model import Message -from pydantic import BaseModel, Field logger = logging.getLogger(__name__) diff --git a/api/core/features/annotation_reply.py b/api/core/features/annotation_reply.py index 09945aaf6e..bdc5467e62 100644 --- a/api/core/features/annotation_reply.py +++ b/api/core/features/annotation_reply.py @@ -1,13 +1,14 @@ import logging from typing import Optional +from flask import current_app + from core.embedding.cached_embedding import CacheEmbedding from core.entities.application_entities import InvokeFrom from core.index.vector_index.vector_index import VectorIndex from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db -from flask import current_app from models.dataset import Dataset from models.model import App, AppAnnotationSetting, Message, MessageAnnotation from services.annotation_service import AppAnnotationService diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index 5538918234..adc8f3b663 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -8,8 +8,14 @@ from core.app_runner.app_runner import AppRunner from core.application_queue_manager import ApplicationQueueManager from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import (AgentEntity, AgentToolEntity, ApplicationGenerateEntity, - AppOrchestrationConfigEntity, InvokeFrom, ModelConfigEntity) +from core.entities.application_entities import ( + AgentEntity, + AgentToolEntity, + ApplicationGenerateEntity, + AppOrchestrationConfigEntity, + InvokeFrom, + ModelConfigEntity, +) from core.file.message_file_parser import FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance @@ -18,8 +24,12 @@ from core.model_runtime.entities.message_entities import PromptMessage, PromptMe from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.entities.tool_entities import (ToolInvokeMessage, ToolInvokeMessageBinary, ToolParameter, - ToolRuntimeVariablePool) +from core.tools.entities.tool_entities import ( + ToolInvokeMessage, + ToolInvokeMessageBinary, + ToolParameter, + ToolRuntimeVariablePool, +) from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool from core.tools.tool_file_manager import ToolFileManager diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index 0d64920403..9d35832316 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -1,18 +1,27 @@ import json -import logging import re from typing import Dict, Generator, List, Literal, Union from core.application_queue_manager import PublishFrom from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit from core.features.assistant_base_runner import BaseAssistantApplicationRunner -from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.errors import (ToolInvokeError, ToolNotFoundError, ToolNotSupportedError, ToolParameterValidationError, - ToolProviderCredentialValidationError, ToolProviderNotFoundError) +from core.tools.errors import ( + ToolInvokeError, + ToolNotFoundError, + ToolNotSupportedError, + ToolParameterValidationError, + ToolProviderCredentialValidationError, + ToolProviderNotFoundError, +) from models.model import Conversation, Message diff --git a/api/core/features/assistant_fc_runner.py b/api/core/features/assistant_fc_runner.py index 8b42244838..f0a55aa80b 100644 --- a/api/core/features/assistant_fc_runner.py +++ b/api/core/features/assistant_fc_runner.py @@ -4,12 +4,23 @@ from typing import Any, Dict, Generator, List, Tuple, Union from core.application_queue_manager import PublishFrom from core.features.assistant_base_runner import BaseAssistantApplicationRunner -from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, ToolPromptMessage, UserPromptMessage) -from core.tools.errors import (ToolInvokeError, ToolNotFoundError, ToolNotSupportedError, ToolParameterValidationError, - ToolProviderCredentialValidationError, ToolProviderNotFoundError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.tools.errors import ( + ToolInvokeError, + ToolNotFoundError, + ToolNotSupportedError, + ToolParameterValidationError, + ToolProviderCredentialValidationError, + ToolProviderNotFoundError, +) from models.model import Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) diff --git a/api/core/features/dataset_retrieval.py b/api/core/features/dataset_retrieval.py index f8fcea7c10..159428aad4 100644 --- a/api/core/features/dataset_retrieval.py +++ b/api/core/features/dataset_retrieval.py @@ -1,5 +1,7 @@ from typing import List, Optional, cast +from langchain.tools import BaseTool + from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity @@ -9,7 +11,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db -from langchain.tools import BaseTool from models.dataset import Dataset diff --git a/api/core/features/external_data_fetch.py b/api/core/features/external_data_fetch.py index 791fbf6ae3..33154d8389 100644 --- a/api/core/features/external_data_fetch.py +++ b/api/core/features/external_data_fetch.py @@ -4,9 +4,10 @@ import logging from concurrent.futures import ThreadPoolExecutor from typing import Optional, Tuple +from flask import Flask, current_app + from core.entities.application_entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory -from flask import Flask, current_app logger = logging.getLogger(__name__) diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 626dbbca43..435074f743 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -1,11 +1,12 @@ import enum from typing import Optional +from pydantic import BaseModel + from core.file.upload_file_parser import UploadFileParser from core.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db from models.model import UploadFile -from pydantic import BaseModel class FileType(enum.Enum): diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index c92f9e6950..ce783d8fbb 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional, Union import requests + from core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType from extensions.ext_database import db from models.account import Account diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index ca63301a59..b259a911d8 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -6,9 +6,10 @@ import os import time from typing import Optional -from extensions.ext_storage import storage from flask import current_app +from extensions.ext_storage import storage + IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index 2a15575360..072b02dc94 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -1,6 +1,8 @@ import json import logging +from langchain.schema import OutputParserException + from core.model_manager import ModelManager from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType @@ -9,7 +11,6 @@ from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorO from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT -from langchain.schema import OutputParserException class LLMGenerator: diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 9501f2ce5b..0bfe763fac 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -2,11 +2,16 @@ Proxy requests to avoid SSRF """ -from httpx import get as _get, post as _post, put as _put, patch as _patch, head as _head, options as _options -from requests import delete as _delete - import os +from httpx import get as _get +from httpx import head as _head +from httpx import options as _options +from httpx import patch as _patch +from httpx import post as _post +from httpx import put as _put +from requests import delete as _delete + SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index ea93c5336c..58b551f295 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,10 +1,11 @@ from typing import Optional +from flask import Config, Flask +from pydantic import BaseModel + from core.entities.provider_entities import QuotaUnit, RestrictModel from core.model_runtime.entities.model_entities import ModelType -from flask import Config, Flask from models.provider import ProviderQuotaType -from pydantic import BaseModel class HostingQuota(BaseModel): diff --git a/api/core/index/base.py b/api/core/index/base.py index 33178ff83b..1dc7cfdcc6 100644 --- a/api/core/index/base.py +++ b/api/core/index/base.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Any, List from langchain.schema import BaseRetriever, Document + from models.dataset import Dataset diff --git a/api/core/index/index.py b/api/core/index/index.py index 56ce3c99c6..42971c895e 100644 --- a/api/core/index/index.py +++ b/api/core/index/index.py @@ -1,10 +1,11 @@ +from flask import current_app +from langchain.embeddings import OpenAIEmbeddings + from core.embedding.cached_embedding import CacheEmbedding from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex from core.index.vector_index.vector_index import VectorIndex from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from flask import current_app -from langchain.embeddings import OpenAIEmbeddings from models.dataset import Dataset diff --git a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py b/api/core/index/keyword_table_index/jieba_keyword_table_handler.py index fc07402206..db9fd027a0 100644 --- a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py +++ b/api/core/index/keyword_table_index/jieba_keyword_table_handler.py @@ -2,9 +2,10 @@ import re from typing import Set import jieba -from core.index.keyword_table_index.stopwords import STOPWORDS from jieba.analyse import default_tfidf +from core.index.keyword_table_index.stopwords import STOPWORDS + class JiebaKeywordTableHandler: diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/index/keyword_table_index/keyword_table_index.py index 06eef1ebf2..9ad8b8d64e 100644 --- a/api/core/index/keyword_table_index/keyword_table_index.py +++ b/api/core/index/keyword_table_index/keyword_table_index.py @@ -2,12 +2,13 @@ import json from collections import defaultdict from typing import Any, Dict, List, Optional +from langchain.schema import BaseRetriever, Document +from pydantic import BaseModel, Extra, Field + from core.index.base import BaseIndex from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler from extensions.ext_database import db -from langchain.schema import BaseRetriever, Document from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment -from pydantic import BaseModel, Extra, Field class KeywordTableConfig(BaseModel): diff --git a/api/core/index/vector_index/base.py b/api/core/index/vector_index/base.py index ccc1833821..b9b8e6d3dc 100644 --- a/api/core/index/vector_index/base.py +++ b/api/core/index/vector_index/base.py @@ -3,14 +3,14 @@ import logging from abc import abstractmethod from typing import Any, List, cast -from core.index.base import BaseIndex -from extensions.ext_database import db from langchain.embeddings.base import Embeddings from langchain.schema import BaseRetriever, Document from langchain.vectorstores import VectorStore -from models.dataset import Dataset, DatasetCollectionBinding + +from core.index.base import BaseIndex +from extensions.ext_database import db +from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.dataset import DocumentSegment class BaseVectorIndex(BaseIndex): diff --git a/api/core/index/vector_index/milvus_vector_index.py b/api/core/index/vector_index/milvus_vector_index.py index 67ba5a7b32..a0b6f5d207 100644 --- a/api/core/index/vector_index/milvus_vector_index.py +++ b/api/core/index/vector_index/milvus_vector_index.py @@ -1,13 +1,14 @@ from typing import Any, List, cast +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +from langchain.vectorstores import VectorStore +from pydantic import BaseModel, root_validator + from core.index.base import BaseIndex from core.index.vector_index.base import BaseVectorIndex from core.vector_store.milvus_vector_store import MilvusVectorStore -from langchain.embeddings.base import Embeddings -from langchain.schema import Document -from langchain.vectorstores import VectorStore from models.dataset import Dataset -from pydantic import BaseModel, root_validator class MilvusConfig(BaseModel): diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py index f755fe4101..f182c4c0e1 100644 --- a/api/core/index/vector_index/qdrant_vector_index.py +++ b/api/core/index/vector_index/qdrant_vector_index.py @@ -2,16 +2,17 @@ import os from typing import Any, List, Optional, cast import qdrant_client +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +from langchain.vectorstores import VectorStore +from pydantic import BaseModel +from qdrant_client.http.models import HnswConfigDiff + from core.index.base import BaseIndex from core.index.vector_index.base import BaseVectorIndex from core.vector_store.qdrant_vector_store import QdrantVectorStore from extensions.ext_database import db -from langchain.embeddings.base import Embeddings -from langchain.schema import BaseRetriever, Document -from langchain.vectorstores import VectorStore from models.dataset import Dataset, DatasetCollectionBinding -from pydantic import BaseModel -from qdrant_client.http.models import HnswConfigDiff class QdrantConfig(BaseModel): diff --git a/api/core/index/vector_index/vector_index.py b/api/core/index/vector_index/vector_index.py index 0a69c4f734..74b5f0adc1 100644 --- a/api/core/index/vector_index/vector_index.py +++ b/api/core/index/vector_index/vector_index.py @@ -1,9 +1,10 @@ import json +from flask import current_app +from langchain.embeddings.base import Embeddings + from core.index.vector_index.base import BaseVectorIndex from extensions.ext_database import db -from flask import current_app -from langchain.embeddings.base import Embeddings from models.dataset import Dataset, Document diff --git a/api/core/index/vector_index/weaviate_vector_index.py b/api/core/index/vector_index/weaviate_vector_index.py index b4add6c11a..8af3c5926b 100644 --- a/api/core/index/vector_index/weaviate_vector_index.py +++ b/api/core/index/vector_index/weaviate_vector_index.py @@ -2,14 +2,15 @@ from typing import Any, List, Optional, cast import requests import weaviate +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +from langchain.vectorstores import VectorStore +from pydantic import BaseModel, root_validator + from core.index.base import BaseIndex from core.index.vector_index.base import BaseVectorIndex from core.vector_store.weaviate_vector_store import WeaviateVectorStore -from langchain.embeddings.base import Embeddings -from langchain.schema import BaseRetriever, Document -from langchain.vectorstores import VectorStore from models.dataset import Dataset -from pydantic import BaseModel, root_validator class WeaviateConfig(BaseModel): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 2f1cf282f8..04f5dbeab6 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,7 +5,13 @@ import re import threading import time import uuid -from typing import AbstractSet, Any, Collection, List, Literal, Optional, Type, Union, cast +from typing import List, Optional, cast + +from flask import Flask, current_app +from flask_login import current_user +from langchain.schema import Document +from langchain.text_splitter import TextSplitter +from sqlalchemy.orm.exc import ObjectDeletedError from core.data_loader.file_extractor import FileExtractor from core.data_loader.loader.notion import NotionLoader @@ -17,22 +23,15 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.spiltter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from flask import Flask, current_app -from flask_login import current_user -from langchain.schema import Document -from langchain.text_splitter import TS, TextSplitter, TokenTextSplitter from libs import helper -from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument -from models.dataset import DocumentSegment from models.model import UploadFile from models.source import DataSourceBinding -from sqlalchemy.orm.exc import ObjectDeletedError class IndexingRunner: diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 663daa0856..f1f8ab3a3b 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,7 +1,12 @@ from core.file.message_file_parser import MessageFileParser from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, - TextPromptMessageContent, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import model_provider_factory from extensions.ext_database import db diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index 76d4ef310e..b5bd9e267a 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -2,9 +2,10 @@ from decimal import Decimal from enum import Enum from typing import Optional +from pydantic import BaseModel + from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo -from pydantic import BaseModel class LLMMode(Enum): diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 2041cb3a97..ebde3ec85b 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -2,9 +2,10 @@ from decimal import Decimal from enum import Enum from typing import Any, Optional -from core.model_runtime.entities.common_entities import I18nObject from pydantic import BaseModel +from core.model_runtime.entities.common_entities import I18nObject + class ModelType(Enum): """ diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index bd55d60795..acc453bb84 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -1,9 +1,10 @@ from enum import Enum from typing import Optional +from pydantic import BaseModel + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel -from pydantic import BaseModel class ConfigurateMethod(Enum): diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 499c76eb7d..7be3def379 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -1,8 +1,9 @@ from decimal import Decimal -from core.model_runtime.entities.model_entities import ModelUsage from pydantic import BaseModel +from core.model_runtime.entities.model_entities import ModelUsage + class EmbeddingUsage(ModelUsage): """ diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 11f9a7a6fb..eb811ab224 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -4,10 +4,18 @@ from abc import ABC, abstractmethod from typing import Optional import yaml + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, ModelType, - PriceConfig, PriceInfo, PriceType) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelType, + PriceConfig, + PriceInfo, + PriceType, +) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 75ea7bacef..173b4dcab7 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -9,8 +9,13 @@ from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.logging_callback import LoggingCallback from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import (ModelPropertyKey, ModelType, ParameterRule, ParameterType, - PriceType) +from core.model_runtime.entities.model_entities import ( + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceType, +) from core.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index a856d42588..f3d71670f1 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -1,9 +1,10 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Dict, Optional +from typing import Dict import yaml + from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 987f2fabf1..3f689a724d 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -3,14 +3,26 @@ from typing import Generator, List, Optional, Union import anthropic from anthropic import Anthropic, Stream from anthropic.types import Completion, completion_create_params +from httpx import Timeout + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from httpx import Timeout class AnthropicLargeLanguageModel(LargeLanguageModel): diff --git a/api/core/model_runtime/model_providers/azure_openai/_common.py b/api/core/model_runtime/model_providers/azure_openai/_common.py index 627b487357..b65138252b 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_common.py +++ b/api/core/model_runtime/model_providers/azure_openai/_common.py @@ -1,9 +1,16 @@ import openai -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) -from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPENAI_API_VERSION from httpx import Timeout +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPENAI_API_VERSION + class _CommonAzureOpenAI: @staticmethod diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 8104df52dd..90dd2e7a6b 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -1,9 +1,18 @@ +from pydantic import BaseModel + from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, I18nObject, - ModelFeature, ModelPropertyKey, ModelType, ParameterRule, - PriceConfig) -from pydantic import BaseModel +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + I18nObject, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + PriceConfig, +) AZURE_OPENAI_API_VERSION = '2023-12-01-preview' diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 326043aa39..1bab34edd6 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -3,22 +3,30 @@ import logging from typing import Generator, List, Optional, Union, cast import tiktoken -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, PromptMessageContentType, PromptMessageTool, - SystemPromptMessage, TextPromptMessageContent, - ToolPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI -from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel from openai import AzureOpenAI, Stream from openai.types import Completion from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI +from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel + logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index e472151cb5..606a898db5 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -5,13 +5,14 @@ from typing import Optional, Tuple, Union import numpy as np import tiktoken +from openai import AzureOpenAI + from core.model_runtime.entities.model_entities import AIModelEntity, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI from core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_BASE_MODELS, AzureBaseModel -from openai import AzureOpenAI class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index 48ed86f66b..46ba0cffaf 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -1,18 +1,19 @@ from enum import Enum from hashlib import md5 from json import dumps, loads -from os.path import join -from time import time -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Any, Dict, Generator, List, Union -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (BadRequestError, - InsufficientAccountBalance, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) from requests import post +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( + BadRequestError, + InsufficientAccountBalance, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) + class BaichuanMessage: class Role(Enum): diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index c8bb1feb52..d9a73477f6 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -1,20 +1,33 @@ -from typing import Generator, List, Optional, Union, cast +from typing import Generator, List, cast -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (BadRequestError, - InsufficientAccountBalance, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( + BadRequestError, + InsufficientAccountBalance, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) class BaichuanLarguageModel(LargeLanguageModel): diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 20aafea1eb..5020c58996 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -1,21 +1,30 @@ import time -from json import dumps, loads +from json import dumps from typing import Optional, Tuple +from requests import post + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (BadRequestError, - InsufficientAccountBalance, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) -from requests import post +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( + BadRequestError, + InsufficientAccountBalance, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) class BaichuanTextEmbeddingModel(TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 6a9c695350..7a2faae895 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -4,13 +4,30 @@ from typing import Generator, List, Optional, Union import boto3 from botocore.config import Config -from botocore.exceptions import (ClientError, EndpointConnectionError, NoRegionError, ServiceNotInRegionError, - UnknownServiceError) +from botocore.exceptions import ( + ClientError, + EndpointConnectionError, + NoRegionError, + ServiceNotInRegionError, + UnknownServiceError, +) + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index 6f78f7aa88..fd2bcd5ec3 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -1,23 +1,44 @@ import logging -from json import dumps from os.path import join from typing import Generator, List, Optional, cast +from httpx import Timeout +from openai import ( + APIConnectionError, + APITimeoutError, + AuthenticationError, + ConflictError, + InternalServerError, + NotFoundError, + OpenAI, + PermissionDeniedError, + RateLimitError, + Stream, + UnprocessableEntityError, +) +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat.chat_completion_message import FunctionCall + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction, - PromptMessageTool, SystemPromptMessage, ToolPromptMessage, - UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils import helper -from httpx import Timeout -from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, - NotFoundError, OpenAI, PermissionDeniedError, RateLimitError, Stream, UnprocessableEntityError) -from openai.types.chat import ChatCompletion, ChatCompletionChunk -from openai.types.chat.chat_completion_message import FunctionCall -from requests import post logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index acff4177c3..95d3252b11 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -5,14 +5,26 @@ import cohere from cohere.responses import Chat, Generations from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration from cohere.responses.generation import StreamingGenerations, StreamingText + from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, - PromptMessageContentType, PromptMessageTool, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index 8c82cce766..7fee57f670 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -1,9 +1,16 @@ from typing import Optional import cohere + from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.rerank_model import RerankModel diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index a239727814..fda8b27de4 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -4,10 +4,17 @@ from typing import Optional, Tuple import cohere import numpy as np from cohere.responses import Tokens + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 6fd5c9144c..e376e72c07 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -4,18 +4,30 @@ from typing import Generator, List, Optional, Union import google.api_core.exceptions as exceptions import google.generativeai as genai import google.generativeai.client as client -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, - PromptMessageContentType, PromptMessageRole, - PromptMessageTool, SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers import google -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types.content_types import to_part +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageContentType, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + logger = logging.getLogger(__name__) class GoogleLargeLanguageModel(LargeLanguageModel): diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py index 1140c947b9..dd8ae526e6 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -1,6 +1,7 @@ -from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError from huggingface_hub.utils import BadRequestError, HfHubHTTPError +from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError + class _CommonHuggingfaceHub: diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index e0701dff59..381d29c7e5 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -1,18 +1,30 @@ from typing import Generator, List, Optional, Union +from huggingface_hub import InferenceClient +from huggingface_hub.hf_api import HfApi +from huggingface_hub.utils import BadRequestError + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, - ModelPropertyKey, ModelType, ParameterRule) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub -from huggingface_hub import InferenceClient -from huggingface_hub.hf_api import HfApi -from huggingface_hub.utils import BadRequestError class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel): diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index f0dc632fae..0f0c166f3e 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -4,13 +4,14 @@ from typing import Optional import numpy as np import requests +from huggingface_hub import HfApi, InferenceClient + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub -from huggingface_hub import HfApi, InferenceClient HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/' diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 50238fbcde..5c146972cd 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -2,14 +2,21 @@ import time from json import JSONDecodeError, dumps from typing import Optional +from requests import post + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.jina.text_embedding.jina_tokenizer import JinaTokenizer -from requests import post class JinaTextEmbeddingModel(TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 117ef8c399..8d571d20b1 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -1,24 +1,53 @@ from os.path import join -from typing import Generator, List, Optional, Union, cast +from typing import Generator, List, cast -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, - ParameterRule, ParameterType) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils import helper from httpx import Timeout -from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, - NotFoundError, OpenAI, PermissionDeniedError, RateLimitError, Stream, UnprocessableEntityError) +from openai import ( + APIConnectionError, + APITimeoutError, + AuthenticationError, + ConflictError, + InternalServerError, + NotFoundError, + OpenAI, + PermissionDeniedError, + RateLimitError, + Stream, + UnprocessableEntityError, +) from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_message import FunctionCall from openai.types.completion import Completion +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils import helper + class LocalAILarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, diff --git a/api/core/model_runtime/model_providers/localai/localai.py b/api/core/model_runtime/model_providers/localai/localai.py index 9ba94e5f21..6d2278fd54 100644 --- a/api/core/model_runtime/model_providers/localai/localai.py +++ b/api/core/model_runtime/model_providers/localai/localai.py @@ -1,7 +1,5 @@ import logging -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index 511f09e3e7..39143127eb 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -3,13 +3,20 @@ from json import JSONDecodeError, dumps from os.path import join from typing import Optional +from requests import post + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from requests import post class LocalAITextEmbeddingModel(TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 718ebb1013..ee73005bd7 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -1,14 +1,18 @@ -from hashlib import md5 from json import dumps, loads -from time import time from typing import Any, Dict, Generator, List, Union -from core.model_runtime.model_providers.minimax.llm.errors import (BadRequestError, InsufficientAccountBalanceError, - InternalServerError, InvalidAPIKeyError, - InvalidAuthenticationError, RateLimitReachedError) -from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage from requests import Response, post +from core.model_runtime.model_providers.minimax.llm.errors import ( + BadRequestError, + InsufficientAccountBalanceError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) +from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage + class MinimaxChatCompletion(object): """ diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 6233af26b6..2497a9d7b8 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -1,14 +1,18 @@ -from hashlib import md5 from json import dumps, loads -from time import time from typing import Any, Dict, Generator, List, Union -from core.model_runtime.model_providers.minimax.llm.errors import (BadRequestError, InsufficientAccountBalanceError, - InternalServerError, InvalidAPIKeyError, - InvalidAuthenticationError, RateLimitReachedError) -from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage from requests import Response, post +from core.model_runtime.model_providers.minimax.llm.errors import ( + BadRequestError, + InsufficientAccountBalanceError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) +from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage + class MinimaxChatCompletionPro(object): """ diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index 2657c85419..bc65e756eb 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -1,17 +1,34 @@ from typing import Generator, List from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, ToolPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.minimax.llm.chat_completion import MinimaxChatCompletion from core.model_runtime.model_providers.minimax.llm.chat_completion_pro import MinimaxChatCompletionPro -from core.model_runtime.model_providers.minimax.llm.errors import (BadRequestError, InsufficientAccountBalanceError, - InternalServerError, InvalidAPIKeyError, - InvalidAuthenticationError, RateLimitReachedError) +from core.model_runtime.model_providers.minimax.llm.errors import ( + BadRequestError, + InsufficientAccountBalanceError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index 65f2a9a225..edf4d6005a 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -1,17 +1,29 @@ import time -from json import dumps, loads +from json import dumps from typing import Optional +from requests import post + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.minimax.llm.errors import (BadRequestError, InsufficientAccountBalanceError, - InternalServerError, InvalidAPIKeyError, - InvalidAuthenticationError, RateLimitReachedError) -from requests import post +from core.model_runtime.model_providers.minimax.llm.errors import ( + BadRequestError, + InsufficientAccountBalanceError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) class MinimaxTextEmbeddingModel(TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 06932b018d..e1e74ea806 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -5,12 +5,13 @@ from collections import OrderedDict from typing import Optional import yaml +from pydantic import BaseModel + from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator -from pydantic import BaseModel logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index 26f890b10b..40618b7fb4 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -1,7 +1,7 @@ -from typing import List, Optional, Union, Generator +from typing import Generator, List, Optional, Union from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import (PromptMessage, PromptMessageTool) +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel diff --git a/api/core/model_runtime/model_providers/moonshot/moonshot.py b/api/core/model_runtime/model_providers/moonshot/moonshot.py index c35882e010..5654ae1459 100644 --- a/api/core/model_runtime/model_providers/moonshot/moonshot.py +++ b/api/core/model_runtime/model_providers/moonshot/moonshot.py @@ -1,6 +1,6 @@ import logging -from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index 27c4be125a..848ac76d33 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -6,16 +6,38 @@ from typing import Generator, List, Optional, Union, cast from urllib.parse import urljoin import requests + from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, PromptMessageContentType, PromptMessageTool, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, I18nObject, - ModelFeature, ModelPropertyKey, ModelType, ParameterRule, - ParameterType, PriceConfig) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + I18nObject, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 5d96ac65ff..fd73728b78 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -7,12 +7,25 @@ from urllib.parse import urljoin import numpy as np import requests + from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, - PriceConfig, PriceType) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + PriceConfig, + PriceType, +) from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 91705c3ba8..436461c11e 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -1,8 +1,15 @@ import openai -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from httpx import Timeout +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + class _CommonOpenAI: def _to_credential_kwargs(self, credentials: dict) -> dict: diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 7722c69a95..56a88884f8 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -2,23 +2,29 @@ import logging from typing import Generator, List, Optional, Union, cast import tiktoken -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, PromptMessageContentType, - PromptMessageFunction, PromptMessageTool, SystemPromptMessage, - TextPromptMessageContent, ToolPromptMessage, - UserPromptMessage) -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType, PriceConfig -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.openai._common import _CommonOpenAI -from core.model_runtime.utils import helper from openai import OpenAI, Stream from openai.types import Completion from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType, PriceConfig +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.openai._common import _CommonOpenAI + logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index 2a0901d752..b1d0e57ad2 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -1,11 +1,12 @@ from typing import Optional +from openai import OpenAI +from openai.types import ModerationCreateResponse + from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.moderation_model import ModerationModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI -from openai import OpenAI -from openai.types import ModerationCreateResponse class OpenAIModerationModel(_CommonOpenAI, ModerationModel): diff --git a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py index b2b337a563..efbdd054f9 100644 --- a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py @@ -1,9 +1,10 @@ from typing import IO, Optional +from openai import OpenAI + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI -from openai import OpenAI class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index 87a5cf1a2a..28ab5c30ff 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -4,12 +4,13 @@ from typing import Optional, Tuple, Union import numpy as np import tiktoken +from openai import OpenAI + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI -from openai import OpenAI class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index 95a88e9bec..b3e66c1223 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -3,13 +3,14 @@ from functools import reduce from io import BytesIO from typing import Optional +from flask import Response, stream_with_context +from openai import OpenAI +from pydub import AudioSegment + from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI -from flask import Response, stream_with_context -from openai import OpenAI -from pydub import AudioSegment class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index 9b7b052b99..51950ca377 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -1,13 +1,14 @@ -from decimal import Decimal import requests -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, - ModelPropertyKey, ModelType, ParameterRule, ParameterType, - PriceConfig) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) + +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) class _CommonOAI_API_Compat: diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 53ee5817d9..2430ff2b2d 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -5,15 +5,31 @@ from typing import Generator, List, Optional, Union, cast from urllib.parse import urljoin import requests + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, PromptMessageContent, PromptMessageContentType, - PromptMessageFunction, PromptMessageTool, SystemPromptMessage, - ToolPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, - ModelPropertyKey, ModelType, ParameterRule, ParameterType, - PriceConfig) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContent, + PromptMessageContentType, + PromptMessageFunction, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, +) from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py index 407eefa701..3445ebbaf7 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py @@ -1,7 +1,5 @@ import logging -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index b735fdb792..4c75682de2 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -6,9 +6,16 @@ from urllib.parse import urljoin import numpy as np import requests + from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, - PriceConfig, PriceType) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + PriceConfig, + PriceType, +) from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index af62ddf92f..3491f107ab 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -1,22 +1,40 @@ -from typing import Generator, List, Optional, Union +from typing import Generator, List from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, - ParameterRule, ParameterType) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.openllm.llm.openllm_generate import OpenLLMGenerate, OpenLLMGenerateMessage -from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import (BadRequestError, - InsufficientAccountBalanceError, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) +from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import ( + BadRequestError, + InsufficientAccountBalanceError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) class OpenLLMLargeLanguageModel(LargeLanguageModel): diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 2d9a10fa2a..06453cb3f8 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -2,15 +2,15 @@ from enum import Enum from json import dumps, loads from typing import Any, Dict, Generator, List, Union -from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import (BadRequestError, - InsufficientAccountBalanceError, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) from requests import Response, post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema +from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import ( + BadRequestError, + InternalServerError, + InvalidAuthenticationError, +) + class OpenLLMGenerateMessage: class Role(Enum): diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index 2f30427d36..33847c0cb3 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -1,15 +1,22 @@ import time -from json import dumps, loads +from json import dumps from typing import Optional +from requests import post +from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from requests import post -from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema class OpenLLMTextEmbeddingModel(TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py index ad130cabbc..29d8427d8e 100644 --- a/api/core/model_runtime/model_providers/replicate/_common.py +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -1,6 +1,7 @@ -from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError from replicate.exceptions import ModelError, ReplicateError +from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError + class _CommonReplicate: diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 69c0a82636..ce69c67984 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -1,18 +1,30 @@ from typing import Generator, List, Optional, Union -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, - PromptMessageTool, SystemPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, - ParameterRule) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.replicate._common import _CommonReplicate from replicate import Client as ReplicateClient from replicate.exceptions import ReplicateError from replicate.prediction import Prediction +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.replicate._common import _CommonReplicate + class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 37a275614c..a481aebc99 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -2,13 +2,14 @@ import json import time from typing import Optional +from replicate import Client as ReplicateClient + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.replicate._common import _CommonReplicate -from replicate import Client as ReplicateClient class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 33475f5769..6dfa1e3a6b 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -2,10 +2,21 @@ import threading from typing import Generator, List, Optional, Union from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/model_runtime/model_providers/spark/spark.py b/api/core/model_runtime/model_providers/spark/spark.py index c8bea10390..b3695e0501 100644 --- a/api/core/model_runtime/model_providers/spark/spark.py +++ b/api/core/model_runtime/model_providers/spark/spark.py @@ -1,7 +1,5 @@ import logging -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/togetherai/togetherai.py b/api/core/model_runtime/model_providers/togetherai/togetherai.py index e2ede35d69..ffce4794e7 100644 --- a/api/core/model_runtime/model_providers/togetherai/togetherai.py +++ b/api/core/model_runtime/model_providers/togetherai/togetherai.py @@ -1,7 +1,5 @@ import logging -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 63f300fc19..8aac4412fd 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -1,18 +1,36 @@ from typing import Generator, List, Optional, Union -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from dashscope import get_tokenizer from dashscope.api_entities.dashscope_response import DashScopeAPIResponse -from dashscope.common.error import (AuthenticationError, InvalidParameter, RequestFailure, ServiceUnavailableError, - UnsupportedHTTPMethod, UnsupportedModel) +from dashscope.common.error import ( + AuthenticationError, + InvalidParameter, + RequestFailure, + ServiceUnavailableError, + UnsupportedHTTPMethod, + UnsupportedModel, +) from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + from ._client import EnhanceTongyi diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index a6fc201080..3e1608944b 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -4,12 +4,13 @@ from io import BytesIO from typing import Optional import dashscope +from flask import Response, stream_with_context +from pydub import AudioSegment + from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.tongyi._common import _CommonTongyi -from flask import Response, stream_with_context -from pydub import AudioSegment class TongyiText2SpeechModel(_CommonTongyi, TTSModel): diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 65081a9665..f13fd27b91 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -4,13 +4,17 @@ from json import dumps, loads from threading import Lock from typing import Any, Dict, Generator, List, Union -from core.model_runtime.entities.message_entities import PromptMessageTool -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (BadRequestError, InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) from requests import Response, post +from core.model_runtime.entities.message_entities import PromptMessageTool +from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( + BadRequestError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) + # map api_key to access_token baidu_access_tokens: Dict[str, 'BaiduAccessToken'] = {} baidu_access_tokens_lock = Lock() diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index 27b2bce9af..b13e340d91 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -1,17 +1,32 @@ -from typing import Generator, List, Optional, Union, cast +from typing import Generator, List, cast -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (BadRequestError, InsufficientAccountBalance, - InternalServerError, InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) +from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( + BadRequestError, + InsufficientAccountBalance, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) class ErnieBotLarguageModel(LargeLanguageModel): diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 841e197873..7da1b00651 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -1,26 +1,62 @@ -from typing import Generator, Iterator, List, Optional, Union, cast +from typing import Generator, Iterator, List, cast -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, ToolPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelFeature, ModelPropertyKey, - ModelType, ParameterRule, ParameterType) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper, - XinferenceModelExtraParameter) -from core.model_runtime.utils import helper -from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, - NotFoundError, OpenAI, PermissionDeniedError, RateLimitError, Stream, UnprocessableEntityError) +from openai import ( + APIConnectionError, + APITimeoutError, + AuthenticationError, + ConflictError, + InternalServerError, + NotFoundError, + OpenAI, + PermissionDeniedError, + RateLimitError, + UnprocessableEntityError, +) from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall from openai.types.completion import Completion -from xinference_client.client.restful.restful_client import (Client, RESTfulChatglmCppChatModelHandle, - RESTfulChatModelHandle, RESTfulGenerateModelHandle) +from xinference_client.client.restful.restful_client import ( + Client, + RESTfulChatglmCppChatModelHandle, + RESTfulChatModelHandle, + RESTfulGenerateModelHandle, +) + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.xinference.xinference_helper import ( + XinferenceHelper, + XinferenceModelExtraParameter, +) +from core.model_runtime.utils import helper class XinferenceAILargeLanguageModel(LargeLanguageModel): diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 9ec9e09aa0..f1f87d47df 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -1,13 +1,20 @@ from typing import Optional +from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.rerank_model import RerankModel -from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle class XinferenceRerankModel(RerankModel): diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index bfc77db494..a68bc99976 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -1,15 +1,22 @@ import time from typing import Optional +from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper -from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle class XinferenceTextEmbeddingModel(TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 764ffe8b65..089ffd691f 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -3,7 +3,6 @@ from threading import Lock from time import time from typing import List -from requests import get from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, MissingSchema, Timeout from requests.sessions import Session diff --git a/api/core/model_runtime/model_providers/zhipuai/_common.py b/api/core/model_runtime/model_providers/zhipuai/_common.py index b961fe8b24..2574234abf 100644 --- a/api/core/model_runtime/model_providers/zhipuai/_common.py +++ b/api/core/model_runtime/model_providers/zhipuai/_common.py @@ -1,5 +1,11 @@ -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) class _CommonZhipuaiAI: diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index e3180ec177..6d1f462d0f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -1,12 +1,15 @@ -import json -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Generator, List, Optional, Union from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, PromptMessageContentType, PromptMessageRole, - PromptMessageTool, SystemPromptMessage, - TextPromptMessageContent, ToolPromptMessage, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageContentType, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index fd39f5a7a9..30c373729a 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -7,7 +7,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI -from langchain.schema.language_model import _get_token_ids_default_method class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py index dfe52fd54c..7796b778a3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING import httpx from ..core._base_api import BaseAPI -from ..core._base_type import NOT_GIVEN, Body, FileTypes, Headers, NotGiven, Query +from ..core._base_type import NOT_GIVEN, FileTypes, Headers, NotGiven from ..core._files import is_file_content from ..core._http_client import make_user_request_input from ..types.file_object import FileObject, ListOfFileObject diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py index 65ce5b246f..ce852a48c3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Optional import httpx diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py index 3f22731de6..2406e57820 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py @@ -1,8 +1,7 @@ from __future__ import annotations -from typing import Any, Union, cast +from typing import Any, Union -import pydantic.generics from httpx import Timeout from pydantic import ConfigDict from typing_extensions import ClassVar, TypedDict, Unpack diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py index c841e1d756..6197b6faaf 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py @@ -1,7 +1,6 @@ from typing import List, Optional, Union from pydantic import BaseModel -from typing_extensions import Literal __all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py index 1a70483a7b..6ff3f77fd7 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py @@ -1,7 +1,6 @@ from typing import List, Optional, Union from pydantic import BaseModel -from typing_extensions import Literal __all__ = ["FineTuningJobEvent", "Metric", "JobEvent"] diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 82b2f27234..9cafbf17a3 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,9 +1,10 @@ +from pydantic import BaseModel + from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor from core.helper.encrypter import decrypt_token from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult from extensions.ext_database import db from models.api_based_extension import APIBasedExtension -from pydantic import BaseModel class ModerationInputParams(BaseModel): diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 1cce8f18f2..9a369a9f87 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Optional -from core.extension.extensible import Extensible, ExtensionModule from pydantic import BaseModel +from core.extension.extensible import Extensible, ExtensionModule + class ModerationAction(Enum): DIRECT_OUTPUT = 'direct_output' diff --git a/api/core/prompt/output_parser/rule_config_generator.py b/api/core/prompt/output_parser/rule_config_generator.py index 61165d628e..2755910c28 100644 --- a/api/core/prompt/output_parser/rule_config_generator.py +++ b/api/core/prompt/output_parser/rule_config_generator.py @@ -1,7 +1,8 @@ from typing import Any -from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE from langchain.schema import BaseOutputParser, OutputParserException + +from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE from libs.json_in_md_parser import parse_and_check_json_markdown diff --git a/api/core/prompt/output_parser/suggested_questions_after_answer.py b/api/core/prompt/output_parser/suggested_questions_after_answer.py index 49501a2dd7..d8bb0809cf 100644 --- a/api/core/prompt/output_parser/suggested_questions_after_answer.py +++ b/api/core/prompt/output_parser/suggested_questions_after_answer.py @@ -2,9 +2,10 @@ import json import re from typing import Any +from langchain.schema import BaseOutputParser + from core.model_runtime.errors.invoke import InvokeError from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT -from langchain.schema import BaseOutputParser class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 01cad0c1d4..5ffcaaec65 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -4,13 +4,21 @@ import os import re from typing import List, Optional, Tuple, cast -from core.entities.application_entities import (AdvancedCompletionPromptTemplateEntity, ModelConfigEntity, - PromptTemplateEntity) +from core.entities.application_entities import ( + AdvancedCompletionPromptTemplateEntity, + ModelConfigEntity, + PromptTemplateEntity, +) from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_builder import PromptBuilder diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 1c505823d1..6e28247d38 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,21 +3,36 @@ from collections import defaultdict from json import JSONDecodeError from typing import Optional +from sqlalchemy.exc import IntegrityError + from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle -from core.entities.provider_entities import (CustomConfiguration, CustomModelConfiguration, CustomProviderConfiguration, - QuotaConfiguration, SystemConfiguration) +from core.entities.provider_entities import ( + CustomConfiguration, + CustomModelConfiguration, + CustomProviderConfiguration, + QuotaConfiguration, + SystemConfiguration, +) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import (ConfigurateMethod, CredentialFormSchema, FormType, - ProviderEntity) +from core.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FormType, + ProviderEntity, +) from core.model_runtime.model_providers import model_provider_factory from extensions import ext_hosting_provider from extensions.ext_database import db -from models.provider import (Provider, ProviderModel, ProviderQuotaType, ProviderType, TenantDefaultModel, - TenantPreferredModelProvider) -from sqlalchemy.exc import IntegrityError +from models.provider import ( + Provider, + ProviderModel, + ProviderQuotaType, + ProviderType, + TenantDefaultModel, + TenantPreferredModelProvider, +) class ProviderManager: diff --git a/api/core/rerank/rerank.py b/api/core/rerank/rerank.py index 4d2f84b492..984cdb4003 100644 --- a/api/core/rerank/rerank.py +++ b/api/core/rerank/rerank.py @@ -1,8 +1,9 @@ from typing import List, Optional -from core.model_manager import ModelInstance from langchain.schema import Document +from core.model_manager import ModelInstance + class RerankRunner: def __init__(self, rerank_model_instance: ModelInstance) -> None: diff --git a/api/core/spiltter/fixed_text_splitter.py b/api/core/spiltter/fixed_text_splitter.py index a6895998cf..babb360a5e 100644 --- a/api/core/spiltter/fixed_text_splitter.py +++ b/api/core/spiltter/fixed_text_splitter.py @@ -3,11 +3,20 @@ from __future__ import annotations from typing import Any, List, Optional, cast +from langchain.text_splitter import ( + TS, + AbstractSet, + Collection, + Literal, + RecursiveCharacterTextSplitter, + TokenTextSplitter, + Type, + Union, +) + from core.model_manager import ModelInstance from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer -from langchain.text_splitter import (TS, AbstractSet, Collection, Literal, RecursiveCharacterTextSplitter, - TokenTextSplitter, Type, Union) class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/tool/web_reader_tool.py b/api/core/tool/web_reader_tool.py index 18a0e93721..7de0521c03 100644 --- a/api/core/tool/web_reader_tool.py +++ b/api/core/tool/web_reader_tool.py @@ -11,10 +11,6 @@ from typing import Any, Type import requests from bs4 import BeautifulSoup, CData, Comment, NavigableString -from core.chain.llm_chain import LLMChain -from core.data_loader import file_extractor -from core.data_loader.file_extractor import FileExtractor -from core.entities.application_entities import ModelConfigEntity from langchain.chains import RefineDocumentsChain from langchain.chains.summarize import refine_prompts from langchain.schema import Document @@ -24,6 +20,11 @@ from newspaper import Article from pydantic import BaseModel, Field from regex import regex +from core.chain.llm_chain import LLMChain +from core.data_loader import file_extractor +from core.data_loader.file_extractor import FileExtractor +from core.entities.application_entities import ModelConfigEntity + FULL_TEMPLATE = """ TITLE: {title} AUTHORS: {authors} diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index e9d677cbc1..70ca098cd4 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -1,8 +1,9 @@ from typing import Any, Dict, List, Optional -from core.tools.entities.tool_entities import ToolParameter, ToolProviderType from pydantic import BaseModel +from core.tools.entities.tool_entities import ToolParameter, ToolProviderType + class ApiBasedToolBundle(BaseModel): """ diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index d33c1ba8c3..15a0c028bd 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,9 +1,10 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union, cast -from core.tools.entities.common_entities import I18nObject from pydantic import BaseModel, Field +from core.tools.entities.common_entities import I18nObject + class ToolProviderType(Enum): """ diff --git a/api/core/tools/entities/user_entities.py b/api/core/tools/entities/user_entities.py index 265f856eb4..c89d5e280a 100644 --- a/api/core/tools/entities/user_entities.py +++ b/api/core/tools/entities/user_entities.py @@ -1,10 +1,11 @@ from enum import Enum from typing import Dict, List, Optional +from pydantic import BaseModel + from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderCredentials from core.tools.tool.tool import ToolParameter -from pydantic import BaseModel class UserToolProvider(BaseModel): diff --git a/api/core/tools/model/tool_model_manager.py b/api/core/tools/model/tool_model_manager.py index ee17742161..ec24786046 100644 --- a/api/core/tools/model/tool_model_manager.py +++ b/api/core/tools/model/tool_model_manager.py @@ -11,8 +11,13 @@ from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.model.errors import InvokeModelError diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index 03edb365ce..7f7b04996e 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -2,8 +2,12 @@ from typing import Any, Dict, List from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiBasedToolBundle -from core.tools.entities.tool_entities import (ApiProviderAuthType, ToolCredentialsOption, ToolProviderCredentials, - ToolProviderType) +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolCredentialsOption, + ToolProviderCredentials, + ToolProviderType, +) from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.api_tool import ApiTool from core.tools.tool.tool import Tool diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index ccee002185..bf954a6baa 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,9 +1,9 @@ -from core.tools.entities.user_entities import UserToolProvider -from core.tools.entities.tool_entities import ToolProviderType -from typing import List -from yaml import load, FullLoader - import os.path +from typing import List + +from yaml import FullLoader, load + +from core.tools.entities.user_entities import UserToolProvider position = {} diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 25fac01268..f46b7a0823 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -1,10 +1,10 @@ from base64 import b64decode -from os.path import join from typing import Any, Dict, List, Union +from openai import AzureOpenAI + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -from openai import AzureOpenAI class DallE3Tool(BuiltinTool): diff --git a/api/core/tools/provider/builtin/bing/bing.py b/api/core/tools/provider/builtin/bing/bing.py index ab3718387a..07213b8909 100644 --- a/api/core/tools/provider/builtin/bing/bing.py +++ b/api/core/tools/provider/builtin/bing/bing.py @@ -1,9 +1,9 @@ -from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from typing import Any, Dict + from core.tools.errors import ToolProviderCredentialValidationError - from core.tools.provider.builtin.bing.tools.bing_web_search import BingSearchTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController -from typing import Any, Dict, List class BingProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: Dict[str, Any]) -> None: diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index 490a9deee5..b94ee4b459 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -1,10 +1,11 @@ -from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.entities.tool_entities import ToolInvokeMessage - from typing import Any, Dict, List, Union -from os import path + from requests import get +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + class BingSearchTool(BuiltinTool): url = 'https://api.bing.microsoft.com/v7.0/search' diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index 77773bf6bb..813b4abcf2 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -1,4 +1,5 @@ import matplotlib.pyplot as plt + from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.chart.tools.line import LinearChartTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController diff --git a/api/core/tools/provider/builtin/chart/tools/bar.py b/api/core/tools/provider/builtin/chart/tools/bar.py index bfc4936140..6aa84c19c1 100644 --- a/api/core/tools/provider/builtin/chart/tools/bar.py +++ b/api/core/tools/provider/builtin/chart/tools/bar.py @@ -2,6 +2,7 @@ import io from typing import Any, Dict, List, Union import matplotlib.pyplot as plt + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/chart/tools/line.py b/api/core/tools/provider/builtin/chart/tools/line.py index a2b618315b..7b9ef99a3c 100644 --- a/api/core/tools/provider/builtin/chart/tools/line.py +++ b/api/core/tools/provider/builtin/chart/tools/line.py @@ -2,6 +2,7 @@ import io from typing import Any, Dict, List, Union import matplotlib.pyplot as plt + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/chart/tools/pie.py b/api/core/tools/provider/builtin/chart/tools/pie.py index 47c2bb1976..dff48f1fe4 100644 --- a/api/core/tools/provider/builtin/chart/tools/pie.py +++ b/api/core/tools/provider/builtin/chart/tools/pie.py @@ -2,6 +2,7 @@ import io from typing import Any, Dict, List, Union import matplotlib.pyplot as plt + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py index ef62faf8b8..51e7619f37 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py @@ -2,9 +2,10 @@ from base64 import b64decode from os.path import join from typing import Any, Dict, List, Union +from openai import OpenAI + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -from openai import OpenAI class DallE2Tool(BuiltinTool): diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index 43e96d7780..d320ecfb94 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -2,9 +2,10 @@ from base64 import b64decode from os.path import join from typing import Any, Dict, List, Union +from openai import OpenAI + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -from openai import OpenAI class DallE3Tool(BuiltinTool): diff --git a/api/core/tools/provider/builtin/gaode/gaode.py b/api/core/tools/provider/builtin/gaode/gaode.py index a89ea6579a..b55d93e07b 100644 --- a/api/core/tools/provider/builtin/gaode/gaode.py +++ b/api/core/tools/provider/builtin/gaode/gaode.py @@ -1,6 +1,7 @@ import urllib.parse import requests + from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController diff --git a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py index 527ee6ba25..bdb87c6363 100644 --- a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py +++ b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py @@ -2,6 +2,7 @@ import json from typing import Any, Dict, List, Union import requests + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/github/github.py b/api/core/tools/provider/builtin/github/github.py index b224a82a4b..9275504208 100644 --- a/api/core/tools/provider/builtin/github/github.py +++ b/api/core/tools/provider/builtin/github/github.py @@ -1,4 +1,5 @@ import requests + from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController diff --git a/api/core/tools/provider/builtin/github/tools/github_repositories.py b/api/core/tools/provider/builtin/github/tools/github_repositories.py index 10ca5a5cd7..3ee3660d34 100644 --- a/api/core/tools/provider/builtin/github/tools/github_repositories.py +++ b/api/core/tools/provider/builtin/github/tools/github_repositories.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Union from urllib.parse import quote import requests + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/google/google.py b/api/core/tools/provider/builtin/google/google.py index 0ac34af33d..6dba77953a 100644 --- a/api/core/tools/provider/builtin/google/google.py +++ b/api/core/tools/provider/builtin/google/google.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool diff --git a/api/core/tools/provider/builtin/google/tools/google_search.py b/api/core/tools/provider/builtin/google/tools/google_search.py index 81ff20fdc1..5c7a63751c 100644 --- a/api/core/tools/provider/builtin/google/tools/google_search.py +++ b/api/core/tools/provider/builtin/google/tools/google_search.py @@ -2,9 +2,10 @@ import os import sys from typing import Any, Dict, List, Union +from serpapi import GoogleSearch + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -from serpapi import GoogleSearch class HiddenPrints: diff --git a/api/core/tools/provider/builtin/maths/tools/eval_expression.py b/api/core/tools/provider/builtin/maths/tools/eval_expression.py index b551e5a0c7..b0135bd309 100644 --- a/api/core/tools/provider/builtin/maths/tools/eval_expression.py +++ b/api/core/tools/provider/builtin/maths/tools/eval_expression.py @@ -1,11 +1,11 @@ import logging -from datetime import datetime, timezone from typing import Any, Dict, List, Union +import numexpr as ne + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -from pytz import timezone as pytz_timezone -import numexpr as ne + class EvaluateExpressionTool(BuiltinTool): def _invoke(self, diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 2899235f76..46c8b152cd 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -5,12 +5,13 @@ from copy import deepcopy from os.path import join from typing import Any, Dict, List, Union +from httpx import get, post +from PIL import Image + from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption from core.tools.errors import ToolProviderCredentialValidationError from core.tools.tool.builtin_tool import BuiltinTool -from httpx import get, post -from PIL import Image DRAW_TEXT_OPTIONS = { "prompt": "", diff --git a/api/core/tools/provider/builtin/time/tools/current_time.py b/api/core/tools/provider/builtin/time/tools/current_time.py index 07a1f94b31..e3903bbbe0 100644 --- a/api/core/tools/provider/builtin/time/tools/current_time.py +++ b/api/core/tools/provider/builtin/time/tools/current_time.py @@ -1,9 +1,10 @@ from datetime import datetime, timezone from typing import Any, Dict, List, Union +from pytz import timezone as pytz_timezone + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -from pytz import timezone as pytz_timezone class CurrentTimeTool(BuiltinTool): diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py index 8051729825..0ac08a1d7f 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py @@ -1,11 +1,12 @@ from base64 import b64decode from typing import Any, Dict, List, Union +from httpx import post + from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.vectorizer.tools.test_data import VECTORIZER_ICON_PNG from core.tools.tool.builtin_tool import BuiltinTool -from httpx import post class VectorizerTool(BuiltinTool): diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.py b/api/core/tools/provider/builtin/webscraper/webscraper.py index aef35e02a4..63fd1253e8 100644 --- a/api/core/tools/provider/builtin/webscraper/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/webscraper.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.webscraper.tools.webscraper import WebscraperTool diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index 3cab06fdd5..b7563973ff 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -1,11 +1,12 @@ from typing import Any, Dict, List, Union -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool from langchain import WikipediaAPIWrapper from langchain.tools import WikipediaQueryRun from pydantic import BaseModel, Field +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + class WikipediaInput(BaseModel): query: str = Field(..., description="search query.") diff --git a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py index d915044d89..6152137d93 100644 --- a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Union +from httpx import get + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.errors import ToolInvokeError, ToolProviderCredentialValidationError from core.tools.tool.builtin_tool import BuiltinTool -from httpx import get class WolframAlphaTool(BuiltinTool): diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py index 55d827093c..d98c084f77 100644 --- a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py @@ -1,10 +1,8 @@ -from typing import Any, Dict, List +from typing import Any, Dict -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.wolframalpha.tools.wolframalpha import WolframAlphaTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController -from core.tools.tool.tool import Tool class GoogleProvider(BuiltinToolProviderController): diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py index 85608b5324..74504b25a2 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/analytics.py +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py @@ -2,11 +2,12 @@ from datetime import datetime from typing import Any, Dict, List, Union import pandas as pd -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool from requests.exceptions import HTTPError, ReadTimeout from yfinance import download +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + class YahooFinanceAnalyticsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \ diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py index 79de87e0f3..f1e4070974 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/news.py +++ b/api/core/tools/provider/builtin/yahoo/tools/news.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Union import yfinance +from requests.exceptions import HTTPError, ReadTimeout + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -from requests.exceptions import HTTPError, ReadTimeout class YahooFinanceSearchTickerTool(BuiltinTool): diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py index b28cee0374..8064ae49b4 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/ticker.py +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Union +from requests.exceptions import HTTPError, ReadTimeout +from yfinance import Ticker + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -from requests.exceptions import HTTPError, ReadTimeout -from yfinance import Ticker class YahooFinanceSearchTickerTool(BuiltinTool): diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py index 5922c37dda..b2da4b2d3d 100644 --- a/api/core/tools/provider/builtin/youtube/tools/videos.py +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -1,9 +1,10 @@ from datetime import datetime from typing import Any, Dict, List, Union +from googleapiclient.discovery import build + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -from googleapiclient.discovery import build class YoutubeVideosAnalyticsTool(BuiltinTool): diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 14eefebf49..c1a58e536c 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -3,14 +3,19 @@ from abc import abstractmethod from os import listdir, path from typing import Any, Dict, List +from yaml import FullLoader, load + from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType from core.tools.entities.user_entities import UserToolProviderCredentials -from core.tools.errors import (ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError, - ToolProviderNotFoundError) +from core.tools.errors import ( + ToolNotFoundError, + ToolParameterValidationError, + ToolProviderCredentialValidationError, + ToolProviderNotFoundError, +) from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool -from yaml import FullLoader, load class BuiltinToolProviderController(ToolProviderController): diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index 954727f774..6f4370354e 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -1,12 +1,17 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from core.tools.entities.tool_entities import (ToolParameter, ToolProviderCredentials, ToolProviderIdentity, - ToolProviderType) +from pydantic import BaseModel + +from core.tools.entities.tool_entities import ( + ToolParameter, + ToolProviderCredentials, + ToolProviderIdentity, + ToolProviderType, +) from core.tools.entities.user_entities import UserToolProviderCredentials from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.tool.tool import Tool -from pydantic import BaseModel class ToolProviderController(BaseModel, ABC): diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index c72a6ec183..be465476b6 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Union import httpx import requests + import core.helper.ssrf_proxy as ssrf_proxy from core.tools.entities.tool_bundle import ApiBasedToolBundle from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index e862d427af..3836c89a27 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -1,4 +1,3 @@ -from enum import Enum from typing import List from core.model_runtime.entities.llm_entities import LLMResult diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index e205401686..43174feed9 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,7 +1,10 @@ -import json import threading from typing import List, Optional, Type +from flask import Flask, current_app +from langchain.tools import BaseTool +from pydantic import BaseModel, Field + from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.embedding.cached_embedding import CacheEmbedding from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError @@ -10,10 +13,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rerank.rerank import RerankRunner from extensions.ext_database import db -from flask import Flask, current_app -from langchain.tools import BaseTool from models.dataset import Dataset, Document, DocumentSegment -from pydantic import BaseModel, Field from services.retrieval_service import RetrievalService default_retrieval_model = { diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 79de38ca14..699be64867 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,6 +1,10 @@ import threading from typing import List, Optional, Type +from flask import current_app +from langchain.tools import BaseTool +from pydantic import BaseModel, Field + from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.embedding.cached_embedding import CacheEmbedding from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex @@ -9,10 +13,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rerank.rerank import RerankRunner from extensions.ext_database import db -from flask import current_app -from langchain.tools import BaseTool from models.dataset import Dataset, Document, DocumentSegment -from pydantic import BaseModel, Field from services.retrieval_service import RetrievalService default_retrieval_model = { diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 13b9a8497b..deb2d3e18b 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List + +from langchain.tools import BaseTool from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom @@ -6,7 +8,6 @@ from core.features.dataset_retrieval import DatasetRetrievalFeature from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter from core.tools.tool.tool import Tool -from langchain.tools import BaseTool class DatasetRetrieverTool(Tool): diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 95a8f8578f..3c96ef2fe9 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -2,12 +2,20 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Union -from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler -from core.tools.entities.tool_entities import (ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter, - ToolRuntimeImageVariable, ToolRuntimeVariable, ToolRuntimeVariablePool) -from core.tools.tool_file_manager import ToolFileManager from pydantic import BaseModel +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.tools.entities.tool_entities import ( + ToolDescription, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolRuntimeImageVariable, + ToolRuntimeVariable, + ToolRuntimeVariablePool, +) +from core.tools.tool_file_manager import ToolFileManager + class Tool(BaseModel, ABC): identity: ToolIdentity = None diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index c373399606..ac028a3222 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -8,10 +8,11 @@ from mimetypes import guess_extension, guess_type from typing import Generator, Tuple, Union from uuid import uuid4 -from extensions.ext_database import db -from extensions.ext_storage import storage from flask import current_app from httpx import get + +from extensions.ext_database import db +from extensions.ext_storage import storage from models.model import MessageFile from models.tools import ToolFile diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index a43e5f218e..6dd8fec1a8 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,10 +1,12 @@ -from typing import Dict, Any +from typing import Any, Dict + from pydantic import BaseModel +from core.helper import encrypter +from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.tools.entities.tool_entities import ToolProviderCredentials from core.tools.provider.tool_provider import ToolProviderController -from core.helper import encrypter -from core.helper.tool_provider_cache import ToolProviderCredentialsCacheType, ToolProviderCredentialsCache + class ToolConfiguration(BaseModel): tenant_id: str diff --git a/api/core/tools/utils/encoder.py b/api/core/tools/utils/encoder.py index cce50fb0ac..4387e4cc03 100644 --- a/api/core/tools/utils/encoder.py +++ b/api/core/tools/utils/encoder.py @@ -1,4 +1,3 @@ -from enum import Enum from typing import List from pydantic import BaseModel diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 48b84d630c..0e0b015469 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,15 +1,15 @@ -from json import dumps as json_dumps from json import loads as json_loads from typing import List, Tuple -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_bundle import ApiBasedToolBundle -from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter, ToolParameterOption -from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError from requests import get from yaml import FullLoader, load +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiBasedToolBundle +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter +from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError + class ApiBasedToolSchemaParser: @staticmethod diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 5a769f6f2d..5e361fd14d 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -11,10 +11,6 @@ from typing import Any, Type import requests from bs4 import BeautifulSoup, CData, Comment, NavigableString -from core.chain.llm_chain import LLMChain -from core.data_loader import file_extractor -from core.data_loader.file_extractor import FileExtractor -from core.entities.application_entities import ModelConfigEntity from langchain.chains import RefineDocumentsChain from langchain.chains.summarize import refine_prompts from langchain.schema import Document @@ -24,6 +20,11 @@ from newspaper import Article from pydantic import BaseModel, Field from regex import regex +from core.chain.llm_chain import LLMChain +from core.data_loader import file_extractor +from core.data_loader.file_extractor import FileExtractor +from core.entities.application_entities import ModelConfigEntity + FULL_TEMPLATE = """ TITLE: {title} AUTHORS: {authors} diff --git a/api/core/vector_store/qdrant_vector_store.py b/api/core/vector_store/qdrant_vector_store.py index 06544766f3..53ad7b2aae 100644 --- a/api/core/vector_store/qdrant_vector_store.py +++ b/api/core/vector_store/qdrant_vector_store.py @@ -1,10 +1,11 @@ from typing import Any, cast -from core.vector_store.vector.qdrant import Qdrant from langchain.schema import Document from qdrant_client.http.models import Filter, FilterSelector, PointIdsList from qdrant_client.local.qdrant_local import QdrantLocal +from core.vector_store.vector.qdrant import Qdrant + class QdrantVectorStore(Qdrant): def del_texts(self, filter: Filter): diff --git a/api/core/vector_store/vector/qdrant.py b/api/core/vector_store/vector/qdrant.py index b9c32f59d4..1e85824f82 100644 --- a/api/core/vector_store/vector/qdrant.py +++ b/api/core/vector_store/vector/qdrant.py @@ -14,7 +14,7 @@ from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance -from qdrant_client.http.models import FilterSelector, PayloadSchemaType, TextIndexParams, TextIndexType, TokenizerType +from qdrant_client.http.models import PayloadSchemaType, TextIndexParams, TextIndexType, TokenizerType if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -1396,9 +1396,7 @@ class Qdrant(VectorStore): "Could not import qdrant-client python package. " "Please install it with `pip install qdrant-client`." ) - from grpc import RpcError from qdrant_client.http import models as rest - from qdrant_client.http.exceptions import UnexpectedResponse # Just do a single quick embedding to get vector size partial_embeddings = embedding.embed_documents(texts[:1]) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 058a9b4b00..0b281c9271 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -3,13 +3,12 @@ import logging import time import click -from celery import shared_task +from werkzeug.exceptions import NotFound + from core.indexing_runner import DocumentIsPausedException, IndexingRunner -from events.dataset_event import dataset_was_deleted from events.event_handlers.document_index_event import document_index_created from extensions.ext_database import db from models.dataset import Document -from werkzeug.exceptions import NotFound @document_index_created.connect diff --git a/api/extensions/ext_hosting_provider.py b/api/extensions/ext_hosting_provider.py index 5752ec7f4c..49e2fcb0c7 100644 --- a/api/extensions/ext_hosting_provider.py +++ b/api/extensions/ext_hosting_provider.py @@ -1,6 +1,7 @@ -from core.hosting_configuration import HostingConfiguration from flask import Flask +from core.hosting_configuration import HostingConfiguration + hosting_configuration = HostingConfiguration() diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 67cb8e4ea9..5974de34de 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,4 +1,5 @@ from flask_restful import fields + from libs.helper import TimestampField account_fields = { diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index 2ccc9ddfe0..749e9900de 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -1,4 +1,5 @@ from flask_restful import fields + from libs.helper import TimestampField diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 63e8f5b16a..e6c1272086 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,4 +1,5 @@ from flask_restful import fields + from libs.helper import TimestampField app_detail_kernel_fields = { diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index f7298933c1..4479dc5b7a 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,4 +1,5 @@ from flask_restful import fields + from libs.helper import TimestampField diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index a7035018df..6f3c920c85 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -1,4 +1,5 @@ from flask_restful import fields + from libs.helper import TimestampField integrate_icon_fields = { diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 1382871ae7..eb2ccb8f9f 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -1,4 +1,5 @@ from flask_restful import fields + from libs.helper import TimestampField dataset_fields = { diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index bf115659ef..94d905eafe 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -1,5 +1,6 @@ -from fields.dataset_fields import dataset_fields from flask_restful import fields + +from fields.dataset_fields import dataset_fields from libs.helper import TimestampField document_fields = { diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 5f2322003a..2ef379dabc 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,4 +1,5 @@ from flask_restful import fields + from libs.helper import TimestampField upload_config_fields = { diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index bb8805417e..541e56a378 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -1,4 +1,5 @@ from flask_restful import fields + from libs.helper import TimestampField document_fields = { diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index f1c2377f46..821d3c0ade 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -1,4 +1,5 @@ from flask_restful import fields + from libs.helper import TimestampField app_fields = { diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 59a029e3e2..21b2e8e9e2 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,5 +1,6 @@ -from fields.conversation_fields import message_file_fields from flask_restful import fields + +from fields.conversation_fields import message_file_fields from libs.helper import TimestampField feedback_fields = { diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 8c07c09321..e41d1a53dd 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -1,4 +1,5 @@ from flask_restful import fields + from libs.helper import TimestampField segment_fields = { diff --git a/api/libs/login.py b/api/libs/login.py index 06c6a837af..5c03cfe957 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,14 +1,15 @@ import os from functools import wraps -from extensions.ext_database import db -from flask import current_app, g, has_request_context, request, session +from flask import current_app, g, has_request_context, request from flask_login import user_logged_in from flask_login.config import EXEMPT_METHODS -from models.account import Account, Tenant, TenantAccountJoin from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy +from extensions.ext_database import db +from models.account import Account, Tenant, TenantAccountJoin + #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user current_user = LocalProxy(lambda: _get_user()) diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 2a91d9941a..dacdee0bc1 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,11 +1,7 @@ -import json import urllib.parse from dataclasses import dataclass import requests -from extensions.ext_database import db -from flask_login import current_user -from models.source import DataSourceBinding @dataclass diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 1cf84e808a..7891b01182 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,9 +1,9 @@ -import json import urllib.parse import requests -from extensions.ext_database import db from flask_login import current_user + +from extensions.ext_database import db from models.source import DataSourceBinding diff --git a/api/libs/rsa.py b/api/libs/rsa.py index e2ffed8d05..9f499ff95c 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -1,10 +1,11 @@ # -*- coding:utf-8 -*- import hashlib -import libs.gmpy2_pkcs10aep_cipher as gmpy2_pkcs10aep_cipher from Crypto.Cipher import AES from Crypto.PublicKey import RSA from Crypto.Random import get_random_bytes + +import libs.gmpy2_pkcs10aep_cipher as gmpy2_pkcs10aep_cipher from extensions.ext_redis import redis_client from extensions.ext_storage import storage diff --git a/api/migrations/versions/16830a790f0f_.py b/api/migrations/versions/16830a790f0f_.py index fd1eaedf67..38d6e4940a 100644 --- a/api/migrations/versions/16830a790f0f_.py +++ b/api/migrations/versions/16830a790f0f_.py @@ -5,9 +5,8 @@ Revises: 380c6aa5a70d Create Date: 2024-02-01 08:21:31.111119 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = '16830a790f0f' diff --git a/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py b/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py index adff497e0b..13a823f7ec 100644 --- a/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py +++ b/api/migrations/versions/187385f442fc_modify_provider_model_name_length.py @@ -7,7 +7,6 @@ Create Date: 2024-01-02 07:18:43.887428 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = '187385f442fc' diff --git a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py index 790ade64ec..2405021856 100644 --- a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py +++ b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py @@ -5,7 +5,6 @@ Revises: 114eed84c228 Create Date: 2024-01-12 03:42:27.362415 """ -import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql diff --git a/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py b/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py index dfb59839b4..5a8476501b 100644 --- a/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py +++ b/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py @@ -7,7 +7,6 @@ Create Date: 2023-08-19 17:01:57.471562 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = '853f9b9cd3b6' diff --git a/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py b/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py index 430a8c78c2..f4c4ebb51b 100644 --- a/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py +++ b/api/migrations/versions/8ae9bc661daa_add_tool_conversation_variables_idx.py @@ -5,7 +5,6 @@ Revises: 9fafbd60eca1 Create Date: 2024-01-15 14:22:03.597692 """ -import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. diff --git a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py index cf296628a9..5dcb630aed 100644 --- a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py +++ b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py @@ -7,7 +7,6 @@ Create Date: 2023-11-02 04:04:57.609485 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'a9836e3baeee' diff --git a/api/migrations/versions/dfb3b7f477da_add_tool_index.py b/api/migrations/versions/dfb3b7f477da_add_tool_index.py index 3ef03595fe..e14a65a1ff 100644 --- a/api/migrations/versions/dfb3b7f477da_add_tool_index.py +++ b/api/migrations/versions/dfb3b7f477da_add_tool_index.py @@ -5,7 +5,6 @@ Revises: b24be59fbb04 Create Date: 2024-01-24 02:17:01.631635 """ -import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. diff --git a/api/models/account.py b/api/models/account.py index 21fc998185..55d8514220 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -2,10 +2,11 @@ import enum import json from typing import List -from extensions.ext_database import db from flask_login import UserMixin from sqlalchemy.dialects.postgresql import UUID +from extensions.ext_database import db + class AccountStatus(str, enum.Enum): PENDING = 'pending' diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 200675d766..e34cfb8f7b 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,8 +1,9 @@ import enum -from extensions.ext_database import db from sqlalchemy.dialects.postgresql import UUID +from extensions.ext_database import db + class APIBasedExtensionPoint(enum.Enum): APP_EXTERNAL_DATA_TOOL_QUERY = 'app.external_data_tool.query' diff --git a/api/models/dataset.py b/api/models/dataset.py index 06fd55aeca..d31e49f6ca 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -2,11 +2,12 @@ import json import pickle from json import JSONDecodeError +from sqlalchemy import func +from sqlalchemy.dialects.postgresql import JSONB, UUID + from extensions.ext_database import db from models.account import Account from models.model import App, UploadFile -from sqlalchemy import func -from sqlalchemy.dialects.postgresql import JSONB, UUID class Dataset(db.Model): diff --git a/api/models/model.py b/api/models/model.py index afe2bc1628..d642d9a397 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,14 +1,15 @@ import json import uuid +from flask import current_app, request +from flask_login import UserMixin +from sqlalchemy import Float, text +from sqlalchemy.dialects.postgresql import UUID + from core.file.tool_file_parser import ToolFileParser from core.file.upload_file_parser import UploadFileParser from extensions.ext_database import db -from flask import current_app, request -from flask_login import UserMixin from libs.helper import generate_string -from sqlalchemy import Float, text -from sqlalchemy.dialects.postgresql import UUID from .account import Account, Tenant diff --git a/api/models/provider.py b/api/models/provider.py index 514ea47cff..4c9fd793cc 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,8 +1,9 @@ from enum import Enum -from extensions.ext_database import db from sqlalchemy.dialects.postgresql import UUID +from extensions.ext_database import db + class ProviderType(Enum): CUSTOM = 'custom' diff --git a/api/models/source.py b/api/models/source.py index 9923e38b4b..8afe0f9522 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,6 +1,7 @@ -from extensions.ext_database import db from sqlalchemy.dialects.postgresql import JSONB, UUID +from extensions.ext_database import db + class DataSourceBinding(db.Model): __tablename__ = 'data_source_bindings' diff --git a/api/models/task.py b/api/models/task.py index fd4105b2ea..2a1bfa124f 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,6 +1,7 @@ from datetime import datetime from celery import states + from extensions.ext_database import db diff --git a/api/models/tool.py b/api/models/tool.py index 0e25659980..ac866e20a4 100644 --- a/api/models/tool.py +++ b/api/models/tool.py @@ -1,9 +1,10 @@ import json from enum import Enum -from extensions.ext_database import db from sqlalchemy.dialects.postgresql import UUID +from extensions.ext_database import db + class ToolProviderName(Enum): SERPAPI = 'serpapi' diff --git a/api/models/tools.py b/api/models/tools.py index 6b074786a2..10f572cdd0 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,14 +1,14 @@ import json -from enum import Enum from typing import List +from sqlalchemy import ForeignKey +from sqlalchemy.dialects.postgresql import UUID + from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiBasedToolBundle -from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolRuntimeVariablePool +from core.tools.entities.tool_entities import ApiProviderSchemaType from extensions.ext_database import db from models.model import Account, App, Tenant -from sqlalchemy import ForeignKey -from sqlalchemy.dialects.postgresql import UUID class BuiltinToolProvider(db.Model): diff --git a/api/models/web.py b/api/models/web.py index 2957703a20..b2466430b9 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,6 +1,7 @@ +from sqlalchemy.dialects.postgresql import UUID + from extensions.ext_database import db from models.model import Message -from sqlalchemy.dialects.postgresql import UUID class SavedMessage(db.Model): diff --git a/api/pyproject.toml b/api/pyproject.toml new file mode 100644 index 0000000000..2061092ef5 --- /dev/null +++ b/api/pyproject.toml @@ -0,0 +1,17 @@ +[project] +requires-python = ">=3.10" + +[tool.ruff] +exclude = [ + "__init__.py", + "tests/", +] +line-length = 120 + +[tool.ruff.lint] +ignore-init-module-imports = true +select = [ + "F401", # unused-import + "I001", # unsorted-imports + "I002", # missing-required-import +] diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 53d9500bab..0daf651d2f 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -1,13 +1,14 @@ import datetime import time -import app import click -from extensions.ext_database import db from flask import current_app -from models.dataset import Embedding from werkzeug.exceptions import NotFound +import app +from extensions.ext_database import db +from models.dataset import Embedding + @app.celery.task(queue='dataset') def clean_embedding_cache_task(): diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index f5ba46463f..5db863fe8d 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -1,14 +1,14 @@ import datetime -import logging import time -import app import click +from flask import current_app +from werkzeug.exceptions import NotFound + +import app from core.index.index import IndexBuilder from extensions.ext_database import db -from flask import current_app -from models.dataset import Dataset, DatasetCollectionBinding, DatasetQuery, Document -from werkzeug.exceptions import NotFound +from models.dataset import Dataset, DatasetQuery, Document @app.celery.task(queue='dataset') diff --git a/api/services/account_service.py b/api/services/account_service.py index 0999d700b6..b5934aafb9 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -8,23 +8,33 @@ from datetime import datetime, timedelta from hashlib import sha256 from typing import Any, Dict, Optional +from flask import current_app +from sqlalchemy import func +from werkzeug.exceptions import Forbidden + from constants.languages import language_timezone_mapping, languages from events.tenant_event import tenant_was_created from extensions.ext_redis import redis_client -from flask import current_app from libs.helper import get_remote_ip from libs.passport import PassportService from libs.password import compare_password, hash_password from libs.rsa import generate_key_pair from models.account import * -from services.errors.account import (AccountAlreadyInTenantError, AccountLoginError, AccountNotLinkTenantError, - AccountRegisterError, CannotOperateSelfError, CurrentPasswordIncorrectError, - InvalidActionError, LinkAccountIntegrateError, MemberNotInTenantError, - NoPermissionError, RoleAlreadyAssignedError, TenantNotFound) -from sqlalchemy import func +from services.errors.account import ( + AccountAlreadyInTenantError, + AccountLoginError, + AccountNotLinkTenantError, + AccountRegisterError, + CannotOperateSelfError, + CurrentPasswordIncorrectError, + InvalidActionError, + LinkAccountIntegrateError, + MemberNotInTenantError, + NoPermissionError, + RoleAlreadyAssignedError, + TenantNotFound, +) from tasks.mail_invite_member_task import send_invite_member_mail_task -from werkzeug.exceptions import Forbidden -from sqlalchemy import exc def _create_tenant_for_account(account) -> Tenant: diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index d3cd911125..d52f6e20c2 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,13 +1,18 @@ import copy -from core.prompt.advanced_prompt_templates import (BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, - BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, - BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, - BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CONTEXT, - CHAT_APP_CHAT_PROMPT_CONFIG, CHAT_APP_COMPLETION_PROMPT_CONFIG, - COMPLETION_APP_CHAT_PROMPT_CONFIG, - COMPLETION_APP_COMPLETION_PROMPT_CONFIG, CONTEXT) +from core.prompt.advanced_prompt_templates import ( + BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_CONTEXT, + CHAT_APP_CHAT_PROMPT_CONFIG, + CHAT_APP_COMPLETION_PROMPT_CONFIG, + COMPLETION_APP_CHAT_PROMPT_CONFIG, + COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + CONTEXT, +) from core.prompt.prompt_transform import AppMode diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 1ffe44910b..0a9e835586 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,21 +1,21 @@ import datetime -import json import uuid import pandas as pd +from flask_login import current_user +from sqlalchemy import or_ +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import NotFound + from extensions.ext_database import db from extensions.ext_redis import redis_client -from flask_login import current_user from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation -from sqlalchemy import or_ from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task from tasks.annotation.disable_annotation_reply_task import disable_annotation_reply_task from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task -from werkzeug.datastructures import FileStorage -from werkzeug.exceptions import NotFound class AppAnnotationService: diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 4cc965bacf..ba6bf1ab6f 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -1,12 +1,17 @@ import io from typing import Optional +from werkzeug.datastructures import FileStorage + from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, - ProviderNotSupportSpeechToTextServiceError, - ProviderNotSupportTextToSpeechServiceError, UnsupportedAudioTypeServiceError) -from werkzeug.datastructures import FileStorage +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + ProviderNotSupportTextToSpeechServiceError, + UnsupportedAudioTypeServiceError, +) FILE_SIZE = 15 FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024 diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 00a460b35b..37d362d083 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,6 +1,7 @@ import os import requests + from extensions.ext_database import db from models.account import TenantAccountJoin diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 6035eb1b50..7d925ec9b3 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -1,6 +1,8 @@ import json from typing import Any, Generator, Union +from sqlalchemy import and_ + from core.application_manager import ApplicationManager from core.entities.application_entities import InvokeFrom from core.file.message_file_parser import MessageFileParser @@ -11,7 +13,6 @@ from services.errors.app import MoreLikeThisDisabledError from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError from services.errors.message import MessageNotExistsError -from sqlalchemy import and_ class CompletionService: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 7cdded1ae8..8e587f07a7 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,6 +6,10 @@ import time import uuid from typing import List, Optional, cast +from flask import current_app +from flask_login import current_user +from sqlalchemy import func + from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.index.index import IndexBuilder from core.model_manager import ModelManager @@ -15,12 +19,17 @@ from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db from extensions.ext_redis import redis_client -from flask import current_app -from flask_login import current_user from libs import helper from models.account import Account -from models.dataset import (AppDatasetJoin, Dataset, DatasetCollectionBinding, DatasetProcessRule, DatasetQuery, - Document, DocumentSegment) +from models.dataset import ( + AppDatasetJoin, + Dataset, + DatasetCollectionBinding, + DatasetProcessRule, + DatasetQuery, + Document, + DocumentSegment, +) from models.model import UploadFile from models.source import DataSourceBinding from services.errors.account import NoPermissionError @@ -28,7 +37,6 @@ from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError from services.vector_service import VectorService -from sqlalchemy import func from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 58f00135dc..6cdd5090ae 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,16 +1,21 @@ from enum import Enum from typing import Optional +from flask import current_app +from pydantic import BaseModel + from core.entities.model_entities import ModelStatus, ModelWithProviderEntity from core.entities.provider_entities import QuotaConfiguration from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType, ProviderModel -from core.model_runtime.entities.provider_entities import (ConfigurateMethod, ModelCredentialSchema, - ProviderCredentialSchema, ProviderHelpEntity, - SimpleProviderEntity) -from flask import current_app +from core.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) from models.provider import ProviderQuotaType, ProviderType -from pydantic import BaseModel class CustomConfigurationStatus(Enum): diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 75feaf7800..14d262de7c 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -1,5 +1,6 @@ from flask import current_app from pydantic import BaseModel + from services.billing_service import BillingService diff --git a/api/services/file_service.py b/api/services/file_service.py index fd95cff74a..b796fbfb51 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -3,17 +3,18 @@ import hashlib import uuid from typing import Generator, Tuple, Union +from flask import current_app +from flask_login import current_user +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import NotFound + from core.data_loader.file_extractor import FileExtractor from core.file.upload_file_parser import UploadFileParser from extensions.ext_database import db from extensions.ext_storage import storage -from flask import current_app -from flask_login import current_user from models.account import Account from models.model import EndUser, UploadFile from services.errors.file import FileTooLargeError, UnsupportedFileTypeError -from werkzeug.datastructures import FileStorage -from werkzeug.exceptions import NotFound IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 0f38241948..7d2d57476d 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -4,18 +4,19 @@ import time from typing import List import numpy as np +from flask import current_app +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +from sklearn.manifold import TSNE + from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rerank.rerank import RerankRunner from extensions.ext_database import db -from flask import current_app -from langchain.embeddings.base import Embeddings -from langchain.schema import Document from models.account import Account from models.dataset import Dataset, DatasetQuery, DocumentSegment from services.retrieval_service import RetrievalService -from sklearn.manifold import TSNE default_retrieval_model = { 'search_method': 'semantic_search', diff --git a/api/services/message_service.py b/api/services/message_service.py index 79feb1c669..fa34057d07 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -12,8 +12,12 @@ from models.model import App, AppModelConfig, EndUser, Message, MessageFeedback from services.conversation_service import ConversationService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError -from services.errors.message import (FirstMessageNotExistsError, LastMessageNotExistsError, MessageNotExistsError, - SuggestedQuestionsAfterAnswerDisabledError) +from services.errors.message import ( + FirstMessageNotExistsError, + LastMessageNotExistsError, + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) class MessageService: diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 906ef57a39..7b2b9049f7 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -4,18 +4,25 @@ import os from typing import Optional, Tuple, cast import requests +from flask import current_app + from core.entities.model_entities import ModelStatus from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.provider_manager import ProviderManager -from flask import current_app from models.provider import ProviderType -from services.entities.model_provider_entities import (CustomConfigurationResponse, CustomConfigurationStatus, - DefaultModelResponse, ModelResponse, - ModelWithProviderEntityResponse, ProviderResponse, - ProviderWithModelsResponse, SimpleProviderEntityResponse, - SystemConfigurationResponse) +from services.entities.model_provider_entities import ( + CustomConfigurationResponse, + CustomConfigurationStatus, + DefaultModelResponse, + ModelResponse, + ModelWithProviderEntityResponse, + ProviderResponse, + ProviderWithModelsResponse, + SimpleProviderEntityResponse, + SystemConfigurationResponse, +) logger = logging.getLogger(__name__) diff --git a/api/services/retrieval_service.py b/api/services/retrieval_service.py index 2efa0dbee4..bc8f4ad5be 100644 --- a/api/services/retrieval_service.py +++ b/api/services/retrieval_service.py @@ -1,13 +1,14 @@ from typing import Optional +from flask import Flask, current_app +from langchain.embeddings.base import Embeddings + from core.index.vector_index.vector_index import VectorIndex from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rerank.rerank import RerankRunner from extensions.ext_database import db -from flask import Flask, current_app -from langchain.embeddings.base import Embeddings from models.dataset import Dataset default_retrieval_model = { diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index ac21d1d9d6..e5f800f53e 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -1,10 +1,17 @@ import json -from typing import List, Tuple +from typing import List + +from flask import current_app +from httpx import get from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiBasedToolBundle -from core.tools.entities.tool_entities import (ApiProviderAuthType, ApiProviderSchemaType, ToolCredentialsOption, - ToolProviderCredentials) +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ApiProviderSchemaType, + ToolCredentialsOption, + ToolProviderCredentials, +) from core.tools.entities.user_entities import UserTool, UserToolProvider from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiBasedToolProviderController @@ -14,8 +21,6 @@ from core.tools.utils.configuration import ToolConfiguration from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db -from flask import current_app -from httpx import get from models.tools import ApiToolProvider, BuiltinToolProvider diff --git a/api/services/vector_service.py b/api/services/vector_service.py index ee06bd175a..9c30c3f41d 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,8 +1,9 @@ from typing import List, Optional -from core.index.index import IndexBuilder from langchain.schema import Document + +from core.index.index import IndexBuilder from models.dataset import Dataset, DocumentSegment diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 1bdf9d8631..923e44dd85 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,8 +1,7 @@ -from extensions.ext_database import db -from flask import current_app from flask_login import current_user + +from extensions.ext_database import db from models.account import Tenant, TenantAccountJoin, TenantAccountJoinRole -from models.provider import Provider from services.account_service import TenantService from services.feature_service import FeatureService diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index ea5c17b487..ae235a2a63 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -4,13 +4,14 @@ import time import click from celery import shared_task +from langchain.schema import Document +from werkzeug.exceptions import NotFound + from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client -from langchain.schema import Document from models.dataset import Document as DatasetDocument from models.dataset import DocumentSegment -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 4c0e13feb3..61529f9bde 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -3,8 +3,9 @@ import time import click from celery import shared_task -from core.index.index import IndexBuilder from langchain.schema import Document + +from core.index.index import IndexBuilder from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index a7d026b15e..5b6c45b4f3 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -1,17 +1,17 @@ -import json import logging import time import click from celery import shared_task +from langchain.schema import Document +from werkzeug.exceptions import NotFound + from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client -from langchain.schema import Document from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index f4afb2383f..852f899512 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -1,9 +1,9 @@ -import datetime import logging import time import click from celery import shared_task + from core.index.index import IndexBuilder from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 6cb51eecdb..c5f028c72d 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -1,15 +1,15 @@ -import datetime import logging import time import click from celery import shared_task +from werkzeug.exceptions import NotFound + from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset -from models.model import App, AppAnnotationSetting, MessageAnnotation -from werkzeug.exceptions import NotFound +from models.model import App, AppAnnotationSetting @shared_task(queue='dataset') diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 42c3b23836..a125dd5717 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -4,14 +4,15 @@ import time import click from celery import shared_task +from langchain.schema import Document +from werkzeug.exceptions import NotFound + from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client -from langchain.schema import Document from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index c1ca161c9a..e632c3a24e 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -3,8 +3,9 @@ import time import click from celery import shared_task -from core.index.index import IndexBuilder from langchain.schema import Document + +from core.index.index import IndexBuilder from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 1d5d966098..4c79b537d8 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -6,6 +6,8 @@ from typing import List, cast import click from celery import shared_task +from sqlalchemy import func + from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -14,7 +16,6 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper from models.dataset import Dataset, Document, DocumentSegment -from sqlalchemy import func @shared_task(queue='dataset') diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 5813f38706..74ebcea15f 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -3,12 +3,17 @@ import time import click from celery import shared_task + from core.index.index import IndexBuilder -from core.index.vector_index.vector_index import VectorIndex from extensions.ext_database import db -from flask import current_app -from models.dataset import (AppDatasetJoin, Dataset, DatasetKeywordTable, DatasetProcessRule, DatasetQuery, Document, - DocumentSegment) +from models.dataset import ( + AppDatasetJoin, + Dataset, + DatasetProcessRule, + DatasetQuery, + Document, + DocumentSegment, +) @shared_task(queue='dataset') diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 1750eb80aa..76eb1a572c 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task + from core.index.index import IndexBuilder from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 46b066d7ba..536f766af3 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -4,6 +4,7 @@ from typing import List import click from celery import shared_task + from core.index.index import IndexBuilder from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 23e599cf03..8ba09de2b2 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -5,12 +5,13 @@ from typing import List, Optional import click from celery import shared_task +from langchain.schema import Document +from werkzeug.exceptions import NotFound + from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client -from langchain.schema import Document from models.dataset import DocumentSegment -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index d8a3b501ef..008f122a82 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -3,12 +3,12 @@ import time import click from celery import shared_task +from langchain.schema import Document + from core.index.index import IndexBuilder from extensions.ext_database import db -from langchain.schema import Document -from models.dataset import Dataset +from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument -from models.dataset import DocumentSegment @shared_task(queue='dataset') diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 75776b2aa2..9c9b00a2f5 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -3,11 +3,11 @@ import time import click from celery import shared_task + from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset, Document, DocumentSegment -from werkzeug.exceptions import NotFound +from models.dataset import Dataset, Document @shared_task(queue='dataset') diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 57c94c7fa1..97f4fd0677 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -3,11 +3,12 @@ import time import click from celery import shared_task +from werkzeug.exceptions import NotFound + from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 319e8ddb0d..57f080e3ff 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -4,13 +4,14 @@ import time import click from celery import shared_task +from werkzeug.exceptions import NotFound + from core.data_loader.loader.notion import NotionLoader from core.index.index import IndexBuilder from core.indexing_runner import DocumentIsPausedException, IndexingRunner from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment from models.source import DataSourceBinding -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 2ea6288059..87081e19e3 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -4,10 +4,10 @@ import time import click from celery import shared_task + from core.indexing_runner import DocumentIsPausedException, IndexingRunner from extensions.ext_database import db from models.dataset import Document -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 54449662c3..12014799b0 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -4,11 +4,12 @@ import time import click from celery import shared_task +from werkzeug.exceptions import NotFound + from core.index.index import IndexBuilder from core.indexing_runner import DocumentIsPausedException, IndexingRunner from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index ce450563d2..8dffd01520 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -4,12 +4,13 @@ import time import click from celery import shared_task +from langchain.schema import Document +from werkzeug.exceptions import NotFound + from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client -from langchain.schema import Document from models.dataset import DocumentSegment -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index a7df7f9298..7d134fc34f 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -3,10 +3,10 @@ import time import click from celery import shared_task -from constants.languages import languages -from extensions.ext_mail import mail from flask import current_app, render_template +from extensions.ext_mail import mail + @shared_task(queue='mail') def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str): diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index e1ed87a395..02278f512b 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -3,10 +3,11 @@ import time import click from celery import shared_task +from werkzeug.exceptions import NotFound + from core.indexing_runner import DocumentIsPausedException, IndexingRunner from extensions.ext_database import db from models.dataset import Document -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 6bb6e96261..a18842a59a 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -3,11 +3,12 @@ import time import click from celery import shared_task +from werkzeug.exceptions import NotFound + from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Document, DocumentSegment -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/update_segment_index_task.py b/api/tasks/update_segment_index_task.py index 1f6592a3e8..40089ad3e4 100644 --- a/api/tasks/update_segment_index_task.py +++ b/api/tasks/update_segment_index_task.py @@ -5,12 +5,13 @@ from typing import List, Optional import click from celery import shared_task +from langchain.schema import Document +from werkzeug.exceptions import NotFound + from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client -from langchain.schema import Document from models.dataset import DocumentSegment -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/api/tasks/update_segment_keyword_index_task.py b/api/tasks/update_segment_keyword_index_task.py index 8ae4b64137..ee88beba98 100644 --- a/api/tasks/update_segment_keyword_index_task.py +++ b/api/tasks/update_segment_keyword_index_task.py @@ -1,16 +1,15 @@ import datetime import logging import time -from typing import List, Optional import click from celery import shared_task +from werkzeug.exceptions import NotFound + from core.index.index import IndexBuilder from extensions.ext_database import db from extensions.ext_redis import redis_client -from langchain.schema import Document from models.dataset import DocumentSegment -from werkzeug.exceptions import NotFound @shared_task(queue='dataset') diff --git a/dev/reformat b/dev/reformat index 8e0baf5e11..864f9b4b02 100755 --- a/dev/reformat +++ b/dev/reformat @@ -2,10 +2,11 @@ set -x -# python style checks rely on `isort` in path -if ! command -v isort &> /dev/null -then - echo "Skip Python imports linting, since 'isort' is not available. Please install it with 'pip install isort'." -else - isort --settings ./.github/linters/.isort.cfg ./ +# python style checks rely on `ruff` in path +if ! command -v ruff &> /dev/null; then + echo "Installing Ruff ..." + pip install ruff fi + +# run ruff linter +ruff check --fix ./api diff --git a/web/.husky/pre-commit b/web/.husky/pre-commit index 7a8fabf9e8..dfd6ec0209 100755 --- a/web/.husky/pre-commit +++ b/web/.husky/pre-commit @@ -1,4 +1,35 @@ -#!/usr/bin/env sh +#!/usr/bin/env bash . "$(dirname -- "$0")/_/husky.sh" -cd ./web && npx lint-staged +# get the list of modified files +files=$(git diff --cached --name-only) + +# check if api or web directory is modified + +api_modified=false +web_modified=false + +for file in $files +do + if [[ $file == "api/"* && $file == *.py ]]; then + # set api_modified flag to true + api_modified=true + elif [[ $file == "web/"* ]]; then + # set web_modified flag to true + web_modified=true + fi +done + +# run linters based on the modified modules + +if $api_modified; then + echo "Running Ruff linter on api module" + ./dev/reformat +fi + +if $web_modified; then + echo "Running ESLint on web module" + cd ./web || exit 1 + npx lint-staged + cd ../ +fi