From 2cf1187b32c0885e9f9b987d57c6ca9da4508a2c Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Tue, 10 Sep 2024 17:00:20 +0800 Subject: [PATCH] chore(api/core): apply ruff reformatting (#7624) --- api/core/__init__.py | 2 +- api/core/agent/cot_agent_runner.py | 213 ++- api/core/agent/cot_chat_agent_runner.py | 26 +- api/core/agent/cot_completion_agent_runner.py | 21 +- api/core/agent/entities.py | 16 +- api/core/agent/fc_agent_runner.py | 235 ++- .../agent/output_parser/cot_output_parser.py | 74 +- api/core/agent/prompt/template.py | 18 +- .../app/app_config/base_app_config_manager.py | 26 +- .../sensitive_word_avoidance/manager.py | 23 +- .../easy_ui_based_app/agent/manager.py | 63 +- .../easy_ui_based_app/dataset/manager.py | 88 +- .../model_config/converter.py | 30 +- .../easy_ui_based_app/model_config/manager.py | 33 +- .../prompt_template/manager.py | 57 +- .../easy_ui_based_app/variables/manager.py | 52 +- api/core/app/app_config/entities.py | 46 +- .../features/file_upload/manager.py | 28 +- .../features/more_like_this/manager.py | 8 +- .../features/opening_statement/manager.py | 6 +- .../features/retrieval_resource/manager.py | 8 +- .../features/speech_to_text/manager.py | 8 +- .../manager.py | 14 +- .../features/text_to_speech/manager.py | 16 +- .../apps/advanced_chat/app_config_manager.py | 26 +- .../app/apps/advanced_chat/app_generator.py | 166 +- .../app_generator_tts_publisher.py | 37 +- api/core/app/apps/advanced_chat/app_runner.py | 84 +- .../generate_response_converter.py | 54 +- .../advanced_chat/generate_task_pipeline.py | 167 +- .../app/apps/agent_chat/app_config_manager.py | 59 +- api/core/app/apps/agent_chat/app_generator.py | 102 +- api/core/app/apps/agent_chat/app_runner.py | 91 +- .../agent_chat/generate_response_converter.py | 56 +- .../base_app_generate_response_converter.py | 85 +- api/core/app/apps/base_app_generator.py | 12 +- api/core/app/apps/base_app_queue_manager.py | 30 +- api/core/app/apps/base_app_runner.py | 272 ++- api/core/app/apps/chat/app_config_manager.py | 45 +- api/core/app/apps/chat/app_generator.py | 106 +- api/core/app/apps/chat/app_runner.py | 51 +- .../apps/chat/generate_response_converter.py | 56 +- .../app/apps/completion/app_config_manager.py | 35 +- api/core/app/apps/completion/app_generator.py | 175 +- api/core/app/apps/completion/app_runner.py | 36 +- .../completion/generate_response_converter.py | 50 +- .../app/apps/message_based_app_generator.py | 104 +- .../apps/message_based_app_queue_manager.py | 21 +- .../app/apps/workflow/app_config_manager.py | 18 +- api/core/app/apps/workflow/app_generator.py | 129 +- .../app/apps/workflow/app_queue_manager.py | 24 +- api/core/app/apps/workflow/app_runner.py | 23 +- .../workflow/generate_response_converter.py | 22 +- .../apps/workflow/generate_task_pipeline.py | 133 +- api/core/app/apps/workflow_app_runner.py | 154 +- .../app/apps/workflow_logging_callback.py | 198 +-- api/core/app/entities/app_invoke_entities.py | 34 +- api/core/app/entities/queue_entities.py | 54 +- api/core/app/entities/task_entities.py | 40 +- .../annotation_reply/annotation_reply.py | 60 +- .../hosting_moderation/hosting_moderation.py | 10 +- .../app/features/rate_limiting/rate_limit.py | 23 +- api/core/app/segments/__init__.py | 42 +- api/core/app/segments/factory.py | 20 +- api/core/app/segments/parser.py | 4 +- api/core/app/segments/segment_group.py | 6 +- api/core/app/segments/segments.py | 24 +- api/core/app/segments/types.py | 20 +- api/core/app/segments/variables.py | 5 +- .../based_generate_task_pipeline.py | 42 +- .../easy_ui_based_generate_task_pipeline.py | 175 +- .../app/task_pipeline/message_cycle_manage.py | 75 +- .../task_pipeline/workflow_cycle_manage.py | 117 +- .../agent_tool_callback_handler.py | 38 +- .../index_tool_callback_handler.py | 62 +- .../workflow_tool_callback_handler.py | 2 +- api/core/embedding/cached_embedding.py | 50 +- api/core/entities/agent_entities.py | 8 +- api/core/entities/message_entities.py | 6 +- api/core/entities/model_entities.py | 8 +- api/core/entities/provider_configuration.py | 368 +++-- api/core/entities/provider_entities.py | 20 +- api/core/errors/error.py | 7 + .../api_based_extension_requestor.py | 25 +- api/core/extension/extensible.py | 46 +- api/core/extension/extension.py | 5 +- api/core/external_data_tool/api/api.py | 68 +- .../external_data_tool/external_data_fetch.py | 41 +- api/core/external_data_tool/factory.py | 6 +- api/core/file/file_obj.py | 59 +- api/core/file/message_file_parser.py | 100 +- api/core/file/tool_file_parser.py | 9 +- api/core/file/upload_file_parser.py | 12 +- .../helper/code_executor/code_executor.py | 72 +- .../code_executor/code_node_provider.py | 20 +- .../javascript/javascript_code_provider.py | 3 +- .../javascript/javascript_transformer.py | 3 +- .../code_executor/jinja2/jinja2_formatter.py | 6 +- .../jinja2/jinja2_transformer.py | 4 +- .../python3/python3_code_provider.py | 3 +- .../code_executor/template_transformer.py | 14 +- api/core/helper/encrypter.py | 7 +- api/core/helper/model_provider_cache.py | 2 +- api/core/helper/moderation.py | 21 +- api/core/helper/module_import_helper.py | 9 +- api/core/helper/position_helper.py | 22 +- api/core/helper/ssrf_proxy.py | 33 +- api/core/helper/tool_parameter_cache.py | 15 +- api/core/helper/tool_provider_cache.py | 5 +- api/core/hosting_configuration.py | 106 +- api/core/indexing_runner.py | 474 +++--- api/core/llm_generator/llm_generator.py | 99 +- .../output_parser/rule_config_generator.py | 20 +- .../suggested_questions_after_answer.py | 3 +- api/core/llm_generator/prompts.py | 20 +- api/core/memory/token_buffer_memory.py | 64 +- .../model_runtime/callbacks/base_callback.py | 72 +- .../callbacks/logging_callback.py | 114 +- .../model_runtime/entities/common_entities.py | 1 + api/core/model_runtime/entities/defaults.py | 176 +- .../model_runtime/entities/llm_entities.py | 35 +- .../entities/message_entities.py | 31 +- .../model_runtime/entities/model_entities.py | 47 +- .../entities/provider_entities.py | 17 +- .../model_runtime/entities/rerank_entities.py | 2 + .../entities/text_embedding_entities.py | 3 +- api/core/model_runtime/errors/invoke.py | 6 + api/core/model_runtime/errors/validate.py | 1 + .../model_providers/__base/ai_model.py | 79 +- .../__base/large_language_model.py | 260 +-- .../model_providers/__base/model_provider.py | 24 +- .../__base/moderation_model.py | 10 +- .../model_providers/__base/rerank_model.py | 29 +- .../__base/speech2text_model.py | 11 +- .../model_providers/__base/text2img_model.py | 13 +- .../__base/text_embedding_model.py | 13 +- .../__base/tokenizers/gpt2_tokenzier.py | 11 +- .../model_providers/__base/tts_model.py | 38 +- .../model_providers/anthropic/anthropic.py | 7 +- .../model_providers/anthropic/llm/llm.py | 304 ++-- .../model_providers/azure_openai/_common.py | 28 +- .../model_providers/azure_openai/_constant.py | 905 +++++----- .../azure_openai/azure_openai.py | 1 - .../model_providers/azure_openai/llm/llm.py | 261 ++- .../azure_openai/speech2text/speech2text.py | 9 +- .../text_embedding/text_embedding.py | 86 +- .../model_providers/azure_openai/tts/tts.py | 44 +- .../model_providers/baichuan/baichuan.py | 8 +- .../baichuan/llm/baichuan_tokenizer.py | 8 +- .../baichuan/llm/baichuan_turbo.py | 5 +- .../baichuan/llm/baichuan_turbo_errors.py | 7 +- .../model_providers/baichuan/llm/llm.py | 93 +- .../baichuan/text_embedding/text_embedding.py | 111 +- .../model_providers/bedrock/bedrock.py | 10 +- .../model_providers/bedrock/llm/llm.py | 528 +++--- .../bedrock/text_embedding/text_embedding.py | 113 +- .../model_providers/chatglm/chatglm.py | 7 +- .../model_providers/chatglm/llm/llm.py | 219 +-- .../model_providers/cohere/cohere.py | 7 +- .../model_providers/cohere/llm/llm.py | 312 ++-- .../model_providers/cohere/rerank/rerank.py | 46 +- .../cohere/text_embedding/text_embedding.py | 86 +- .../model_providers/deepseek/deepseek.py | 9 +- .../model_providers/deepseek/llm/llm.py | 45 +- .../model_providers/fishaudio/__init__.py | 1 - .../model_providers/fishaudio/fishaudio.py | 8 +- .../model_providers/fishaudio/tts/tts.py | 34 +- .../model_providers/google/google.py | 7 +- .../model_providers/google/llm/llm.py | 206 ++- .../model_providers/groq/groq.py | 9 +- .../model_providers/groq/llm/llm.py | 21 +- .../huggingface_hub/_common.py | 8 +- .../huggingface_hub/huggingface_hub.py | 1 - .../huggingface_hub/llm/llm.py | 187 +-- .../text_embedding/text_embedding.py | 109 +- .../huggingface_tei/huggingface_tei.py | 1 - .../huggingface_tei/rerank/rerank.py | 32 +- .../huggingface_tei/tei_helper.py | 54 +- .../text_embedding/text_embedding.py | 44 +- .../model_providers/hunyuan/hunyuan.py | 9 +- .../model_providers/hunyuan/llm/llm.py | 180 +- .../hunyuan/text_embedding/text_embedding.py | 51 +- .../model_providers/jina/jina.py | 8 +- .../model_providers/jina/rerank/rerank.py | 59 +- .../jina/text_embedding/jina_tokenizer.py | 8 +- .../jina/text_embedding/text_embedding.py | 92 +- .../model_providers/leptonai/leptonai.py | 9 +- .../model_providers/leptonai/llm/llm.py | 36 +- .../model_providers/localai/llm/llm.py | 353 ++-- .../model_providers/localai/localai.py | 3 +- .../model_providers/localai/rerank/rerank.py | 72 +- .../localai/speech2text/speech2text.py | 38 +- .../localai/text_embedding/text_embedding.py | 96 +- .../minimax/llm/chat_completion.py | 152 +- .../minimax/llm/chat_completion_pro.py | 159 +- .../model_providers/minimax/llm/errors.py | 7 +- .../model_providers/minimax/llm/llm.py | 202 +-- .../model_providers/minimax/llm/types.py | 31 +- .../model_providers/minimax/minimax.py | 10 +- .../minimax/text_embedding/text_embedding.py | 83 +- .../model_providers/mistralai/llm/llm.py | 23 +- .../model_providers/mistralai/mistralai.py | 8 +- .../model_providers/moonshot/llm/llm.py | 169 +- .../model_providers/moonshot/moonshot.py | 8 +- .../model_providers/novita/llm/llm.py | 52 +- .../model_providers/novita/novita.py | 7 +- .../model_providers/nvidia/llm/llm.py | 214 ++- .../model_providers/nvidia/nvidia.py | 8 +- .../model_providers/nvidia/rerank/rerank.py | 23 +- .../nvidia/text_embedding/text_embedding.py | 81 +- .../model_providers/nvidia_nim/llm/llm.py | 1 + .../model_providers/nvidia_nim/nvidia_nim.py | 1 - .../model_providers/oci/llm/llm.py | 189 ++- .../model_runtime/model_providers/oci/oci.py | 10 +- .../oci/text_embedding/text_embedding.py | 90 +- .../model_providers/ollama/llm/llm.py | 100 +- .../model_providers/ollama/ollama.py | 1 - .../ollama/text_embedding/text_embedding.py | 79 +- .../model_providers/openai/_common.py | 6 +- .../model_providers/openai/llm/llm.py | 442 ++--- .../openai/moderation/moderation.py | 10 +- .../model_providers/openai/openai.py | 8 +- .../openai/speech2text/speech2text.py | 6 +- .../openai/text_embedding/text_embedding.py | 65 +- .../model_providers/openai/tts/tts.py | 47 +- .../openai_api_compatible/_common.py | 9 +- .../openai_api_compatible/llm/llm.py | 445 +++-- .../openai_api_compatible.py | 1 - .../speech2text/speech2text.py | 4 +- .../text_embedding/text_embedding.py | 118 +- .../model_providers/openllm/llm/llm.py | 194 +-- .../openllm/llm/openllm_generate.py | 121 +- .../openllm/llm/openllm_generate_errors.py | 7 +- .../openllm/text_embedding/text_embedding.py | 67 +- .../model_providers/openrouter/llm/llm.py | 46 +- .../model_providers/openrouter/openrouter.py | 10 +- .../model_providers/perfxcloud/llm/llm.py | 42 +- .../model_providers/perfxcloud/perfxcloud.py | 8 +- .../text_embedding/text_embedding.py | 126 +- .../model_providers/replicate/_common.py | 8 +- .../model_providers/replicate/llm/llm.py | 172 +- .../model_providers/replicate/replicate.py | 1 - .../text_embedding/text_embedding.py | 97 +- .../model_providers/sagemaker/llm/llm.py | 260 ++- .../sagemaker/rerank/rerank.py | 102 +- .../model_providers/sagemaker/sagemaker.py | 30 +- .../sagemaker/speech2text/speech2text.py | 81 +- .../text_embedding/text_embedding.py | 101 +- .../model_providers/sagemaker/tts/tts.py | 226 ++- .../model_providers/siliconflow/llm/llm.py | 20 +- .../siliconflow/rerank/rerank.py | 47 +- .../siliconflow/siliconflow.py | 8 +- .../siliconflow/speech2text/speech2text.py | 4 +- .../text_embedding/text_embedding.py | 13 +- .../model_providers/spark/llm/_client.py | 130 +- .../model_providers/spark/llm/llm.py | 117 +- .../model_providers/stepfun/llm/llm.py | 166 +- .../model_providers/stepfun/stepfun.py | 8 +- .../tencent/speech2text/flash_recognizer.py | 51 +- .../tencent/speech2text/speech2text.py | 14 +- .../model_providers/tencent/tencent.py | 7 +- .../model_providers/togetherai/llm/llm.py | 90 +- .../model_providers/togetherai/togetherai.py | 1 - .../model_providers/tongyi/_common.py | 4 +- .../model_providers/tongyi/llm/llm.py | 221 +-- .../tongyi/text_embedding/text_embedding.py | 23 +- .../model_providers/tongyi/tongyi.py | 7 +- .../model_providers/tongyi/tts/tts.py | 54 +- .../triton_inference_server/llm/llm.py | 273 +-- .../triton_inference_server.py | 1 + .../model_providers/upstage/_common.py | 9 +- .../model_providers/upstage/llm/llm.py | 262 +-- .../upstage/text_embedding/text_embedding.py | 63 +- .../model_providers/upstage/upstage.py | 11 +- .../model_providers/vertex_ai/llm/llm.py | 302 ++-- .../text_embedding/text_embedding.py | 52 +- .../model_providers/vertex_ai/vertex_ai.py | 7 +- .../model_providers/volcengine_maas/client.py | 129 +- .../volcengine_maas/legacy/client.py | 69 +- .../volcengine_maas/legacy/errors.py | 50 +- .../legacy/volc_sdk/__init__.py | 2 +- .../legacy/volc_sdk/base/auth.py | 109 +- .../legacy/volc_sdk/base/service.py | 51 +- .../legacy/volc_sdk/base/util.py | 18 +- .../volcengine_maas/legacy/volc_sdk/common.py | 20 +- .../volcengine_maas/legacy/volc_sdk/maas.py | 59 +- .../volcengine_maas/llm/llm.py | 231 +-- .../volcengine_maas/llm/models.py | 141 +- .../volcengine_maas/text_embedding/models.py | 10 +- .../text_embedding/text_embedding.py | 59 +- .../model_providers/wenxin/_common.py | 124 +- .../model_providers/wenxin/llm/ernie_bot.py | 186 ++- .../model_providers/wenxin/llm/llm.py | 230 ++- .../wenxin/text_embedding/text_embedding.py | 70 +- .../model_providers/wenxin/wenxin.py | 8 +- .../model_providers/wenxin/wenxin_errors.py | 23 +- .../model_providers/xinference/llm/llm.py | 506 +++--- .../xinference/rerank/rerank.py | 121 +- .../xinference/speech2text/speech2text.py | 69 +- .../text_embedding/text_embedding.py | 82 +- .../model_providers/xinference/tts/tts.py | 180 +- .../xinference/xinference_helper.py | 86 +- .../model_providers/yi/llm/llm.py | 46 +- .../model_runtime/model_providers/yi/yi.py | 8 +- .../model_providers/zhinao/llm/llm.py | 20 +- .../model_providers/zhinao/zhinao.py | 8 +- .../model_providers/zhipuai/_common.py | 5 +- .../model_providers/zhipuai/llm/llm.py | 269 ++- .../zhipuai/text_embedding/text_embedding.py | 33 +- .../model_providers/zhipuai/zhipuai.py | 7 +- .../zhipuai/zhipuai_sdk/__init__.py | 1 - .../zhipuai/zhipuai_sdk/__version__.py | 3 +- .../zhipuai/zhipuai_sdk/_client.py | 21 +- .../api_resource/chat/async_completions.py | 48 +- .../api_resource/chat/completions.py | 36 +- .../zhipuai_sdk/api_resource/embeddings.py | 24 +- .../zhipuai/zhipuai_sdk/api_resource/files.py | 33 +- .../api_resource/fine_tuning/fine_tuning.py | 1 - .../api_resource/fine_tuning/jobs.py | 52 +- .../zhipuai_sdk/api_resource/images.py | 36 +- .../zhipuai/zhipuai_sdk/core/_errors.py | 29 +- .../zhipuai/zhipuai_sdk/core/_http_client.py | 193 +-- .../zhipuai/zhipuai_sdk/core/_request_opt.py | 13 +- .../zhipuai/zhipuai_sdk/core/_response.py | 18 +- .../zhipuai/zhipuai_sdk/core/_sse_client.py | 53 +- .../types/chat/async_chat_completion.py | 2 +- .../zhipuai_sdk/types/chat/chat_completion.py | 2 - .../zhipuai/zhipuai_sdk/types/file_object.py | 2 - .../types/fine_tuning/fine_tuning_job.py | 2 +- .../schema_validators/common_validator.py | 27 +- .../model_credential_schema_validator.py | 1 - .../provider_credential_schema_validator.py | 1 - api/core/model_runtime/utils/encoders.py | 19 +- api/core/model_runtime/utils/helper.py | 2 +- api/core/moderation/api/api.py | 32 +- api/core/moderation/base.py | 5 +- api/core/moderation/input_moderation.py | 12 +- api/core/moderation/keywords/keywords.py | 24 +- .../openai_moderation/openai_moderation.py | 31 +- api/core/moderation/output_moderation.py | 44 +- api/core/ops/base_trace_instance.py | 2 +- api/core/ops/entities/config_entity.py | 23 +- api/core/ops/entities/trace_entity.py | 32 +- .../entities/langsmith_trace_entity.py | 62 +- .../ops/langsmith_trace/langsmith_trace.py | 14 +- api/core/ops/ops_trace_manager.py | 157 +- api/core/ops/utils.py | 8 +- api/core/prompt/advanced_prompt_transform.py | 140 +- .../prompt/agent_history_prompt_transform.py | 22 +- .../entities/advanced_prompt_entities.py | 9 +- .../advanced_prompt_templates.py | 64 +- api/core/prompt/prompt_transform.py | 75 +- api/core/prompt/simple_prompt_transform.py | 204 +-- api/core/prompt/utils/prompt_message_util.py | 67 +- .../prompt/utils/prompt_template_parser.py | 4 +- api/core/provider_manager.py | 348 ++-- api/core/rag/cleaner/clean_processor.py | 30 +- api/core/rag/cleaner/cleaner_base.py | 5 +- .../unstructured_extra_whitespace_cleaner.py | 2 +- ...uctured_group_broken_paragraphs_cleaner.py | 2 +- .../unstructured_non_ascii_chars_cleaner.py | 2 +- ...ructured_replace_unicode_quotes_cleaner.py | 3 +- .../unstructured_translate_text_cleaner.py | 2 +- .../data_post_processor.py | 51 +- api/core/rag/data_post_processor/reorder.py | 1 - .../rag/datasource/keyword/jieba/jieba.py | 136 +- .../jieba/jieba_keyword_table_handler.py | 3 +- .../rag/datasource/keyword/jieba/stopwords.py | 1466 ++++++++++++++++- .../rag/datasource/keyword/keyword_base.py | 10 +- .../rag/datasource/keyword/keyword_factory.py | 9 +- api/core/rag/datasource/retrieval_service.py | 237 +-- .../vdb/analyticdb/analyticdb_vector.py | 55 +- .../datasource/vdb/chroma/chroma_vector.py | 44 +- .../vdb/elasticsearch/elasticsearch_vector.py | 119 +- .../datasource/vdb/milvus/milvus_vector.py | 114 +- .../datasource/vdb/myscale/myscale_vector.py | 22 +- .../vdb/opensearch/opensearch_vector.py | 99 +- .../rag/datasource/vdb/oracle/oraclevector.py | 56 +- .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 48 +- .../rag/datasource/vdb/pgvector/pgvector.py | 5 +- .../datasource/vdb/qdrant/qdrant_vector.py | 172 +- .../rag/datasource/vdb/relyt/relyt_vector.py | 60 +- .../datasource/vdb/tencent/tencent_vector.py | 59 +- .../datasource/vdb/tidb_vector/tidb_vector.py | 75 +- api/core/rag/datasource/vdb/vector_base.py | 16 +- api/core/rag/datasource/vdb/vector_factory.py | 54 +- api/core/rag/datasource/vdb/vector_type.py | 28 +- .../vdb/weaviate/weaviate_vector.py | 78 +- api/core/rag/docstore/dataset_docstore.py | 70 +- api/core/rag/extractor/blob/blob.py | 1 + api/core/rag/extractor/csv_extractor.py | 17 +- .../rag/extractor/entity/extract_setting.py | 3 + api/core/rag/extractor/excel_extractor.py | 35 +- api/core/rag/extractor/extract_processor.py | 103 +- api/core/rag/extractor/extractor_base.py | 5 +- .../rag/extractor/firecrawl/firecrawl_app.py | 102 +- .../firecrawl/firecrawl_web_extractor.py | 52 +- api/core/rag/extractor/helpers.py | 4 +- api/core/rag/extractor/html_extractor.py | 13 +- api/core/rag/extractor/markdown_extractor.py | 20 +- api/core/rag/extractor/notion_extractor.py | 143 +- api/core/rag/extractor/pdf_extractor.py | 15 +- api/core/rag/extractor/text_extractor.py | 8 +- .../unstructured_doc_extractor.py | 14 +- .../unstructured_eml_extractor.py | 6 +- .../unstructured_epub_extractor.py | 1 + .../unstructured_markdown_extractor.py | 1 + .../unstructured_msg_extractor.py | 7 +- .../unstructured_ppt_extractor.py | 7 +- .../unstructured_pptx_extractor.py | 6 +- .../unstructured_text_extractor.py | 7 +- .../unstructured_xml_extractor.py | 7 +- api/core/rag/extractor/word_extractor.py | 77 +- .../index_processor/index_processor_base.py | 34 +- .../index_processor_factory.py | 4 +- .../processor/paragraph_index_processor.py | 48 +- .../processor/qa_index_processor.py | 86 +- api/core/rag/models/document.py | 8 +- api/core/rag/rerank/constants/rerank_mode.py | 6 +- api/core/rag/rerank/rerank_model.py | 32 +- api/core/rag/rerank/weight_rerank.py | 39 +- api/core/rag/retrieval/dataset_retrieval.py | 385 +++-- .../output_parser/structured_chat.py | 4 +- api/core/rag/retrieval/retrieval_methods.py | 6 +- .../multi_dataset_function_call_router.py | 24 +- .../router/multi_dataset_react_route.py | 120 +- api/core/rag/splitter/fixed_text_splitter.py | 19 +- api/core/rag/splitter/text_splitter.py | 129 +- api/core/tools/entities/api_entities.py | 49 +- api/core/tools/entities/common_entities.py | 7 +- api/core/tools/entities/tool_bundle.py | 1 + api/core/tools/entities/tool_entities.py | 177 +- api/core/tools/entities/values.py | 132 +- api/core/tools/errors.py | 9 +- api/core/tools/provider/api_tool_provider.py | 168 +- api/core/tools/provider/app_tool_provider.py | 120 +- api/core/tools/provider/builtin/_positions.py | 2 +- .../tools/provider/builtin/aippt/aippt.py | 2 +- .../provider/builtin/aippt/tools/aippt.py | 479 +++--- .../builtin/alphavantage/alphavantage.py | 2 +- .../builtin/alphavantage/tools/query_stock.py | 31 +- .../tools/provider/builtin/arxiv/arxiv.py | 3 +- .../builtin/arxiv/tools/arxiv_search.py | 16 +- api/core/tools/provider/builtin/aws/aws.py | 11 +- .../builtin/aws/tools/apply_guardrail.py | 37 +- .../aws/tools/lambda_translate_utils.py | 79 +- .../builtin/aws/tools/lambda_yaml_to_json.py | 47 +- .../aws/tools/sagemaker_text_rerank.py | 58 +- .../builtin/aws/tools/sagemaker_tts.py | 88 +- .../provider/builtin/azuredalle/azuredalle.py | 8 +- .../builtin/azuredalle/tools/dalle3.py | 67 +- .../builtin/bing/tools/bing_web_search.py | 198 ++- .../tools/provider/builtin/brave/brave.py | 3 +- .../builtin/brave/tools/brave_search.py | 17 +- .../tools/provider/builtin/chart/chart.py | 41 +- .../tools/provider/builtin/chart/tools/bar.py | 27 +- .../provider/builtin/chart/tools/line.py | 31 +- .../tools/provider/builtin/chart/tools/pie.py | 28 +- .../builtin/code/tools/simple_code.py | 14 +- .../tools/provider/builtin/cogview/cogview.py | 11 +- .../builtin/cogview/tools/cogview3.py | 59 +- .../provider/builtin/crossref/crossref.py | 4 +- .../builtin/crossref/tools/query_doi.py | 11 +- .../builtin/crossref/tools/query_title.py | 73 +- .../tools/provider/builtin/dalle/dalle.py | 9 +- .../provider/builtin/dalle/tools/dalle2.py | 55 +- .../provider/builtin/dalle/tools/dalle3.py | 73 +- .../tools/provider/builtin/devdocs/devdocs.py | 3 +- .../builtin/devdocs/tools/searchDevDocs.py | 16 +- api/core/tools/provider/builtin/did/did.py | 9 +- .../tools/provider/builtin/did/did_appx.py | 46 +- .../provider/builtin/did/tools/animations.py | 36 +- .../tools/provider/builtin/did/tools/talks.py | 62 +- .../dingtalk/tools/dingtalk_group_bot.py | 62 +- .../provider/builtin/duckduckgo/duckduckgo.py | 3 +- .../builtin/duckduckgo/tools/ddgo_ai.py | 4 +- .../builtin/duckduckgo/tools/ddgo_img.py | 17 +- .../builtin/duckduckgo/tools/ddgo_search.py | 17 +- .../duckduckgo/tools/ddgo_translate.py | 6 +- .../builtin/feishu/tools/feishu_group_bot.py | 31 +- .../builtin/feishu_base/feishu_base.py | 2 +- .../feishu_base/tools/add_base_record.py | 42 +- .../builtin/feishu_base/tools/create_base.py | 26 +- .../feishu_base/tools/create_base_table.py | 34 +- .../feishu_base/tools/delete_base_records.py | 42 +- .../feishu_base/tools/delete_base_tables.py | 29 +- .../feishu_base/tools/get_base_info.py | 21 +- .../tools/get_tenant_access_token.py | 24 +- .../feishu_base/tools/list_base_records.py | 46 +- .../feishu_base/tools/list_base_tables.py | 25 +- .../feishu_base/tools/read_base_record.py | 34 +- .../feishu_base/tools/update_base_record.py | 46 +- .../feishu_document/feishu_document.py | 6 +- .../feishu_document/tools/create_document.py | 10 +- .../tools/get_document_raw_content.py | 8 +- .../tools/list_document_block.py | 10 +- .../feishu_document/tools/write_document.py | 10 +- .../builtin/feishu_message/feishu_message.py | 6 +- .../feishu_message/tools/send_bot_message.py | 12 +- .../tools/send_webhook_message.py | 12 +- .../provider/builtin/firecrawl/firecrawl.py | 11 +- .../builtin/firecrawl/firecrawl_appx.py | 61 +- .../provider/builtin/firecrawl/tools/crawl.py | 48 +- .../builtin/firecrawl/tools/crawl_job.py | 17 +- .../builtin/firecrawl/tools/scrape.py | 36 +- .../builtin/firecrawl/tools/search.py | 19 +- .../tools/provider/builtin/gaode/gaode.py | 14 +- .../builtin/gaode/tools/gaode_weather.py | 67 +- .../provider/builtin/getimgai/getimgai.py | 9 +- .../builtin/getimgai/getimgai_appx.py | 22 +- .../builtin/getimgai/tools/text2image.py | 34 +- .../tools/provider/builtin/github/github.py | 16 +- .../github/tools/github_repositories.py | 64 +- .../tools/provider/builtin/gitlab/gitlab.py | 18 +- .../builtin/gitlab/tools/gitlab_commits.py | 107 +- .../builtin/gitlab/tools/gitlab_files.py | 61 +- .../tools/provider/builtin/google/google.py | 8 +- .../builtin/google/tools/google_search.py | 23 +- .../google_translate/google_translate.py | 6 +- .../google_translate/tools/translate.py | 34 +- api/core/tools/provider/builtin/hap/hap.py | 2 +- .../builtin/hap/tools/add_worksheet_record.py | 37 +- .../hap/tools/delete_worksheet_record.py | 37 +- .../builtin/hap/tools/get_worksheet_fields.py | 107 +- .../hap/tools/get_worksheet_pivot_data.py | 117 +- .../hap/tools/list_worksheet_records.py | 193 ++- .../builtin/hap/tools/list_worksheets.py | 77 +- .../hap/tools/update_worksheet_record.py | 41 +- api/core/tools/provider/builtin/jina/jina.py | 38 +- .../builtin/jina/tools/jina_reader.py | 67 +- .../builtin/jina/tools/jina_search.py | 41 +- .../builtin/jina/tools/jina_tokenizer.py | 32 +- .../builtin/json_process/json_process.py | 11 +- .../builtin/json_process/tools/delete.py | 23 +- .../builtin/json_process/tools/insert.py | 45 +- .../builtin/json_process/tools/parse.py | 23 +- .../builtin/json_process/tools/replace.py | 55 +- .../provider/builtin/judge0ce/judge0ce.py | 3 +- .../builtin/judge0ce/tools/executeCode.py | 44 +- .../tools/provider/builtin/maths/maths.py | 4 +- .../builtin/maths/tools/eval_expression.py | 21 +- .../provider/builtin/nominatim/nominatim.py | 24 +- .../nominatim/tools/nominatim_lookup.py | 45 +- .../nominatim/tools/nominatim_reverse.py | 48 +- .../nominatim/tools/nominatim_search.py | 48 +- .../builtin/novitaai/_novita_tool_base.py | 24 +- .../provider/builtin/novitaai/novitaai.py | 38 +- .../novitaai/tools/novitaai_createtile.py | 27 +- .../novitaai/tools/novitaai_modelquery.py | 134 +- .../novitaai/tools/novitaai_txt2img.py | 65 +- .../tools/provider/builtin/onebot/onebot.py | 4 +- .../builtin/onebot/tools/send_group_msg.py | 53 +- .../builtin/onebot/tools/send_private_msg.py | 55 +- .../builtin/openweather/openweather.py | 13 +- .../builtin/openweather/tools/weather.py | 14 +- .../provider/builtin/perplexity/perplexity.py | 20 +- .../perplexity/tools/perplexity_search.py | 77 +- .../tools/provider/builtin/pubmed/pubmed.py | 3 +- .../builtin/pubmed/tools/pubmed_search.py | 40 +- .../tools/provider/builtin/qrcode/qrcode.py | 5 +- .../builtin/qrcode/tools/qrcode_generator.py | 41 +- .../tools/provider/builtin/regex/regex.py | 6 +- .../builtin/regex/tools/regex_extract.py | 21 +- .../provider/builtin/searchapi/searchapi.py | 7 +- .../builtin/searchapi/tools/google.py | 23 +- .../builtin/searchapi/tools/google_jobs.py | 34 +- .../builtin/searchapi/tools/google_news.py | 23 +- .../searchapi/tools/youtube_transcripts.py | 17 +- .../tools/provider/builtin/searxng/searxng.py | 8 +- .../builtin/searxng/tools/searxng_search.py | 19 +- .../tools/provider/builtin/serper/serper.py | 7 +- .../builtin/serper/tools/serper_search.py | 30 +- .../builtin/siliconflow/siliconflow.py | 4 +- .../builtin/siliconflow/tools/flux.py | 12 +- .../siliconflow/tools/stable_diffusion.py | 8 +- .../builtin/slack/tools/slack_webhook.py | 27 +- .../tools/provider/builtin/spark/spark.py | 8 +- .../spark/tools/spark_img_generation.py | 39 +- .../tools/provider/builtin/spider/spider.py | 12 +- .../provider/builtin/spider/spiderApp.py | 38 +- .../builtin/spider/tools/scraper_crawler.py | 38 +- .../provider/builtin/stability/stability.py | 1 + .../provider/builtin/stability/tools/base.py | 19 +- .../builtin/stability/tools/text2image.py | 44 +- .../stablediffusion/stablediffusion.py | 1 - .../stablediffusion/tools/stable_diffusion.py | 328 ++-- .../builtin/stackexchange/stackexchange.py | 7 +- .../tools/fetchAnsByStackExQuesID.py | 8 +- .../tools/searchStackExQuestions.py | 14 +- .../tools/provider/builtin/stepfun/stepfun.py | 5 +- .../provider/builtin/stepfun/tools/image.py | 60 +- .../tools/provider/builtin/tavily/tavily.py | 5 +- .../builtin/tavily/tools/tavily_search.py | 28 +- .../provider/builtin/tianditu/tianditu.py | 12 +- .../builtin/tianditu/tools/geocoder.py | 32 +- .../builtin/tianditu/tools/poisearch.py | 67 +- .../builtin/tianditu/tools/staticmap.py | 59 +- api/core/tools/provider/builtin/time/time.py | 3 +- .../builtin/time/tools/current_time.py | 25 +- .../provider/builtin/time/tools/weekday.py | 21 +- .../builtin/trello/tools/create_board.py | 17 +- .../trello/tools/create_list_on_board.py | 19 +- .../trello/tools/create_new_card_on_board.py | 13 +- .../builtin/trello/tools/delete_board.py | 7 +- .../builtin/trello/tools/delete_card.py | 7 +- .../builtin/trello/tools/fetch_all_boards.py | 8 +- .../builtin/trello/tools/get_board_actions.py | 11 +- .../builtin/trello/tools/get_board_by_id.py | 7 +- .../builtin/trello/tools/get_board_cards.py | 7 +- .../trello/tools/get_filterd_board_cards.py | 13 +- .../trello/tools/get_lists_on_board.py | 7 +- .../builtin/trello/tools/update_board.py | 11 +- .../builtin/trello/tools/update_card.py | 10 +- .../tools/provider/builtin/trello/trello.py | 7 +- .../builtin/twilio/tools/send_message.py | 11 +- .../tools/provider/builtin/twilio/twilio.py | 3 +- .../tools/provider/builtin/vanna/vanna.py | 6 +- .../builtin/vectorizer/tools/test_data.py | 2 +- .../builtin/vectorizer/tools/vectorizer.py | 61 +- .../provider/builtin/vectorizer/vectorizer.py | 8 +- .../builtin/webscraper/tools/webscraper.py | 19 +- .../provider/builtin/webscraper/webscraper.py | 7 +- .../builtin/websearch/tools/job_search.py | 18 +- .../builtin/websearch/tools/news_search.py | 16 +- .../builtin/websearch/tools/scholar_search.py | 18 +- .../builtin/websearch/tools/web_search.py | 14 +- .../builtin/wecom/tools/wecom_group_bot.py | 37 +- .../wikipedia/tools/wikipedia_search.py | 1 - .../provider/builtin/wikipedia/wikipedia.py | 3 +- .../wolframalpha/tools/wolframalpha.py | 78 +- .../builtin/wolframalpha/wolframalpha.py | 3 +- .../provider/builtin/yahoo/tools/analytics.py | 60 +- .../provider/builtin/yahoo/tools/news.py | 37 +- .../provider/builtin/yahoo/tools/ticker.py | 19 +- .../tools/provider/builtin/yahoo/yahoo.py | 3 +- .../provider/builtin/youtube/tools/videos.py | 67 +- .../tools/provider/builtin/youtube/youtube.py | 3 +- .../tools/provider/builtin_tool_provider.py | 152 +- api/core/tools/provider/tool_provider.py | 132 +- .../tools/provider/workflow_tool_provider.py | 135 +- api/core/tools/tool/api_tool.py | 189 ++- api/core/tools/tool/builtin_tool.py | 86 +- .../dataset_multi_retriever_tool.py | 155 +- .../dataset_retriever_base_tool.py | 1 + .../dataset_retriever_tool.py | 140 +- api/core/tools/tool/dataset_retriever_tool.py | 51 +- api/core/tools/tool/tool.py | 155 +- api/core/tools/tool/workflow_tool.py | 105 +- api/core/tools/tool_engine.py | 201 ++- api/core/tools/tool_file_manager.py | 29 +- api/core/tools/tool_label_manager.py | 52 +- api/core/tools/tool_manager.py | 439 ++--- api/core/tools/utils/configuration.py | 59 +- api/core/tools/utils/feishu_api_utils.py | 16 +- api/core/tools/utils/message_transformer.py | 109 +- .../tools/utils/model_invocation_utils.py | 66 +- api/core/tools/utils/parser.py | 325 ++-- .../tools/utils/tool_parameter_converter.py | 32 +- api/core/tools/utils/web_reader_tool.py | 73 +- .../utils/workflow_configuration_sync.py | 26 +- api/core/tools/utils/yaml_utils.py | 4 +- .../callbacks/base_workflow_callback.py | 5 +- .../entities/base_node_data_entities.py | 4 +- api/core/workflow/entities/node_entities.py | 61 +- .../workflow/entities/variable_entities.py | 1 + api/core/workflow/entities/variable_pool.py | 12 +- .../workflow/entities/workflow_entities.py | 17 +- .../condition_handlers/base_handler.py | 10 +- .../branch_identify_handler.py | 5 +- .../condition_handlers/condition_handler.py | 8 +- .../condition_handlers/condition_manager.py | 16 +- .../workflow/graph_engine/entities/graph.py | 278 ++-- .../entities/graph_runtime_state.py | 2 +- .../graph_engine/entities/run_condition.py | 2 +- .../entities/runtime_route_state.py | 14 +- .../workflow/graph_engine/graph_engine.py | 308 ++-- api/core/workflow/nodes/answer/answer_node.py | 20 +- .../answer/answer_stream_generate_router.py | 77 +- .../nodes/answer/answer_stream_processor.py | 39 +- .../nodes/answer/base_stream_processor.py | 13 +- api/core/workflow/nodes/answer/entities.py | 10 +- api/core/workflow/nodes/base_node.py | 38 +- api/core/workflow/nodes/code/code_node.py | 170 +- api/core/workflow/nodes/code/entities.py | 7 +- api/core/workflow/nodes/end/end_node.py | 11 +- .../nodes/end/end_stream_generate_router.py | 75 +- .../nodes/end/end_stream_processor.py | 42 +- api/core/workflow/nodes/end/entities.py | 8 +- .../workflow/nodes/http_request/entities.py | 14 +- .../nodes/http_request/http_executor.py | 123 +- .../nodes/http_request/http_request_node.py | 50 +- api/core/workflow/nodes/if_else/entities.py | 1 + .../workflow/nodes/if_else/if_else_node.py | 32 +- api/core/workflow/nodes/iteration/entities.py | 15 +- .../nodes/iteration/iteration_node.py | 138 +- .../nodes/iteration/iteration_start_node.py | 12 +- .../nodes/knowledge_retrieval/entities.py | 17 +- .../knowledge_retrieval_node.py | 224 ++- api/core/workflow/nodes/llm/entities.py | 14 +- api/core/workflow/nodes/llm/llm_node.py | 305 ++-- api/core/workflow/nodes/loop/entities.py | 4 +- api/core/workflow/nodes/loop/loop_node.py | 17 +- .../nodes/parameter_extractor/entities.py | 57 +- .../parameter_extractor_node.py | 460 +++--- .../nodes/parameter_extractor/prompts.py | 154 +- .../nodes/question_classifier/entities.py | 7 +- .../question_classifier_node.py | 182 +- .../question_classifier/template_prompts.py | 2 - api/core/workflow/nodes/start/entities.py | 1 + api/core/workflow/nodes/start/start_node.py | 14 +- .../nodes/template_transform/entities.py | 5 +- .../template_transform_node.py | 40 +- api/core/workflow/nodes/tool/entities.py | 41 +- api/core/workflow/nodes/tool/tool_node.py | 139 +- .../nodes/variable_aggregator/entities.py | 10 +- .../variable_aggregator_node.py | 25 +- .../nodes/variable_assigner/__init__.py | 6 +- .../workflow/nodes/variable_assigner/node.py | 26 +- .../nodes/variable_assigner/node_data.py | 10 +- api/core/workflow/utils/condition/entities.py | 19 +- .../workflow/utils/condition/processor.py | 52 +- .../utils/variable_template_parser.py | 20 +- api/core/workflow/workflow_entry.py | 121 +- api/pyproject.toml | 1 - 724 files changed, 21180 insertions(+), 21123 deletions(-) diff --git a/api/core/__init__.py b/api/core/__init__.py index 8c986fc8bd..6eaea7b1c8 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -1 +1 @@ -import core.moderation.base \ No newline at end of file +import core.moderation.base diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 89c948d2e2..29b428a7c3 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -25,17 +25,19 @@ from models.model import Message class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True - _ignore_observation_providers = ['wenxin'] + _ignore_observation_providers = ["wenxin"] _historic_prompt_messages: list[PromptMessage] = None _agent_scratchpad: list[AgentScratchpadUnit] = None _instruction: str = None _query: str = None _prompt_messages_tools: list[PromptMessage] = None - def run(self, message: Message, - query: str, - inputs: dict[str, str], - ) -> Union[Generator, LLMResult]: + def run( + self, + message: Message, + query: str, + inputs: dict[str, str], + ) -> Union[Generator, LLMResult]: """ Run Cot agent application """ @@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC): trace_manager = app_generate_entity.trace_manager # check model mode - if 'Observation' not in app_generate_entity.model_conf.stop: + if "Observation" not in app_generate_entity.model_conf.stop: if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: - app_generate_entity.model_conf.stop.append('Observation') + app_generate_entity.model_conf.stop.append("Observation") app_config = self.app_config # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template - self._instruction = self._fill_in_inputs_from_external_data_tools( - instruction, inputs) + self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 @@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC): tool_instances, self._prompt_messages_tools = self._init_prompt_tools() function_call_state = True - llm_usage = { - 'usage': None - } - final_answer = '' + llm_usage = {"usage": None} + final_answer = "" def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): - if not final_llm_usage_dict['usage']: - final_llm_usage_dict['usage'] = usage + if not final_llm_usage_dict["usage"]: + final_llm_usage_dict["usage"] = usage else: - llm_usage = final_llm_usage_dict['usage'] + llm_usage = final_llm_usage_dict["usage"] llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.completion_tokens += usage.completion_tokens llm_usage.prompt_price += usage.prompt_price @@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC): message_file_ids = [] agent_thought = self.create_agent_thought( - message_id=message.id, - message='', - tool_name='', - tool_input='', - messages_ids=message_file_ids + message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) if iteration_step > 1: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) # recalc llm max tokens prompt_messages = self._organize_prompt_messages() @@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC): raise ValueError("failed to invoke llm") usage_dict = {} - react_chunks = CotAgentOutputParser.handle_react_stream_output( - chunks, usage_dict) + react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( - agent_response='', - thought='', - action_str='', - observation='', + agent_response="", + thought="", + action_str="", + observation="", action=None, ) # publish agent thought if it's first iteration if iteration_step == 1: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) for chunk in react_chunks: if isinstance(chunk, AgentScratchpadUnit.Action): @@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC): yield LLMResultChunk( model=self.model_config.model, prompt_messages=prompt_messages, - system_fingerprint='', - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=chunk - ), - usage=None - ) + system_fingerprint="", + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), ) - scratchpad.thought = scratchpad.thought.strip( - ) or 'I am thinking about how to help you' + scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" self._agent_scratchpad.append(scratchpad) # get llm usage - if 'usage' in usage_dict: - increase_usage(llm_usage, usage_dict['usage']) + if "usage" in usage_dict: + increase_usage(llm_usage, usage_dict["usage"]) else: - usage_dict['usage'] = LLMUsage.empty_usage() + usage_dict["usage"] = LLMUsage.empty_usage() self.save_agent_thought( agent_thought=agent_thought, - tool_name=scratchpad.action.action_name if scratchpad.action else '', - tool_input={ - scratchpad.action.action_name: scratchpad.action.action_input - } if scratchpad.action else {}, + tool_name=scratchpad.action.action_name if scratchpad.action else "", + tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_invoke_meta={}, thought=scratchpad.thought, - observation='', + observation="", answer=scratchpad.agent_response, messages_ids=[], - llm_usage=usage_dict['usage'] + llm_usage=usage_dict["usage"], ) if not scratchpad.is_final(): - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) if not scratchpad.action: # failed to extract action, return final answer directly - final_answer = '' + final_answer = "" else: if scratchpad.action.action_name.lower() == "final answer": # action is final answer, return final answer directly try: if isinstance(scratchpad.action.action_input, dict): - final_answer = json.dumps( - scratchpad.action.action_input) + final_answer = json.dumps(scratchpad.action.action_input) elif isinstance(scratchpad.action.action_input, str): final_answer = scratchpad.action.action_input else: - final_answer = f'{scratchpad.action.action_input}' + final_answer = f"{scratchpad.action.action_input}" except json.JSONDecodeError: - final_answer = f'{scratchpad.action.action_input}' + final_answer = f"{scratchpad.action.action_input}" else: function_call_state = True # action is tool call, invoke tool @@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC): self.save_agent_thought( agent_thought=agent_thought, tool_name=scratchpad.action.action_name, - tool_input={ - scratchpad.action.action_name: scratchpad.action.action_input}, + tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, thought=scratchpad.thought, - observation={ - scratchpad.action.action_name: tool_invoke_response}, - tool_invoke_meta={ - scratchpad.action.action_name: tool_invoke_meta.to_dict()}, + observation={scratchpad.action.action_name: tool_invoke_response}, + tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, answer=scratchpad.agent_response, messages_ids=message_file_ids, - llm_usage=usage_dict['usage'] + llm_usage=usage_dict["usage"], ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) # update prompt tool message for prompt_tool in self._prompt_messages_tools: @@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC): model=model_instance.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=final_answer - ), - usage=llm_usage['usage'] + index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] ), - system_fingerprint='' + system_fingerprint="", ) # save agent thought self.save_agent_thought( agent_thought=agent_thought, - tool_name='', + tool_name="", tool_input={}, tool_invoke_meta={}, thought=final_answer, observation={}, answer=final_answer, - messages_ids=[] + messages_ids=[], ) self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=final_answer + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(), + system_fingerprint="", + ) ), - usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), - system_fingerprint='' - )), PublishFrom.APPLICATION_MANAGER) + PublishFrom.APPLICATION_MANAGER, + ) - def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, - tool_instances: dict[str, Tool], - message_file_ids: list[str], - trace_manager: Optional[TraceQueueManager] = None - ) -> tuple[str, ToolInvokeMeta]: + def _handle_invoke_action( + self, + action: AgentScratchpadUnit.Action, + tool_instances: dict[str, Tool], + message_file_ids: list[str], + trace_manager: Optional[TraceQueueManager] = None, + ) -> tuple[str, ToolInvokeMeta]: """ handle invoke action :param action: action @@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC): # publish files for message_file_id, save_as in message_files: if save_as: - self.variables_pool.set_file( - tool_name=tool_call_name, value=message_file_id, name=save_as) + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file_id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + ) # add message file ids message_file_ids.append(message_file_id) @@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): """ convert dict to action """ - return AgentScratchpadUnit.Action( - action_name=action['action'], - action_input=action['action_input'] - ) + return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: """ @@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): """ for key, value in inputs.items(): try: - instruction = instruction.replace(f'{{{{{key}}}}}', str(value)) + instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) except Exception as e: continue @@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC): @abstractmethod def _organize_prompt_messages(self) -> list[PromptMessage]: """ - organize prompt messages + organize prompt messages """ def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: """ - format assistant message + format assistant message """ - message = '' + message = "" for scratchpad in agent_scratchpad: if scratchpad.is_final(): message += f"Final Answer: {scratchpad.agent_response}" @@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC): return message - def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_historic_prompt_messages( + self, current_session_messages: list[PromptMessage] = None + ) -> list[PromptMessage]: """ - organize historic prompt messages + organize historic prompt messages """ result: list[PromptMessage] = [] scratchpads: list[AgentScratchpadUnit] = [] @@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): if not current_scratchpad: current_scratchpad = AgentScratchpadUnit( agent_response=message.content, - thought=message.content or 'I am thinking about how to help you', - action_str='', + thought=message.content or "I am thinking about how to help you", + action_str="", action=None, observation=None, ) @@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC): try: current_scratchpad.action = AgentScratchpadUnit.Action( action_name=message.tool_calls[0].function.name, - action_input=json.loads( - message.tool_calls[0].function.arguments) - ) - current_scratchpad.action_str = json.dumps( - current_scratchpad.action.to_dict() + action_input=json.loads(message.tool_calls[0].function.arguments), ) + current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) except: pass elif isinstance(message, ToolPromptMessage): @@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC): current_scratchpad.observation = message.content elif isinstance(message, UserPromptMessage): if scratchpads: - result.append(AssistantPromptMessage( - content=self._format_assistant_message(scratchpads) - )) + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) scratchpads = [] current_scratchpad = None result.append(message) if scratchpads: - result.append(AssistantPromptMessage( - content=self._format_assistant_message(scratchpads) - )) + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) historic_prompts = AgentHistoryPromptTransform( model_config=self.model_config, prompt_messages=current_session_messages or [], history_messages=result, - memory=self.memory + memory=self.memory, ).get_prompt() return historic_prompts diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 8debbe5c5d..bdec6b7ed1 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner): prompt_entity = self.app_config.agent.prompt first_prompt = prompt_entity.first_prompt - system_prompt = first_prompt \ - .replace("{{instruction}}", self._instruction) \ - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ - .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) + system_prompt = ( + first_prompt.replace("{{instruction}}", self._instruction) + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) + .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) + ) return SystemPromptMessage(content=system_prompt) - def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: """ Organize user query """ @@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner): def _organize_prompt_messages(self) -> list[PromptMessage]: """ - Organize + Organize """ # organize system prompt system_message = self._organize_system_prompt() @@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner): if not agent_scratchpad: assistant_messages = [] else: - assistant_message = AssistantPromptMessage(content='') + assistant_message = AssistantPromptMessage(content="") for unit in agent_scratchpad: if unit.is_final(): assistant_message.content += f"Final Answer: {unit.agent_response}" @@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner): if assistant_messages: # organize historic prompt messages - historic_messages = self._organize_historic_prompt_messages([ - system_message, - *query_messages, - *assistant_messages, - UserPromptMessage(content='continue') - ]) + historic_messages = self._organize_historic_prompt_messages( + [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")] + ) messages = [ system_message, *historic_messages, *query_messages, *assistant_messages, - UserPromptMessage(content='continue') + UserPromptMessage(content="continue"), ] else: # organize historic prompt messages diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 9e6eb54f4f..9dab956f9a 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner): prompt_entity = self.app_config.agent.prompt first_prompt = prompt_entity.first_prompt - system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \ - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ - .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) - + system_prompt = ( + first_prompt.replace("{{instruction}}", self._instruction) + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) + .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) + ) + return system_prompt def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str: @@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner): # organize current assistant messages agent_scratchpad = self._agent_scratchpad - assistant_prompt = '' + assistant_prompt = "" for unit in agent_scratchpad: if unit.is_final(): assistant_prompt += f"Final Answer: {unit.agent_response}" @@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner): query_prompt = f"Question: {self._query}" # join all messages - prompt = system_prompt \ - .replace("{{historic_messages}}", historic_prompt) \ - .replace("{{agent_scratchpad}}", assistant_prompt) \ + prompt = ( + system_prompt.replace("{{historic_messages}}", historic_prompt) + .replace("{{agent_scratchpad}}", assistant_prompt) .replace("{{query}}", query_prompt) + ) - return [UserPromptMessage(content=prompt)] \ No newline at end of file + return [UserPromptMessage(content=prompt)] diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 5274224de5..119a88fc7b 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel): """ Agent Tool Entity. """ + provider_type: Literal["builtin", "api", "workflow"] provider_id: str tool_name: str @@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel): """ Agent Prompt Entity. """ + first_prompt: str next_iteration: str @@ -31,6 +33,7 @@ class AgentScratchpadUnit(BaseModel): """ Action Entity. """ + action_name: str action_input: Union[dict, str] @@ -39,8 +42,8 @@ class AgentScratchpadUnit(BaseModel): Convert to dictionary. """ return { - 'action': self.action_name, - 'action_input': self.action_input, + "action": self.action_name, + "action_input": self.action_input, } agent_response: Optional[str] = None @@ -54,10 +57,10 @@ class AgentScratchpadUnit(BaseModel): Check if the scratchpad unit is final. """ return self.action is None or ( - 'final' in self.action.action_name.lower() and - 'answer' in self.action.action_name.lower() + "final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower() ) + class AgentEntity(BaseModel): """ Agent Entity. @@ -67,8 +70,9 @@ class AgentEntity(BaseModel): """ Agent Strategy. """ - CHAIN_OF_THOUGHT = 'chain-of-thought' - FUNCTION_CALLING = 'function-calling' + + CHAIN_OF_THOUGHT = "chain-of-thought" + FUNCTION_CALLING = "function-calling" provider: str model: str diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 3ee6e47742..27cf561e3d 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -24,11 +24,9 @@ from models.model import Message logger = logging.getLogger(__name__) -class FunctionCallAgentRunner(BaseAgentRunner): - def run(self, - message: Message, query: str, **kwargs: Any - ) -> Generator[LLMResultChunk, None, None]: +class FunctionCallAgentRunner(BaseAgentRunner): + def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: """ Run FunctionCall agent application """ @@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner): # continue to run until there is not any tool call function_call_state = True - llm_usage = { - 'usage': None - } - final_answer = '' + llm_usage = {"usage": None} + final_answer = "" # get tracing instance trace_manager = app_generate_entity.trace_manager - + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): - if not final_llm_usage_dict['usage']: - final_llm_usage_dict['usage'] = usage + if not final_llm_usage_dict["usage"]: + final_llm_usage_dict["usage"] = usage else: - llm_usage = final_llm_usage_dict['usage'] + llm_usage = final_llm_usage_dict["usage"] llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.completion_tokens += usage.completion_tokens llm_usage.prompt_price += usage.prompt_price @@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): message_file_ids = [] agent_thought = self.create_agent_thought( - message_id=message.id, - message='', - tool_name='', - tool_input='', - messages_ids=message_file_ids + message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) # recalc llm max tokens @@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_calls: list[tuple[str, str, dict[str, Any]]] = [] # save full response - response = '' + response = "" # save tool call names and inputs - tool_call_names = '' - tool_call_inputs = '' + tool_call_names = "" + tool_call_inputs = "" current_llm_usage = None @@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner): is_first_chunk = True for chunk in chunks: if is_first_chunk: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) is_first_chunk = False # check if there is any tool call if self.check_tool_calls(chunk): function_call_state = True tool_calls.extend(self.extract_tool_calls(chunk)) - tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) + tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }, ensure_ascii=False) + tool_call_inputs = json.dumps( + {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False + ) except json.JSONDecodeError as e: # ensure ascii to avoid encoding error - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }) + tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if chunk.delta.message and chunk.delta.message.content: if isinstance(chunk.delta.message.content, list): @@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner): if self.check_blocking_tool_calls(result): function_call_state = True tool_calls.extend(self.extract_blocking_tool_calls(result)) - tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) + tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }, ensure_ascii=False) + tool_call_inputs = json.dumps( + {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False + ) except json.JSONDecodeError as e: # ensure ascii to avoid encoding error - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }) + tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if result.usage: increase_usage(llm_usage, result.usage) @@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner): response += result.message.content if not result.message.content: - result.message.content = '' + result.message.content = "" + + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) - yield LLMResultChunk( model=model_instance.model, prompt_messages=result.prompt_messages, @@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner): index=0, message=result.message, usage=result.usage, - ) + ), ) - assistant_message = AssistantPromptMessage( - content='', - tool_calls=[] - ) + assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if tool_calls: - assistant_message.tool_calls=[ + assistant_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=tool_call[0], - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call[1], - arguments=json.dumps(tool_call[2], ensure_ascii=False) - ) - ) for tool_call in tool_calls + name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) + ), + ) + for tool_call in tool_calls ] else: assistant_message.content = response - + self._current_thoughts.append(assistant_message) # save thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=tool_call_names, tool_input=tool_call_inputs, thought=response, @@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner): observation=None, answer=response, messages_ids=[], - llm_usage=current_llm_usage + llm_usage=current_llm_usage, ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) - - final_answer += response + '\n' + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + + final_answer += response + "\n" # call tools tool_responses = [] @@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, "tool_response": f"there is not a tool named {tool_call_name}", - "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict() + "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), } else: # invoke tool @@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner): self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file_id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + ) # add message file ids message_file_ids.append(message_file_id) - + tool_response = { "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, "tool_response": tool_invoke_response, - "meta": tool_invoke_meta.to_dict() + "meta": tool_invoke_meta.to_dict(), } - + tool_responses.append(tool_response) - if tool_response['tool_response'] is not None: + if tool_response["tool_response"] is not None: self._current_thoughts.append( ToolPromptMessage( - content=tool_response['tool_response'], + content=tool_response["tool_response"], tool_call_id=tool_call_id, name=tool_call_name, ) - ) + ) if len(tool_responses) > 0: # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=None, tool_input=None, - thought=None, + thought=None, tool_invoke_meta={ - tool_response['tool_call_name']: tool_response['meta'] - for tool_response in tool_responses + tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses }, observation={ - tool_response['tool_call_name']: tool_response['tool_response'] + tool_response["tool_call_name"]: tool_response["tool_response"] for tool_response in tool_responses }, answer=None, - messages_ids=message_file_ids + messages_ids=message_file_ids, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) # update prompt tool for prompt_tool in prompt_messages_tools: @@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner): self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=final_answer + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(), + system_fingerprint="", + ) ), - usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), - system_fingerprint='' - )), PublishFrom.APPLICATION_MANAGER) + PublishFrom.APPLICATION_MANAGER, + ) def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: """ @@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): if llm_result_chunk.delta.message.tool_calls: return True return False - + def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: """ Check if there is any blocking tool call in llm result @@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): return True return False - def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: + def extract_tool_calls( + self, llm_result_chunk: LLMResultChunk + ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract tool calls from llm result chunk @@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_calls = [] for prompt_message in llm_result_chunk.delta.message.tool_calls: args = {} - if prompt_message.function.arguments != '': + if prompt_message.function.arguments != "": args = json.loads(prompt_message.function.arguments) - tool_calls.append(( - prompt_message.id, - prompt_message.function.name, - args, - )) + tool_calls.append( + ( + prompt_message.id, + prompt_message.function.name, + args, + ) + ) return tool_calls - + def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract blocking tool calls from llm result @@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_calls = [] for prompt_message in llm_result.message.tool_calls: args = {} - if prompt_message.function.arguments != '': + if prompt_message.function.arguments != "": args = json.loads(prompt_message.function.arguments) - tool_calls.append(( - prompt_message.id, - prompt_message.function.name, - args, - )) + tool_calls.append( + ( + prompt_message.id, + prompt_message.function.name, + args, + ) + ) return tool_calls - def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _init_system_message( + self, prompt_template: str, prompt_messages: list[PromptMessage] = None + ) -> list[PromptMessage]: """ Initialize system message """ @@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner): return [ SystemPromptMessage(content=prompt_template), ] - + if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) return prompt_messages - def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: """ Organize user query """ @@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): prompt_messages.append(UserPromptMessage(content=query)) return prompt_messages - + def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ As for now, gpt supports both fc and vision at the first iteration. @@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner): for prompt_message in prompt_messages: if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, list): - prompt_message.content = '\n'.join([ - content.data if content.type == PromptMessageContentType.TEXT else - '[image]' if content.type == PromptMessageContentType.IMAGE else - '[file]' - for content in prompt_message.content - ]) + prompt_message.content = "\n".join( + [ + content.data + if content.type == PromptMessageContentType.TEXT + else "[image]" + if content.type == PromptMessageContentType.IMAGE + else "[file]" + for content in prompt_message.content + ] + ) return prompt_messages def _organize_prompt_messages(self): - prompt_template = self.app_config.prompt_template.simple_prompt_template or '' + prompt_template = self.app_config.prompt_template.simple_prompt_template or "" self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) query_prompt_messages = self._organize_user_query(self.query, []) @@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner): model_config=self.model_config, prompt_messages=[*query_prompt_messages, *self._current_thoughts], history_messages=self.history_prompt_messages, - memory=self.memory + memory=self.memory, ).get_prompt() - prompt_messages = [ - *self.history_prompt_messages, - *query_prompt_messages, - *self._current_thoughts - ] + prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] if len(self._current_thoughts) != 0: # clear messages after the first iteration prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index c53fa5000e..1a161677dd 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: @classmethod - def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \ - Generator[Union[str, AgentScratchpadUnit.Action], None, None]: + def handle_react_stream_output( + cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict + ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: def parse_action(json_str): try: action = json.loads(json_str) @@ -22,7 +23,7 @@ class CotAgentOutputParser: action = action[0] for key, value in action.items(): - if 'input' in key.lower(): + if "input" in key.lower(): action_input = value else: action_name = value @@ -33,37 +34,37 @@ class CotAgentOutputParser: action_input=action_input, ) else: - return json_str or '' + return json_str or "" except: - return json_str or '' - + return json_str or "" + def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: - code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) + code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) if not code_blocks: return for block in code_blocks: - json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) + json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE) yield parse_action(json_text) - - code_block_cache = '' + + code_block_cache = "" code_block_delimiter_count = 0 in_code_block = False - json_cache = '' + json_cache = "" json_quote_count = 0 in_json = False got_json = False - action_cache = '' - action_str = 'action:' + action_cache = "" + action_str = "action:" action_idx = 0 - thought_cache = '' - thought_str = 'thought:' + thought_cache = "" + thought_str = "thought:" thought_idx = 0 for response in llm_response: if response.delta.usage: - usage_dict['usage'] = response.delta.usage + usage_dict["usage"] = response.delta.usage response = response.delta.message.content if not isinstance(response, str): continue @@ -72,24 +73,24 @@ class CotAgentOutputParser: index = 0 while index < len(response): steps = 1 - delta = response[index:index+steps] - last_character = response[index-1] if index > 0 else '' + delta = response[index : index + steps] + last_character = response[index - 1] if index > 0 else "" - if delta == '`': + if delta == "`": code_block_cache += delta code_block_delimiter_count += 1 else: if not in_code_block: if code_block_delimiter_count > 0: yield code_block_cache - code_block_cache = '' + code_block_cache = "" else: code_block_cache += delta code_block_delimiter_count = 0 if not in_code_block and not in_json: if delta.lower() == action_str[action_idx] and action_idx == 0: - if last_character not in ['\n', ' ', '']: + if last_character not in ["\n", " ", ""]: index += steps yield delta continue @@ -97,7 +98,7 @@ class CotAgentOutputParser: action_cache += delta action_idx += 1 if action_idx == len(action_str): - action_cache = '' + action_cache = "" action_idx = 0 index += steps continue @@ -105,18 +106,18 @@ class CotAgentOutputParser: action_cache += delta action_idx += 1 if action_idx == len(action_str): - action_cache = '' + action_cache = "" action_idx = 0 index += steps continue else: if action_cache: yield action_cache - action_cache = '' + action_cache = "" action_idx = 0 - + if delta.lower() == thought_str[thought_idx] and thought_idx == 0: - if last_character not in ['\n', ' ', '']: + if last_character not in ["\n", " ", ""]: index += steps yield delta continue @@ -124,7 +125,7 @@ class CotAgentOutputParser: thought_cache += delta thought_idx += 1 if thought_idx == len(thought_str): - thought_cache = '' + thought_cache = "" thought_idx = 0 index += steps continue @@ -132,31 +133,31 @@ class CotAgentOutputParser: thought_cache += delta thought_idx += 1 if thought_idx == len(thought_str): - thought_cache = '' + thought_cache = "" thought_idx = 0 index += steps continue else: if thought_cache: yield thought_cache - thought_cache = '' + thought_cache = "" thought_idx = 0 if code_block_delimiter_count == 3: if in_code_block: yield from extra_json_from_code_block(code_block_cache) - code_block_cache = '' - + code_block_cache = "" + in_code_block = not in_code_block code_block_delimiter_count = 0 if not in_code_block: # handle single json - if delta == '{': + if delta == "{": json_quote_count += 1 in_json = True json_cache += delta - elif delta == '}': + elif delta == "}": json_cache += delta if json_quote_count > 0: json_quote_count -= 1 @@ -172,12 +173,12 @@ class CotAgentOutputParser: if got_json: got_json = False yield parse_action(json_cache) - json_cache = '' + json_cache = "" json_quote_count = 0 in_json = False - + if not in_code_block and not in_json: - yield delta.replace('`', '') + yield delta.replace("`", "") index += steps @@ -186,4 +187,3 @@ class CotAgentOutputParser: if json_cache: yield parse_action(json_cache) - diff --git a/api/core/agent/prompt/template.py b/api/core/agent/prompt/template.py index b0cf1a77fb..cb98f5501d 100644 --- a/api/core/agent/prompt/template.py +++ b/api/core/agent/prompt/template.py @@ -91,14 +91,14 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = "" REACT_PROMPT_TEMPLATES = { - 'english': { - 'chat': { - 'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, - 'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES + "english": { + "chat": { + "prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, + "agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES, + }, + "completion": { + "prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, + "agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES, }, - 'completion': { - 'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, - 'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES - } } -} \ No newline at end of file +} diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py index 3dea305e98..0fd2a779a4 100644 --- a/api/core/app/app_config/base_app_config_manager.py +++ b/api/core/app/app_config/base_app_config_manager.py @@ -26,34 +26,24 @@ class BaseAppConfigManager: config_dict = dict(config_dict.items()) additional_features = AppAdditionalFeatures() - additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert( - config=config_dict - ) + additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) additional_features.file_upload = FileUploadConfigManager.convert( - config=config_dict, - is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] + config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] ) - additional_features.opening_statement, additional_features.suggested_questions = \ - OpeningStatementConfigManager.convert( - config=config_dict - ) + additional_features.opening_statement, additional_features.suggested_questions = ( + OpeningStatementConfigManager.convert(config=config_dict) + ) additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert( config=config_dict ) - additional_features.more_like_this = MoreLikeThisConfigManager.convert( - config=config_dict - ) + additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict) - additional_features.speech_to_text = SpeechToTextConfigManager.convert( - config=config_dict - ) + additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict) - additional_features.text_to_speech = TextToSpeechConfigManager.convert( - config=config_dict - ) + additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict) return additional_features diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 1ca8b1e3b8..037037e6ca 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -7,25 +7,24 @@ from core.moderation.factory import ModerationFactory class SensitiveWordAvoidanceConfigManager: @classmethod def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: - sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance') + sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") if not sensitive_word_avoidance_dict: return None - if sensitive_word_avoidance_dict.get('enabled'): + if sensitive_word_avoidance_dict.get("enabled"): return SensitiveWordAvoidanceEntity( - type=sensitive_word_avoidance_dict.get('type'), - config=sensitive_word_avoidance_dict.get('config'), + type=sensitive_word_avoidance_dict.get("type"), + config=sensitive_word_avoidance_dict.get("config"), ) else: return None @classmethod - def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ - -> tuple[dict, list[str]]: + def validate_and_set_defaults( + cls, tenant_id, config: dict, only_structure_validate: bool = False + ) -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): - config["sensitive_word_avoidance"] = { - "enabled": False - } + config["sensitive_word_avoidance"] = {"enabled": False} if not isinstance(config["sensitive_word_avoidance"], dict): raise ValueError("sensitive_word_avoidance must be of dict type") @@ -41,10 +40,6 @@ class SensitiveWordAvoidanceConfigManager: typ = config["sensitive_word_avoidance"]["type"] sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] - ModerationFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=sensitive_word_avoidance_config - ) + ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) return config, ["sensitive_word_avoidance"] diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index dc65d4439b..6e89f19508 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -12,67 +12,70 @@ class AgentConfigManager: :param config: model config args """ - if 'agent_mode' in config and config['agent_mode'] \ - and 'enabled' in config['agent_mode']: + if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]: + agent_dict = config.get("agent_mode", {}) + agent_strategy = agent_dict.get("strategy", "cot") - agent_dict = config.get('agent_mode', {}) - agent_strategy = agent_dict.get('strategy', 'cot') - - if agent_strategy == 'function_call': + if agent_strategy == "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy == 'cot' or agent_strategy == 'react': + elif agent_strategy == "cot" or agent_strategy == "react": strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT else: # old configs, try to detect default strategy - if config['model']['provider'] == 'openai': + if config["model"]["provider"] == "openai": strategy = AgentEntity.Strategy.FUNCTION_CALLING else: strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT agent_tools = [] - for tool in agent_dict.get('tools', []): + for tool in agent_dict.get("tools", []): keys = tool.keys() if len(keys) >= 4: if "enabled" not in tool or not tool["enabled"]: continue agent_tool_properties = { - 'provider_type': tool['provider_type'], - 'provider_id': tool['provider_id'], - 'tool_name': tool['tool_name'], - 'tool_parameters': tool.get('tool_parameters', {}) + "provider_type": tool["provider_type"], + "provider_id": tool["provider_id"], + "tool_name": tool["tool_name"], + "tool_parameters": tool.get("tool_parameters", {}), } agent_tools.append(AgentToolEntity(**agent_tool_properties)) - if 'strategy' in config['agent_mode'] and \ - config['agent_mode']['strategy'] not in ['react_router', 'router']: - agent_prompt = agent_dict.get('prompt', None) or {} + if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [ + "react_router", + "router", + ]: + agent_prompt = agent_dict.get("prompt", None) or {} # check model mode - model_mode = config.get('model', {}).get('mode', 'completion') - if model_mode == 'completion': + model_mode = config.get("model", {}).get("mode", "completion") + if model_mode == "completion": agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['completion'][ - 'agent_scratchpad']), + first_prompt=agent_prompt.get( + "first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"] + ), + next_iteration=agent_prompt.get( + "next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"] + ), ) else: agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), + first_prompt=agent_prompt.get( + "first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"] + ), + next_iteration=agent_prompt.get( + "next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"] + ), ) return AgentEntity( - provider=config['model']['provider'], - model=config['model']['name'], + provider=config["model"]["provider"], + model=config["model"]["name"], strategy=strategy, prompt=agent_prompt_entity, tools=agent_tools, - max_iteration=agent_dict.get('max_iteration', 5) + max_iteration=agent_dict.get("max_iteration", 5), ) return None diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 1a621d2090..ff131b62e2 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -15,39 +15,38 @@ class DatasetConfigManager: :param config: model config args """ dataset_ids = [] - if 'datasets' in config.get('dataset_configs', {}): - datasets = config.get('dataset_configs', {}).get('datasets', { - 'strategy': 'router', - 'datasets': [] - }) + if "datasets" in config.get("dataset_configs", {}): + datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []}) - for dataset in datasets.get('datasets', []): + for dataset in datasets.get("datasets", []): keys = list(dataset.keys()) - if len(keys) == 0 or keys[0] != 'dataset': + if len(keys) == 0 or keys[0] != "dataset": continue - dataset = dataset['dataset'] + dataset = dataset["dataset"] - if 'enabled' not in dataset or not dataset['enabled']: + if "enabled" not in dataset or not dataset["enabled"]: continue - dataset_id = dataset.get('id', None) + dataset_id = dataset.get("id", None) if dataset_id: dataset_ids.append(dataset_id) - if 'agent_mode' in config and config['agent_mode'] \ - and 'enabled' in config['agent_mode'] \ - and config['agent_mode']['enabled']: + if ( + "agent_mode" in config + and config["agent_mode"] + and "enabled" in config["agent_mode"] + and config["agent_mode"]["enabled"] + ): + agent_dict = config.get("agent_mode", {}) - agent_dict = config.get('agent_mode', {}) - - for tool in agent_dict.get('tools', []): + for tool in agent_dict.get("tools", []): keys = tool.keys() if len(keys) == 1: # old standard key = list(tool.keys())[0] - if key != 'dataset': + if key != "dataset": continue tool_item = tool[key] @@ -55,30 +54,28 @@ class DatasetConfigManager: if "enabled" not in tool_item or not tool_item["enabled"]: continue - dataset_id = tool_item['id'] + dataset_id = tool_item["id"] dataset_ids.append(dataset_id) if len(dataset_ids) == 0: return None # dataset configs - if 'dataset_configs' in config and config.get('dataset_configs'): - dataset_configs = config.get('dataset_configs') + if "dataset_configs" in config and config.get("dataset_configs"): + dataset_configs = config.get("dataset_configs") else: - dataset_configs = { - 'retrieval_model': 'multiple' - } - query_variable = config.get('dataset_query_variable') + dataset_configs = {"retrieval_model": "multiple"} + query_variable = config.get("dataset_query_variable") - if dataset_configs['retrieval_model'] == 'single': + if dataset_configs["retrieval_model"] == "single": return DatasetEntity( dataset_ids=dataset_ids, retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ) - ) + dataset_configs["retrieval_model"] + ), + ), ) else: return DatasetEntity( @@ -86,15 +83,15 @@ class DatasetConfigManager: retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] + dataset_configs["retrieval_model"] ), - top_k=dataset_configs.get('top_k', 4), - score_threshold=dataset_configs.get('score_threshold'), - reranking_model=dataset_configs.get('reranking_model'), - weights=dataset_configs.get('weights'), - reranking_enabled=dataset_configs.get('reranking_enabled', True), - rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'), - ) + top_k=dataset_configs.get("top_k", 4), + score_threshold=dataset_configs.get("score_threshold"), + reranking_model=dataset_configs.get("reranking_model"), + weights=dataset_configs.get("weights"), + reranking_enabled=dataset_configs.get("reranking_enabled", True), + rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), + ), ) @classmethod @@ -111,13 +108,10 @@ class DatasetConfigManager: # dataset_configs if not config.get("dataset_configs"): - config["dataset_configs"] = {'retrieval_model': 'single'} + config["dataset_configs"] = {"retrieval_model": "single"} if not config["dataset_configs"].get("datasets"): - config["dataset_configs"]["datasets"] = { - "strategy": "router", - "datasets": [] - } + config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") @@ -125,8 +119,9 @@ class DatasetConfigManager: if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") - need_manual_query_datasets = (config.get("dataset_configs") - and config["dataset_configs"].get("datasets", {}).get("datasets")) + need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( + "datasets", {} + ).get("datasets") if need_manual_query_datasets and app_mode == AppMode.COMPLETION: # Only check when mode is completion @@ -148,10 +143,7 @@ class DatasetConfigManager: """ # Extract dataset config for legacy compatibility if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } + config["agent_mode"] = {"enabled": False, "tools": []} if not isinstance(config["agent_mode"], dict): raise ValueError("agent_mode must be of object type") @@ -188,7 +180,7 @@ class DatasetConfigManager: if not isinstance(tool_item["enabled"], bool): raise ValueError("enabled in agent_mode.tools must be of boolean type") - if 'id' not in tool_item: + if "id" not in tool_item: raise ValueError("id is required in dataset") try: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 5c9b2cfec7..a91b9f0f02 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -11,9 +11,7 @@ from core.provider_manager import ProviderManager class ModelConfigConverter: @classmethod - def convert(cls, app_config: EasyUIBasedAppConfig, - skip_check: bool = False) \ - -> ModelConfigWithCredentialsEntity: + def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: """ Convert app model config dict to entity. :param app_config: app config @@ -25,9 +23,7 @@ class ModelConfigConverter: provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=app_config.tenant_id, - provider=model_config.provider, - model_type=ModelType.LLM + tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM ) provider_name = provider_model_bundle.configuration.provider.provider @@ -38,8 +34,7 @@ class ModelConfigConverter: # check model credentials model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, - model=model_config.model + model_type=ModelType.LLM, model=model_config.model ) if model_credentials is None: @@ -51,8 +46,7 @@ class ModelConfigConverter: if not skip_check: # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_config.model, - model_type=ModelType.LLM + model=model_config.model, model_type=ModelType.LLM ) if provider_model is None: @@ -69,24 +63,18 @@ class ModelConfigConverter: # model config completion_params = model_config.parameters stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = model_config.mode if not model_mode: - mode_enum = model_type_instance.get_model_mode( - model=model_config.model, - credentials=model_credentials - ) + mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials) model_mode = mode_enum.value - model_schema = model_type_instance.get_model_schema( - model_config.model, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) if not skip_check and not model_schema: raise ValueError(f"Model {model_name} not exist.") diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 730a9527cf..b5e4554181 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -13,23 +13,23 @@ class ModelConfigManager: :param config: model config args """ # model config - model_config = config.get('model') + model_config = config.get("model") if not model_config: raise ValueError("model is required") - completion_params = model_config.get('completion_params') + completion_params = model_config.get("completion_params") stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode - model_mode = model_config.get('mode') + model_mode = model_config.get("mode") return ModelConfigEntity( - provider=config['model']['provider'], - model=config['model']['name'], + provider=config["model"]["provider"], + model=config["model"]["name"], mode=model_mode, parameters=completion_params, stop=stop, @@ -43,7 +43,7 @@ class ModelConfigManager: :param tenant_id: tenant id :param config: app model config args """ - if 'model' not in config: + if "model" not in config: raise ValueError("model is required") if not isinstance(config["model"], dict): @@ -52,17 +52,16 @@ class ModelConfigManager: # model.provider provider_entities = model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] - if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: + if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") # model.name - if 'name' not in config["model"]: + if "name" not in config["model"]: raise ValueError("model.name is required") provider_manager = ProviderManager() models = provider_manager.get_configurations(tenant_id).get_models( - provider=config["model"]["provider"], - model_type=ModelType.LLM + provider=config["model"]["provider"], model_type=ModelType.LLM ) if not models: @@ -80,12 +79,12 @@ class ModelConfigManager: # model.mode if model_mode: - config['model']["mode"] = model_mode + config["model"]["mode"] = model_mode else: - config['model']["mode"] = "completion" + config["model"]["mode"] = "completion" # model.completion_params - if 'completion_params' not in config["model"]: + if "completion_params" not in config["model"]: raise ValueError("model.completion_params is required") config["model"]["completion_params"] = cls.validate_model_completion_params( @@ -101,7 +100,7 @@ class ModelConfigManager: raise ValueError("model.completion_params must be of object type") # stop - if 'stop' not in cp: + if "stop" not in cp: cp["stop"] = [] elif not isinstance(cp["stop"], list): raise ValueError("stop in model.completion_params must be of list type") diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 1f410758aa..de91c9a065 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -14,39 +14,33 @@ class PromptTemplateConfigManager: if not config.get("prompt_type"): raise ValueError("prompt_type is required") - prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type']) + prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"]) if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: simple_prompt_template = config.get("pre_prompt", "") - return PromptTemplateEntity( - prompt_type=prompt_type, - simple_prompt_template=simple_prompt_template - ) + return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template) else: advanced_chat_prompt_template = None chat_prompt_config = config.get("chat_prompt_config", {}) if chat_prompt_config: chat_prompt_messages = [] for message in chat_prompt_config.get("prompt", []): - chat_prompt_messages.append({ - "text": message["text"], - "role": PromptMessageRole.value_of(message["role"]) - }) + chat_prompt_messages.append( + {"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} + ) - advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( - messages=chat_prompt_messages - ) + advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) advanced_completion_prompt_template = None completion_prompt_config = config.get("completion_prompt_config", {}) if completion_prompt_config: completion_prompt_template_params = { - 'prompt': completion_prompt_config['prompt']['text'], + "prompt": completion_prompt_config["prompt"]["text"], } - if 'conversation_histories_role' in completion_prompt_config: - completion_prompt_template_params['role_prefix'] = { - 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], - 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] + if "conversation_histories_role" in completion_prompt_config: + completion_prompt_template_params["role_prefix"] = { + "user": completion_prompt_config["conversation_histories_role"]["user_prefix"], + "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"], } advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( @@ -56,7 +50,7 @@ class PromptTemplateConfigManager: return PromptTemplateEntity( prompt_type=prompt_type, advanced_chat_prompt_template=advanced_chat_prompt_template, - advanced_completion_prompt_template=advanced_completion_prompt_template + advanced_completion_prompt_template=advanced_completion_prompt_template, ) @classmethod @@ -72,7 +66,7 @@ class PromptTemplateConfigManager: config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] - if config['prompt_type'] not in prompt_type_vals: + if config["prompt_type"] not in prompt_type_vals: raise ValueError(f"prompt_type must be in {prompt_type_vals}") # chat_prompt_config @@ -89,27 +83,28 @@ class PromptTemplateConfigManager: if not isinstance(config["completion_prompt_config"], dict): raise ValueError("completion_prompt_config must be of object type") - if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value: - if not config['chat_prompt_config'] and not config['completion_prompt_config']: - raise ValueError("chat_prompt_config or completion_prompt_config is required " - "when prompt_type is advanced") + if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value: + if not config["chat_prompt_config"] and not config["completion_prompt_config"]: + raise ValueError( + "chat_prompt_config or completion_prompt_config is required " "when prompt_type is advanced" + ) model_mode_vals = [mode.value for mode in ModelMode] - if config['model']["mode"] not in model_mode_vals: + if config["model"]["mode"] not in model_mode_vals: raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") - if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value: - user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] - assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] + if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value: + user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] + assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] if not user_prefix: - config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' + config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human" if not assistant_prefix: - config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' + config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant" - if config['model']["mode"] == ModelMode.CHAT.value: - prompt_list = config['chat_prompt_config']['prompt'] + if config["model"]["mode"] == ModelMode.CHAT.value: + prompt_list = config["chat_prompt_config"]["prompt"] if len(prompt_list) > 10: raise ValueError("prompt messages must be less than 10") diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 15fa4d99fd..2c0232c743 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -16,32 +16,30 @@ class BasicVariablesConfigManager: variable_entities = [] # old external_data_tools - external_data_tools = config.get('external_data_tools', []) + external_data_tools = config.get("external_data_tools", []) for external_data_tool in external_data_tools: - if 'enabled' not in external_data_tool or not external_data_tool['enabled']: + if "enabled" not in external_data_tool or not external_data_tool["enabled"]: continue external_data_variables.append( ExternalDataVariableEntity( - variable=external_data_tool['variable'], - type=external_data_tool['type'], - config=external_data_tool['config'] + variable=external_data_tool["variable"], + type=external_data_tool["type"], + config=external_data_tool["config"], ) ) # variables and external_data_tools - for variables in config.get('user_input_form', []): + for variables in config.get("user_input_form", []): variable_type = list(variables.keys())[0] if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: variable = variables[variable_type] - if 'config' not in variable: + if "config" not in variable: continue external_data_variables.append( ExternalDataVariableEntity( - variable=variable['variable'], - type=variable['type'], - config=variable['config'] + variable=variable["variable"], type=variable["type"], config=variable["config"] ) ) elif variable_type in [ @@ -54,13 +52,13 @@ class BasicVariablesConfigManager: variable_entities.append( VariableEntity( type=variable_type, - variable=variable.get('variable'), - description=variable.get('description'), - label=variable.get('label'), - required=variable.get('required', False), - max_length=variable.get('max_length'), - options=variable.get('options'), - default=variable.get('default'), + variable=variable.get("variable"), + description=variable.get("description"), + label=variable.get("label"), + required=variable.get("required", False), + max_length=variable.get("max_length"), + options=variable.get("options"), + default=variable.get("default"), ) ) @@ -103,13 +101,13 @@ class BasicVariablesConfigManager: raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") form_item = item[key] - if 'label' not in form_item: + if "label" not in form_item: raise ValueError("label is required in user_input_form") if not isinstance(form_item["label"], str): raise ValueError("label in user_input_form must be of string type") - if 'variable' not in form_item: + if "variable" not in form_item: raise ValueError("variable is required in user_input_form") if not isinstance(form_item["variable"], str): @@ -117,26 +115,24 @@ class BasicVariablesConfigManager: pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") if pattern.match(form_item["variable"]) is None: - raise ValueError("variable in user_input_form must be a string, " - "and cannot start with a number") + raise ValueError("variable in user_input_form must be a string, " "and cannot start with a number") variables.append(form_item["variable"]) - if 'required' not in form_item or not form_item["required"]: + if "required" not in form_item or not form_item["required"]: form_item["required"] = False if not isinstance(form_item["required"], bool): raise ValueError("required in user_input_form must be of boolean type") if key == "select": - if 'options' not in form_item or not form_item["options"]: + if "options" not in form_item or not form_item["options"]: form_item["options"] = [] if not isinstance(form_item["options"], list): raise ValueError("options in user_input_form must be a list of strings") - if "default" in form_item and form_item['default'] \ - and form_item["default"] not in form_item["options"]: + if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]: raise ValueError("default value in user_input_form must be in the options list") return config, ["user_input_form"] @@ -168,10 +164,6 @@ class BasicVariablesConfigManager: typ = tool["type"] config = tool["config"] - ExternalDataToolFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=config - ) + ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config) return config, ["external_data_tools"] diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index bbb10d3d76..d208db2b01 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel): """ Model Config Entity. """ + provider: str model: str mode: Optional[str] = None @@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel): """ Advanced Chat Message Entity. """ + text: str role: PromptMessageRole @@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel): """ Advanced Chat Prompt Template Entity. """ + messages: list[AdvancedChatMessageEntity] @@ -43,6 +46,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel): """ Role Prefix Entity. """ + user: str assistant: str @@ -60,11 +64,12 @@ class PromptTemplateEntity(BaseModel): Prompt Type. 'simple', 'advanced' """ - SIMPLE = 'simple' - ADVANCED = 'advanced' + + SIMPLE = "simple" + ADVANCED = "advanced" @classmethod - def value_of(cls, value: str) -> 'PromptType': + def value_of(cls, value: str) -> "PromptType": """ Get value of given mode. @@ -74,7 +79,7 @@ class PromptTemplateEntity(BaseModel): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid prompt type value {value}') + raise ValueError(f"invalid prompt type value {value}") prompt_type: PromptType simple_prompt_template: Optional[str] = None @@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. """ + variable: str type: str config: dict[str, Any] = {} @@ -125,11 +131,12 @@ class DatasetRetrieveConfigEntity(BaseModel): Dataset Retrieve Strategy. 'single' or 'multiple' """ - SINGLE = 'single' - MULTIPLE = 'multiple' + + SINGLE = "single" + MULTIPLE = "multiple" @classmethod - def value_of(cls, value: str) -> 'RetrieveStrategy': + def value_of(cls, value: str) -> "RetrieveStrategy": """ Get value of given mode. @@ -139,25 +146,24 @@ class DatasetRetrieveConfigEntity(BaseModel): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid retrieve strategy value {value}') + raise ValueError(f"invalid retrieve strategy value {value}") query_variable: Optional[str] = None # Only when app mode is completion retrieve_strategy: RetrieveStrategy top_k: Optional[int] = None - score_threshold: Optional[float] = .0 - rerank_mode: Optional[str] = 'reranking_model' + score_threshold: Optional[float] = 0.0 + rerank_mode: Optional[str] = "reranking_model" reranking_model: Optional[dict] = None weights: Optional[dict] = None reranking_enabled: Optional[bool] = True - - class DatasetEntity(BaseModel): """ Dataset Config Entity. """ + dataset_ids: list[str] retrieve_config: DatasetRetrieveConfigEntity @@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel): """ Sensitive Word Avoidance Entity. """ + type: str config: dict[str, Any] = {} @@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel): """ Sensitive Word Avoidance Entity. """ + enabled: bool voice: Optional[str] = None language: Optional[str] = None @@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel): """ Tracing Config Entity. """ + enabled: bool tracing_provider: str - - class AppAdditionalFeatures(BaseModel): file_upload: Optional[FileExtraConfig] = None opening_statement: Optional[str] = None @@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel): text_to_speech: Optional[TextToSpeechEntity] = None trace_config: Optional[TracingConfigEntity] = None + class AppConfig(BaseModel): """ Application Config Entity. """ + tenant_id: str app_id: str app_mode: AppMode @@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum): """ App Model Config From. """ - ARGS = 'args' - APP_LATEST_CONFIG = 'app-latest-config' - CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config' + + ARGS = "args" + APP_LATEST_CONFIG = "app-latest-config" + CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config" class EasyUIBasedAppConfig(AppConfig): """ Easy UI Based App Config Entity. """ + app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_id: str app_model_config_dict: dict @@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig): """ Workflow UI Based App Config Entity. """ + workflow_id: str diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 3da3c2eddb..5f7fc99151 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -13,21 +13,19 @@ class FileUploadConfigManager: :param config: model config args :param is_vision: if True, the feature is vision feature """ - file_upload_dict = config.get('file_upload') + file_upload_dict = config.get("file_upload") if file_upload_dict: - if file_upload_dict.get('image'): - if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: + if file_upload_dict.get("image"): + if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]: image_config = { - 'number_limits': file_upload_dict['image']['number_limits'], - 'transfer_methods': file_upload_dict['image']['transfer_methods'] + "number_limits": file_upload_dict["image"]["number_limits"], + "transfer_methods": file_upload_dict["image"]["transfer_methods"], } if is_vision: - image_config['detail'] = file_upload_dict['image']['detail'] + image_config["detail"] = file_upload_dict["image"]["detail"] - return FileExtraConfig( - image_config=image_config - ) + return FileExtraConfig(image_config=image_config) return None @@ -49,21 +47,21 @@ class FileUploadConfigManager: if not config["file_upload"].get("image"): config["file_upload"]["image"] = {"enabled": False} - if config['file_upload']['image']['enabled']: - number_limits = config['file_upload']['image']['number_limits'] + if config["file_upload"]["image"]["enabled"]: + number_limits = config["file_upload"]["image"]["number_limits"] if number_limits < 1 or number_limits > 6: raise ValueError("number_limits must be in [1, 6]") if is_vision: - detail = config['file_upload']['image']['detail'] - if detail not in ['high', 'low']: + detail = config["file_upload"]["image"]["detail"] + if detail not in ["high", "low"]: raise ValueError("detail must be in ['high', 'low']") - transfer_methods = config['file_upload']['image']['transfer_methods'] + transfer_methods = config["file_upload"]["image"]["transfer_methods"] if not isinstance(transfer_methods, list): raise ValueError("transfer_methods must be of list type") for method in transfer_methods: - if method not in ['remote_url', 'local_file']: + if method not in ["remote_url", "local_file"]: raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") return config, ["file_upload"] diff --git a/api/core/app/app_config/features/more_like_this/manager.py b/api/core/app/app_config/features/more_like_this/manager.py index 2ba99a5c40..496e1beeec 100644 --- a/api/core/app/app_config/features/more_like_this/manager.py +++ b/api/core/app/app_config/features/more_like_this/manager.py @@ -7,9 +7,9 @@ class MoreLikeThisConfigManager: :param config: model config args """ more_like_this = False - more_like_this_dict = config.get('more_like_this') + more_like_this_dict = config.get("more_like_this") if more_like_this_dict: - if more_like_this_dict.get('enabled'): + if more_like_this_dict.get("enabled"): more_like_this = True return more_like_this @@ -22,9 +22,7 @@ class MoreLikeThisConfigManager: :param config: app model config args """ if not config.get("more_like_this"): - config["more_like_this"] = { - "enabled": False - } + config["more_like_this"] = {"enabled": False} if not isinstance(config["more_like_this"], dict): raise ValueError("more_like_this must be of dict type") diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index 0d8a71bfcf..b4dacbc409 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -1,5 +1,3 @@ - - class OpeningStatementConfigManager: @classmethod def convert(cls, config: dict) -> tuple[str, list]: @@ -9,10 +7,10 @@ class OpeningStatementConfigManager: :param config: model config args """ # opening statement - opening_statement = config.get('opening_statement') + opening_statement = config.get("opening_statement") # suggested questions - suggested_questions_list = config.get('suggested_questions') + suggested_questions_list = config.get("suggested_questions") return opening_statement, suggested_questions_list diff --git a/api/core/app/app_config/features/retrieval_resource/manager.py b/api/core/app/app_config/features/retrieval_resource/manager.py index fca58e12e8..d098abac2f 100644 --- a/api/core/app/app_config/features/retrieval_resource/manager.py +++ b/api/core/app/app_config/features/retrieval_resource/manager.py @@ -2,9 +2,9 @@ class RetrievalResourceConfigManager: @classmethod def convert(cls, config: dict) -> bool: show_retrieve_source = False - retriever_resource_dict = config.get('retriever_resource') + retriever_resource_dict = config.get("retriever_resource") if retriever_resource_dict: - if retriever_resource_dict.get('enabled'): + if retriever_resource_dict.get("enabled"): show_retrieve_source = True return show_retrieve_source @@ -17,9 +17,7 @@ class RetrievalResourceConfigManager: :param config: app model config args """ if not config.get("retriever_resource"): - config["retriever_resource"] = { - "enabled": False - } + config["retriever_resource"] = {"enabled": False} if not isinstance(config["retriever_resource"], dict): raise ValueError("retriever_resource must be of dict type") diff --git a/api/core/app/app_config/features/speech_to_text/manager.py b/api/core/app/app_config/features/speech_to_text/manager.py index 88b4be25d3..e10ae03e04 100644 --- a/api/core/app/app_config/features/speech_to_text/manager.py +++ b/api/core/app/app_config/features/speech_to_text/manager.py @@ -7,9 +7,9 @@ class SpeechToTextConfigManager: :param config: model config args """ speech_to_text = False - speech_to_text_dict = config.get('speech_to_text') + speech_to_text_dict = config.get("speech_to_text") if speech_to_text_dict: - if speech_to_text_dict.get('enabled'): + if speech_to_text_dict.get("enabled"): speech_to_text = True return speech_to_text @@ -22,9 +22,7 @@ class SpeechToTextConfigManager: :param config: app model config args """ if not config.get("speech_to_text"): - config["speech_to_text"] = { - "enabled": False - } + config["speech_to_text"] = {"enabled": False} if not isinstance(config["speech_to_text"], dict): raise ValueError("speech_to_text must be of dict type") diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py index c6cab01220..9ac5114d12 100644 --- a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager: :param config: model config args """ suggested_questions_after_answer = False - suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer') + suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer") if suggested_questions_after_answer_dict: - if suggested_questions_after_answer_dict.get('enabled'): + if suggested_questions_after_answer_dict.get("enabled"): suggested_questions_after_answer = True return suggested_questions_after_answer @@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager: :param config: app model config args """ if not config.get("suggested_questions_after_answer"): - config["suggested_questions_after_answer"] = { - "enabled": False - } + config["suggested_questions_after_answer"] = {"enabled": False} if not isinstance(config["suggested_questions_after_answer"], dict): raise ValueError("suggested_questions_after_answer must be of dict type") - if "enabled" not in config["suggested_questions_after_answer"] or not \ - config["suggested_questions_after_answer"]["enabled"]: + if ( + "enabled" not in config["suggested_questions_after_answer"] + or not config["suggested_questions_after_answer"]["enabled"] + ): config["suggested_questions_after_answer"]["enabled"] = False if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): diff --git a/api/core/app/app_config/features/text_to_speech/manager.py b/api/core/app/app_config/features/text_to_speech/manager.py index f11e268e73..1c75981785 100644 --- a/api/core/app/app_config/features/text_to_speech/manager.py +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -10,13 +10,13 @@ class TextToSpeechConfigManager: :param config: model config args """ text_to_speech = None - text_to_speech_dict = config.get('text_to_speech') + text_to_speech_dict = config.get("text_to_speech") if text_to_speech_dict: - if text_to_speech_dict.get('enabled'): + if text_to_speech_dict.get("enabled"): text_to_speech = TextToSpeechEntity( - enabled=text_to_speech_dict.get('enabled'), - voice=text_to_speech_dict.get('voice'), - language=text_to_speech_dict.get('language'), + enabled=text_to_speech_dict.get("enabled"), + voice=text_to_speech_dict.get("voice"), + language=text_to_speech_dict.get("language"), ) return text_to_speech @@ -29,11 +29,7 @@ class TextToSpeechConfigManager: :param config: app model config args """ if not config.get("text_to_speech"): - config["text_to_speech"] = { - "enabled": False, - "voice": "", - "language": "" - } + config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""} if not isinstance(config["text_to_speech"], dict): raise ValueError("text_to_speech must be of dict type") diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index c3d0e8ba03..b52f235849 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -1,4 +1,3 @@ - from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): """ Advanced Chatbot App Config Entity. """ + pass class AdvancedChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - workflow: Workflow) -> AdvancedChatAppConfig: + def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: features_dict = workflow.features_dict app_mode = AppMode.value_of(app_model.mode) @@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): app_id=app_model.id, app_mode=app_mode, workflow_id=workflow.id, - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=features_dict - ), - variables=WorkflowVariablesConfigManager.convert( - workflow=workflow - ), - additional_features=cls.convert_features(features_dict, app_mode) + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + additional_features=cls.convert_features(features_dict, app_mode), ) return app_config @@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): # file upload validation config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, - is_vision=False + config=config, is_vision=False ) related_config_keys.extend(current_related_config_keys) @@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): # moderation validation config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate ) related_config_keys.extend(current_related_config_keys) @@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): filtered_config = {key: config.get(key) for key in related_config_keys} return filtered_config - diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 638cc07461..1277dcebc5 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -34,7 +34,8 @@ logger = logging.getLogger(__name__) class AdvancedChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, @@ -44,7 +45,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, @@ -53,14 +55,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): ) -> dict: ... def generate( - self, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom, - stream: bool = True, - ) -> dict[str, Any] | Generator[str, Any, None]: + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: bool = True, + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -71,44 +73,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', False) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)} # get conversation conversation = None - conversation_id = args.get('conversation_id') + conversation_id = args.get("conversation_id") if conversation_id: - conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) + conversation = self._get_conversation_by_user( + app_model=app_model, conversation_id=conversation_id, user=user + ) # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # get tracing instance user_id = user.id if isinstance(user, Account) else user.session_id @@ -130,7 +125,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -140,16 +135,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=invoke_from, application_generate_entity=application_generate_entity, conversation=conversation, - stream=stream + stream=stream, ) - def single_iteration_generate(self, app_model: App, - workflow: Workflow, - node_id: str, - user: Account, - args: dict, - stream: bool = True) \ - -> dict[str, Any] | Generator[str, Any, None]: + def single_iteration_generate( + self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -161,16 +152,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param stream: is stream """ if not node_id: - raise ValueError('node_id is required') + raise ValueError("node_id is required") - if args.get('inputs') is None: - raise ValueError('inputs is required') + if args.get("inputs") is None: + raise ValueError("inputs is required") # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # init application generate entity application_generate_entity = AdvancedChatAppGenerateEntity( @@ -178,18 +166,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_config=app_config, conversation_id=None, inputs={}, - query='', + query="", files=[], user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras={ - "auto_generate_conversation_name": False - }, + extras={"auto_generate_conversation_name": False}, single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( - node_id=node_id, - inputs=args['inputs'] - ) + node_id=node_id, inputs=args["inputs"] + ), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -199,17 +184,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, conversation=None, - stream=stream + stream=stream, ) - def _generate(self, *, - workflow: Workflow, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - application_generate_entity: AdvancedChatAppGenerateEntity, - conversation: Optional[Conversation] = None, - stream: bool = True) \ - -> dict[str, Any] | Generator[str, Any, None]: + def _generate( + self, + *, + workflow: Workflow, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + application_generate_entity: AdvancedChatAppGenerateEntity, + conversation: Optional[Conversation] = None, + stream: bool = True, + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -225,10 +212,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): is_first_conversation = True # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) if is_first_conversation: # update conversation features @@ -243,18 +227,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), # type: ignore - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - 'context': contextvars.copy_context(), - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + "context": contextvars.copy_context(), + }, + ) worker_thread.start() @@ -269,17 +256,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return AdvancedChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation_id: str, - message_id: str, - context: contextvars.Context) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + context: contextvars.Context, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -302,7 +289,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, - message=message + message=message, ) runner.run() @@ -310,14 +297,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG", "false").lower() == 'true': + if os.environ.get("DEBUG", "false").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 0caff4a2e3..d9fc599542 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -25,10 +25,7 @@ def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str): if not text_content or text_content.isspace(): return return model_instance.invoke_tts( - content_text=text_content.strip(), - user="responding_tts", - tenant_id=tenant_id, - voice=voice + content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice ) @@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue): except Exception as e: logging.getLogger(__name__).warning(e) break - audio_queue.put(AudioTrunk("finish", b'')) + audio_queue.put(AudioTrunk("finish", b"")) class AppGeneratorTTSPublisher: - def __init__(self, tenant_id: str, voice: str): self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id - self.msg_text = '' + self.msg_text = "" self._audio_queue = queue.Queue() self._msg_queue = queue.Queue() - self.match = re.compile(r'[。.!?]') + self.match = re.compile(r"[。.!?]") self.model_manager = ModelManager() self.model_instance = self.model_manager.get_default_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.TTS + tenant_id=self.tenant_id, model_type=ModelType.TTS ) self.voices = self.model_instance.get_tts_voices() - values = [voice.get('value') for voice in self.voices] + values = [voice.get("value") for voice in self.voices] self.voice = voice if not voice or voice not in values: - self.voice = self.voices[0].get('value') + self.voice = self.voices[0].get("value") self.MAX_SENTENCE = 2 self._last_audio_event = None self._runtime_thread = threading.Thread(target=self._runtime).start() @@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher: message = self._msg_queue.get() if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: - futures_result = self.executor.submit(_invoiceTTS, self.msg_text, - self.model_instance, self.tenant_id, self.voice) + futures_result = self.executor.submit( + _invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice + ) future_queue.put(futures_result) break elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): @@ -94,21 +90,20 @@ class AppGeneratorTTSPublisher: elif isinstance(message.event, QueueTextChunkEvent): self.msg_text += message.event.text elif isinstance(message.event, QueueNodeSucceededEvent): - self.msg_text += message.event.outputs.get('output', '') + self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): self.MAX_SENTENCE += 1 - text_content = ''.join(sentence_arr) - futures_result = self.executor.submit(_invoiceTTS, text_content, - self.model_instance, - self.tenant_id, - self.voice) + text_content = "".join(sentence_arr) + futures_result = self.executor.submit( + _invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice + ) future_queue.put(futures_result) if text_tmp: self.msg_text = text_tmp else: - self.msg_text = '' + self.msg_text = "" except Exception as e: self.logger.warning(e) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 4da3d093d2..90f547b0f2 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -38,11 +38,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): """ def __init__( - self, - application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, ) -> None: """ :param application_generate_entity: application generate entity @@ -66,11 +66,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError('App not found') + raise ValueError("App not found") workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") user_id = None if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: @@ -81,7 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id = self.application_generate_entity.user_id workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get("DEBUG", "False").lower() == "true"): workflow_callbacks.append(WorkflowLoggingCallback()) if self.application_generate_entity.single_iteration_run: @@ -89,7 +89,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( workflow=workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, - user_inputs=self.application_generate_entity.single_iteration_run.inputs + user_inputs=self.application_generate_entity.single_iteration_run.inputs, ) else: inputs = self.application_generate_entity.inputs @@ -98,26 +98,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # moderation if self.handle_input_moderation( - app_record=app_record, - app_generate_entity=self.application_generate_entity, - inputs=inputs, - query=query, - message_id=self.message.id + app_record=app_record, + app_generate_entity=self.application_generate_entity, + inputs=inputs, + query=query, + message_id=self.message.id, ): return # annotation reply if self.handle_annotation_reply( - app_record=app_record, - message=self.message, - query=query, - app_generate_entity=self.application_generate_entity + app_record=app_record, + message=self.message, + query=query, + app_generate_entity=self.application_generate_entity, ): return # Init conversation variables stmt = select(ConversationVariable).where( - ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id + ConversationVariable.app_id == self.conversation.app_id, + ConversationVariable.conversation_id == self.conversation.id, ) with Session(db.engine) as session: conversation_variables = session.scalars(stmt).all() @@ -190,12 +191,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self._handle_event(workflow_entry, event) def handle_input_moderation( - self, - app_record: App, - app_generate_entity: AdvancedChatAppGenerateEntity, - inputs: Mapping[str, Any], - query: str, - message_id: str + self, + app_record: App, + app_generate_entity: AdvancedChatAppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str, ) -> bool: """ Handle input moderation @@ -217,18 +218,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): message_id=message_id, ) except ModerationException as e: - self._complete_with_stream_output( - text=str(e), - stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION - ) + self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) return True return False - def handle_annotation_reply(self, app_record: App, - message: Message, - query: str, - app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: + def handle_annotation_reply( + self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity + ) -> bool: """ Handle annotation reply :param app_record: app record @@ -246,32 +243,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ) if annotation_reply: - self._publish_event( - QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id) - ) + self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)) self._complete_with_stream_output( - text=annotation_reply.content, - stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY + text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY ) return True return False - def _complete_with_stream_output(self, - text: str, - stopped_by: QueueStopEvent.StopBy) -> None: + def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: """ Direct output :param text: text :return: """ - self._publish_event( - QueueTextChunkEvent( - text=text - ) - ) + self._publish_event(QueueTextChunkEvent(text=text)) - self._publish_event( - QueueStopEvent(stopped_by=stopped_by) - ) + self._publish_event(QueueStopEvent(stopped_by=stopped_by)) diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index ef579827b4..5fbd3e9a94 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -28,15 +28,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): """ blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -50,13 +50,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]: + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, Any, None]: """ Convert stream full response. :param stream_response: stream response @@ -67,14 +69,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -85,7 +87,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, Any, None]: """ Convert stream simple response. :param stream_response: stream response @@ -96,20 +100,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index fb013cd1b1..f3e1a49cc2 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -65,6 +65,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _task_state: WorkflowTaskState _application_generate_entity: AdvancedChatAppGenerateEntity _workflow: Workflow @@ -72,14 +73,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _workflow_system_variables: dict[SystemVariableKey, Any] def __init__( - self, - application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool, + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool, ) -> None: """ Initialize AdvancedChatAppGenerateTaskPipeline. @@ -123,13 +124,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, - self._application_generate_entity.query + self._conversation, self._application_generate_entity.query ) - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) @@ -147,7 +145,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc elif isinstance(stream_response, MessageEndStreamResponse): extras = {} if stream_response.metadata: - extras['metadata'] = stream_response.metadata + extras["metadata"] = stream_response.metadata return ChatbotAppBlockingResponse( task_id=stream_response.task_id, @@ -158,15 +156,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc message_id=self._message.id, answer=self._task_state.answer, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[ChatbotAppStreamResponse, Any, None]: """ To stream response. :return: @@ -176,7 +176,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc conversation_id=self._conversation.id, message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) def _listenAudioMsg(self, publisher, task_id: str): @@ -187,17 +187,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict - if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ - 'text_to_speech'].get('autoPlay') == 'enabled': - tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: @@ -228,12 +231,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc except Exception as e: logger.error(e) break - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None + self, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -267,22 +270,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc db.session.close() yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start( - workflow_run=workflow_run, - event=event - ) + workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: @@ -293,7 +292,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: @@ -304,62 +303,52 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: yield response elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationNextEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationCompletedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, @@ -372,20 +361,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc ) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - self._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), - PublishFrom.TASK_PIPELINE - ) + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowFailedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") workflow_run = self._handle_workflow_run_failed( workflow_run=workflow_run, @@ -399,11 +384,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc ) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) yield self._error_to_stream_response(self._handle_error(err_event, self._message)) break elif isinstance(event, QueueStopEvent): @@ -420,8 +404,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc ) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) # Save message @@ -434,8 +417,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._refetch_message() - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) db.session.commit() db.session.refresh(self._message) @@ -445,8 +429,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._refetch_message() - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) db.session.commit() db.session.refresh(self._message) @@ -472,7 +457,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueueAdvancedChatMessageEndEvent): if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) if output_moderation_answer: @@ -502,8 +487,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) if graph_runtime_state and graph_runtime_state.llm_usage: usage = graph_runtime_state.llm_usage @@ -523,7 +509,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc application_generate_entity=self._application_generate_entity, conversation=self._conversation, is_first_message=self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras + extras=self._application_generate_entity.extras, ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -533,15 +519,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata.copy() + extras["metadata"] = self._task_state.metadata.copy() - if 'annotation_reply' in extras['metadata']: - del extras['metadata']['annotation_reply'] + if "annotation_reply" in extras["metadata"]: + del extras["metadata"]["annotation_reply"] return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, - id=self._message.id, - **extras + task_id=self._application_generate_entity.task_id, id=self._message.id, **extras ) def _handle_output_moderation_chunk(self, text: str) -> bool: @@ -555,14 +539,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # stop subscribe new token when output moderation should direct output self._task_state.answer = self._output_moderation_handler.get_final_output() self._queue_manager.publish( - QueueTextChunkEvent( - text=self._task_state.answer - ), PublishFrom.TASK_PIPELINE + QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index f495ebbf35..9040f18bfd 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig): """ Agent Chatbot App Config Entity. """ + agent: Optional[AgentEntity] = None class AgentChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None) -> AgentChatAppConfig: + def get_app_config( + cls, + app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None, + ) -> AgentChatAppConfig: """ Convert app model config to agent chat app config :param app_model: app model @@ -66,22 +70,12 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - agent=AgentConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + agent=AgentConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -128,7 +122,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager): # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -145,13 +140,15 @@ class AgentChatAppConfigManager(BaseAppConfigManager): # dataset configs # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) @@ -170,10 +167,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): :param config: app model config args """ if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } + config["agent_mode"] = {"enabled": False, "tools": []} if not isinstance(config["agent_mode"], dict): raise ValueError("agent_mode must be of object type") @@ -187,8 +181,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager): if not config["agent_mode"].get("strategy"): config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value - if config["agent_mode"]["strategy"] not in [member.value for member in - list(PlanningStrategy.__members__.values())]: + if config["agent_mode"]["strategy"] not in [ + member.value for member in list(PlanningStrategy.__members__.values()) + ]: raise ValueError("strategy in agent_mode must be in the specified strategy list") if not config["agent_mode"].get("tools"): @@ -210,7 +205,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): raise ValueError("enabled in agent_mode.tools must be of boolean type") if key == "dataset": - if 'id' not in tool_item: + if "id" not in tool_item: raise ValueError("id is required in dataset") try: diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index daf37f4a50..7ba6bbab94 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -30,7 +30,8 @@ logger = logging.getLogger(__name__) class AgentChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, @@ -39,19 +40,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: Literal[False] = False, ) -> dict: ... - def generate(self, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + def generate( + self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True + ) -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -62,60 +61,48 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): :param stream: is stream """ if not stream: - raise ValueError('Agent Chat App does not support blocking mode') + raise ValueError("Agent Chat App does not support blocking mode") - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', True) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + if args.get("conversation_id"): + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = AgentChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # always enable retriever resource in debugger mode - override_model_config_dict["retriever_resource"] = { - "enabled": True - } + override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] @@ -124,7 +111,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): app_model=app_model, app_model_config=app_model_config, conversation=conversation, - override_config_dict=override_model_config_dict + override_config_dict=override_model_config_dict, ) # get tracing instance @@ -145,14 +132,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): invoke_from=invoke_from, extras=extras, call_depth=0, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -161,17 +145,20 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) worker_thread.start() @@ -185,13 +172,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return AgentChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( - self, flask_app: Flask, + self, + flask_app: Flask, application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation_id: str, @@ -224,14 +209,13 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index d1bbf679c5..6b676b0353 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner): """ def run( - self, application_generate_entity: AgentChatAppGenerateEntity, + self, + application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message, @@ -65,7 +66,7 @@ class AgentChatAppRunner(AppRunner): prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) memory = None @@ -73,13 +74,10 @@ class AgentChatAppRunner(AppRunner): # get memory of conversation (read-only) model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -91,7 +89,7 @@ class AgentChatAppRunner(AppRunner): inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # moderation @@ -103,7 +101,7 @@ class AgentChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) except ModerationException as e: self.direct_output( @@ -111,7 +109,7 @@ class AgentChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -122,13 +120,13 @@ class AgentChatAppRunner(AppRunner): message=message, query=query, user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from + invoke_from=application_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + PublishFrom.APPLICATION_MANAGER, ) self.direct_output( @@ -136,7 +134,7 @@ class AgentChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -148,7 +146,7 @@ class AgentChatAppRunner(AppRunner): app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # reorganize all inputs and template to prompt messages @@ -161,14 +159,14 @@ class AgentChatAppRunner(AppRunner): inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: @@ -177,9 +175,9 @@ class AgentChatAppRunner(AppRunner): agent_entity = app_config.agent # load tool variables - tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, - user_id=application_generate_entity.user_id, - tenant_id=app_config.tenant_id) + tool_conversation_variables = self._load_tool_variables( + conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id + ) # convert db variables to tool variables tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) @@ -187,7 +185,7 @@ class AgentChatAppRunner(AppRunner): # init model instance model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) prompt_message, _ = self.organize_prompt_messages( app_record=app_record, @@ -238,7 +236,7 @@ class AgentChatAppRunner(AppRunner): prompt_messages=prompt_message, variables_pool=tool_variables, db_variables=tool_conversation_variables, - model_instance=model_instance + model_instance=model_instance, ) invoke_result = runner.run( @@ -252,17 +250,21 @@ class AgentChatAppRunner(AppRunner): invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream, - agent=True + agent=True, ) def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables: """ load tool variables from database """ - tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter( - ToolConversationVariables.conversation_id == conversation_id, - ToolConversationVariables.tenant_id == tenant_id - ).first() + tool_variables: ToolConversationVariables = ( + db.session.query(ToolConversationVariables) + .filter( + ToolConversationVariables.conversation_id == conversation_id, + ToolConversationVariables.tenant_id == tenant_id, + ) + .first() + ) if tool_variables: # save tool variables to session, so that we can update it later @@ -273,34 +275,40 @@ class AgentChatAppRunner(AppRunner): conversation_id=conversation_id, user_id=user_id, tenant_id=tenant_id, - variables_str='[]', + variables_str="[]", ) db.session.add(tool_variables) db.session.commit() return tool_variables - - def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool: + + def _convert_db_variables_to_tool_variables( + self, db_variables: ToolConversationVariables + ) -> ToolRuntimeVariablePool: """ convert db variables to tool variables """ - return ToolRuntimeVariablePool(**{ - 'conversation_id': db_variables.conversation_id, - 'user_id': db_variables.user_id, - 'tenant_id': db_variables.tenant_id, - 'pool': db_variables.variables - }) + return ToolRuntimeVariablePool( + **{ + "conversation_id": db_variables.conversation_id, + "user_id": db_variables.user_id, + "tenant_id": db_variables.tenant_id, + "pool": db_variables.variables, + } + ) - def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity, - message: Message) -> LLMUsage: + def _get_usage_of_all_agent_thoughts( + self, model_config: ModelConfigWithCredentialsEntity, message: Message + ) -> LLMUsage: """ Get usage of all agent thoughts :param model_config: model config :param message: message :return: """ - agent_thoughts = (db.session.query(MessageAgentThought) - .filter(MessageAgentThought.message_id == message.id).all()) + agent_thoughts = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all() + ) all_message_tokens = 0 all_answer_tokens = 0 @@ -312,8 +320,5 @@ class AgentChatAppRunner(AppRunner): model_type_instance = cast(LargeLanguageModel, model_type_instance) return model_type_instance._calc_response_usage( - model_config.model, - model_config.credentials, - all_message_tokens, - all_answer_tokens + model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens ) diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 118d82c495..629c309c06 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -23,15 +23,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -45,14 +45,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -63,14 +64,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -81,8 +82,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -93,20 +95,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index a196d36be5..73025d99d0 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -13,32 +13,33 @@ class AppGenerateResponseConverter(ABC): _blocking_response_type: type[AppBlockingResponse] @classmethod - def convert(cls, response: Union[ - AppBlockingResponse, - Generator[AppStreamResponse, Any, None] - ], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]: + def convert( + cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom + ) -> dict[str, Any] | Generator[str, Any, None]: if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) else: + def _generate_full_response() -> Generator[str, Any, None]: for chunk in cls.convert_stream_full_response(response): - if chunk == 'ping': - yield f'event: {chunk}\n\n' + if chunk == "ping": + yield f"event: {chunk}\n\n" else: - yield f'data: {chunk}\n\n' + yield f"data: {chunk}\n\n" return _generate_full_response() else: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_simple_response(response) else: + def _generate_simple_response() -> Generator[str, Any, None]: for chunk in cls.convert_stream_simple_response(response): - if chunk == 'ping': - yield f'event: {chunk}\n\n' + if chunk == "ping": + yield f"event: {chunk}\n\n" else: - yield f'data: {chunk}\n\n' + yield f"data: {chunk}\n\n" return _generate_simple_response() @@ -54,14 +55,16 @@ class AppGenerateResponseConverter(ABC): @classmethod @abstractmethod - def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, None, None]: raise NotImplementedError @classmethod @abstractmethod - def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, None, None]: raise NotImplementedError @classmethod @@ -72,24 +75,26 @@ class AppGenerateResponseConverter(ABC): :return: """ # show_retrieve_source - if 'retriever_resources' in metadata: - metadata['retriever_resources'] = [] - for resource in metadata['retriever_resources']: - metadata['retriever_resources'].append({ - 'segment_id': resource['segment_id'], - 'position': resource['position'], - 'document_name': resource['document_name'], - 'score': resource['score'], - 'content': resource['content'], - }) + if "retriever_resources" in metadata: + metadata["retriever_resources"] = [] + for resource in metadata["retriever_resources"]: + metadata["retriever_resources"].append( + { + "segment_id": resource["segment_id"], + "position": resource["position"], + "document_name": resource["document_name"], + "score": resource["score"], + "content": resource["content"], + } + ) # show annotation reply - if 'annotation_reply' in metadata: - del metadata['annotation_reply'] + if "annotation_reply" in metadata: + del metadata["annotation_reply"] # show usage - if 'usage' in metadata: - del metadata['usage'] + if "usage" in metadata: + del metadata["usage"] return metadata @@ -101,16 +106,16 @@ class AppGenerateResponseConverter(ABC): :return: """ error_responses = { - ValueError: {'code': 'invalid_param', 'status': 400}, - ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + ValueError: {"code": "invalid_param", "status": 400}, + ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400}, QuotaExceededError: { - 'code': 'provider_quota_exceeded', - 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.", - 'status': 400 + "code": "provider_quota_exceeded", + "message": "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + "status": 400, }, - ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, - InvokeError: {'code': 'completion_request_error', 'status': 400} + ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400}, + InvokeError: {"code": "completion_request_error", "status": 400}, } # Determine the response based on the type of exception @@ -120,13 +125,13 @@ class AppGenerateResponseConverter(ABC): data = v if data: - data.setdefault('message', getattr(e, 'description', str(e))) + data.setdefault("message", getattr(e, "description", str(e))) else: logging.error(e) data = { - 'code': 'internal_server_error', - 'message': 'Internal Server Error, please contact support.', - 'status': 500 + "code": "internal_server_error", + "message": "Internal Server Error, please contact support.", + "status": 500, } return data diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 9e331dff4d..ce6f7d4338 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -16,10 +16,10 @@ class BaseAppGenerator: def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): user_input_value = inputs.get(var.variable) if var.required and not user_input_value: - raise ValueError(f'{var.variable} is required in input form') + raise ValueError(f"{var.variable} is required in input form") if not var.required and not user_input_value: # TODO: should we return None here if the default value is None? - return var.default or '' + return var.default or "" if ( var.type in ( @@ -34,7 +34,7 @@ class BaseAppGenerator: if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): # may raise ValueError if user_input_value is not a valid number try: - if '.' in user_input_value: + if "." in user_input_value: return float(user_input_value) else: return int(user_input_value) @@ -43,14 +43,14 @@ class BaseAppGenerator: if var.type == VariableEntityType.SELECT: options = var.options or [] if user_input_value not in options: - raise ValueError(f'{var.variable} in input form must be one of the following: {options}') + raise ValueError(f"{var.variable} in input form must be one of the following: {options}") elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): if var.max_length and user_input_value and len(user_input_value) > var.max_length: - raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters') + raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") return user_input_value def _sanitize_value(self, value: Any) -> Any: if isinstance(value, str): - return value.replace('\x00', '') + return value.replace("\x00", "") return value diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index f929a979f1..df972756d5 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -24,9 +24,7 @@ class PublishFrom(Enum): class AppQueueManager: - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom) -> None: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: if not user_id: raise ValueError("user is required") @@ -34,9 +32,10 @@ class AppQueueManager: self._user_id = user_id self._invoke_from = invoke_from - user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, - f"{user_prefix}-{self._user_id}") + user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" + redis_client.setex( + AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" + ) q = queue.Queue() @@ -66,8 +65,7 @@ class AppQueueManager: # publish two messages to make sure the client can receive the stop signal # and stop listening after the stop signal processed self.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE ) if elapsed_time // 10 > last_ping_time: @@ -88,9 +86,7 @@ class AppQueueManager: :param pub_from: publish from :return: """ - self.publish(QueueErrorEvent( - error=e - ), pub_from) + self.publish(QueueErrorEvent(error=e), pub_from) def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ @@ -122,8 +118,8 @@ class AppQueueManager: if result is None: return - user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - if result.decode('utf-8') != f"{user_prefix}-{user_id}": + user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" + if result.decode("utf-8") != f"{user_prefix}-{user_id}": return stopped_cache_key = cls._generate_stopped_cache_key(task_id) @@ -168,9 +164,11 @@ class AppQueueManager: for item in data: self._check_for_sqlalchemy_models(item) else: - if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'): - raise TypeError("Critical Error: Passing SQLAlchemy Model instances " - "that cause thread safety issues is not allowed.") + if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"): + raise TypeError( + "Critical Error: Passing SQLAlchemy Model instances " + "that cause thread safety issues is not allowed." + ) class GenerateTaskStoppedException(Exception): diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 60216959a8..aadb43ad39 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -31,12 +31,15 @@ if TYPE_CHECKING: class AppRunner: - def get_pre_calculate_rest_tokens(self, app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["FileVar"], - query: Optional[str] = None) -> int: + def get_pre_calculate_rest_tokens( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: dict[str, str], + files: list["FileVar"], + query: Optional[str] = None, + ) -> int: """ Get pre calculate rest tokens :param app_record: app record @@ -49,18 +52,20 @@ class AppRunner: """ # Invoke model model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 if model_context_tokens is None: return -1 @@ -75,36 +80,39 @@ class AppRunner: prompt_template_entity=prompt_template_entity, inputs=inputs, files=files, - query=query + query=query, ) - prompt_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) rest_tokens = model_context_tokens - max_tokens - prompt_tokens if rest_tokens < 0: - raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " - "or shrink the max token, or switch to a llm with a larger token limit size.") + raise InvokeBadRequestError( + "Query or prefix prompt is too long, you can reduce the prefix prompt, " + "or shrink the max token, or switch to a llm with a larger token limit size." + ) return rest_tokens - def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, - prompt_messages: list[PromptMessage]): + def recalc_llm_max_tokens( + self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] + ): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 if model_context_tokens is None: return -1 @@ -112,27 +120,28 @@ class AppRunner: if max_tokens is None: max_tokens = 0 - prompt_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) if prompt_tokens + max_tokens > model_context_tokens: max_tokens = max(model_context_tokens - prompt_tokens, 16) for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): model_config.parameters[parameter_rule.name] = max_tokens - def organize_prompt_messages(self, app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["FileVar"], - query: Optional[str] = None, - context: Optional[str] = None, - memory: Optional[TokenBufferMemory] = None) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def organize_prompt_messages( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: dict[str, str], + files: list["FileVar"], + query: Optional[str] = None, + context: Optional[str] = None, + memory: Optional[TokenBufferMemory] = None, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Organize prompt messages :param context: @@ -152,60 +161,54 @@ class AppRunner: app_mode=AppMode.value_of(app_record.mode), prompt_template_entity=prompt_template_entity, inputs=inputs, - query=query if query else '', + query=query if query else "", files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) else: - memory_config = MemoryConfig( - window=MemoryConfig.WindowConfig( - enabled=False - ) - ) + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template - prompt_template = CompletionModelPromptTemplate( - text=advanced_completion_prompt_template.prompt - ) + prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt) if advanced_completion_prompt_template.role_prefix: memory_config.role_prefix = MemoryConfig.RolePrefix( user=advanced_completion_prompt_template.role_prefix.user, - assistant=advanced_completion_prompt_template.role_prefix.assistant + assistant=advanced_completion_prompt_template.role_prefix.assistant, ) else: prompt_template = [] for message in prompt_template_entity.advanced_chat_prompt_template.messages: - prompt_template.append(ChatModelMessage( - text=message.text, - role=message.role - )) + prompt_template.append(ChatModelMessage(text=message.text, role=message.role)) prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs=inputs, - query=query if query else '', + query=query if query else "", files=files, context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) stop = model_config.stop return prompt_messages, stop - def direct_output(self, queue_manager: AppQueueManager, - app_generate_entity: EasyUIBasedAppGenerateEntity, - prompt_messages: list, - text: str, - stream: bool, - usage: Optional[LLMUsage] = None) -> None: + def direct_output( + self, + queue_manager: AppQueueManager, + app_generate_entity: EasyUIBasedAppGenerateEntity, + prompt_messages: list, + text: str, + stream: bool, + usage: Optional[LLMUsage] = None, + ) -> None: """ Direct output :param queue_manager: application queue manager @@ -222,17 +225,10 @@ class AppRunner: chunk = LLMResultChunk( model=app_generate_entity.model_conf.model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage(content=token) - ) + delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)), ) - queue_manager.publish( - QueueLLMChunkEvent( - chunk=chunk - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER) index += 1 time.sleep(0.01) @@ -242,15 +238,19 @@ class AppRunner: model=app_generate_entity.model_conf.model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), - usage=usage if usage else LLMUsage.empty_usage() + usage=usage if usage else LLMUsage.empty_usage(), ), - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) - def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], - queue_manager: AppQueueManager, - stream: bool, - agent: bool = False) -> None: + def _handle_invoke_result( + self, + invoke_result: Union[LLMResult, Generator], + queue_manager: AppQueueManager, + stream: bool, + agent: bool = False, + ) -> None: """ Handle invoke result :param invoke_result: invoke result @@ -260,21 +260,13 @@ class AppRunner: :return: """ if not stream: - self._handle_invoke_result_direct( - invoke_result=invoke_result, - queue_manager=queue_manager, - agent=agent - ) + self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) else: - self._handle_invoke_result_stream( - invoke_result=invoke_result, - queue_manager=queue_manager, - agent=agent - ) + self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) - def _handle_invoke_result_direct(self, invoke_result: LLMResult, - queue_manager: AppQueueManager, - agent: bool) -> None: + def _handle_invoke_result_direct( + self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool + ) -> None: """ Handle invoke result direct :param invoke_result: invoke result @@ -285,12 +277,13 @@ class AppRunner: queue_manager.publish( QueueMessageEndEvent( llm_result=invoke_result, - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) - def _handle_invoke_result_stream(self, invoke_result: Generator, - queue_manager: AppQueueManager, - agent: bool) -> None: + def _handle_invoke_result_stream( + self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool + ) -> None: """ Handle invoke result :param invoke_result: invoke result @@ -300,21 +293,13 @@ class AppRunner: """ model = None prompt_messages = [] - text = '' + text = "" usage = None for result in invoke_result: if not agent: - queue_manager.publish( - QueueLLMChunkEvent( - chunk=result - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) else: - queue_manager.publish( - QueueAgentMessageEvent( - chunk=result - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) text += result.delta.message.content @@ -331,25 +316,24 @@ class AppRunner: usage = LLMUsage.empty_usage() llm_result = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=text), - usage=usage + model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage ) queue_manager.publish( QueueMessageEndEvent( llm_result=llm_result, - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) def moderation_for_inputs( - self, app_id: str, - tenant_id: str, - app_generate_entity: AppGenerateEntity, - inputs: Mapping[str, Any], - query: str, - message_id: str, + self, + app_id: str, + tenant_id: str, + app_generate_entity: AppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str, ) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. @@ -367,14 +351,17 @@ class AppRunner: tenant_id=tenant_id, app_config=app_generate_entity.app_config, inputs=inputs, - query=query if query else '', + query=query if query else "", message_id=message_id, - trace_manager=app_generate_entity.trace_manager + trace_manager=app_generate_entity.trace_manager, ) - def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, - queue_manager: AppQueueManager, - prompt_messages: list[PromptMessage]) -> bool: + def check_hosting_moderation( + self, + application_generate_entity: EasyUIBasedAppGenerateEntity, + queue_manager: AppQueueManager, + prompt_messages: list[PromptMessage], + ) -> bool: """ Check hosting moderation :param application_generate_entity: application generate entity @@ -384,8 +371,7 @@ class AppRunner: """ hosting_moderation_feature = HostingModerationFeature() moderation_result = hosting_moderation_feature.check( - application_generate_entity=application_generate_entity, - prompt_messages=prompt_messages + application_generate_entity=application_generate_entity, prompt_messages=prompt_messages ) if moderation_result: @@ -393,18 +379,20 @@ class AppRunner: queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, - text="I apologize for any confusion, " \ - "but I'm an AI assistant to be helpful, harmless, and honest.", - stream=application_generate_entity.stream + text="I apologize for any confusion, " "but I'm an AI assistant to be helpful, harmless, and honest.", + stream=application_generate_entity.stream, ) return moderation_result - def fill_in_inputs_from_external_data_tools(self, tenant_id: str, - app_id: str, - external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, - query: str) -> dict: + def fill_in_inputs_from_external_data_tools( + self, + tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: dict, + query: str, + ) -> dict: """ Fill in variable inputs from external data tools if exists. @@ -417,18 +405,12 @@ class AppRunner: """ external_data_fetch_feature = ExternalDataFetch() return external_data_fetch_feature.fetch( - tenant_id=tenant_id, - app_id=app_id, - external_data_tools=external_data_tools, - inputs=inputs, - query=query + tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query ) - def query_app_annotations_to_reply(self, app_record: App, - message: Message, - query: str, - user_id: str, - invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: + def query_app_annotations_to_reply( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: """ Query app annotations to reply :param app_record: app record @@ -440,9 +422,5 @@ class AppRunner: """ annotation_reply_feature = AnnotationReplyFeature() return annotation_reply_feature.query( - app_record=app_record, - message=message, - query=query, - user_id=user_id, - invoke_from=invoke_from + app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from ) diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index a286c349b2..96dc7dda79 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig): """ Chatbot App Config Entity. """ + pass class ChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None) -> ChatAppConfig: + def get_app_config( + cls, + app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None, + ) -> ChatAppConfig: """ Convert app model config to chat app config :param app_model: app model @@ -51,7 +55,7 @@ class ChatAppConfigManager(BaseAppConfigManager): config_dict = app_model_config_dict.copy() else: if not override_config_dict: - raise Exception('override_config_dict is required when config_from is ARGS') + raise Exception("override_config_dict is required when config_from is ARGS") config_dict = override_config_dict @@ -63,19 +67,11 @@ class ChatAppConfigManager(BaseAppConfigManager): app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -113,8 +109,9 @@ class ChatAppConfigManager(BaseAppConfigManager): related_config_keys.extend(current_related_config_keys) # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # opening_statement @@ -123,7 +120,8 @@ class ChatAppConfigManager(BaseAppConfigManager): # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -139,8 +137,9 @@ class ChatAppConfigManager(BaseAppConfigManager): related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index ab15928b74..15c7140308 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -30,7 +30,8 @@ logger = logging.getLogger(__name__) class ChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, @@ -39,7 +40,8 @@ class ChatAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, @@ -47,7 +49,8 @@ class ChatAppGenerator(MessageBasedAppGenerator): ) -> dict: ... def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, @@ -62,58 +65,46 @@ class ChatAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', True) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + if args.get("conversation_id"): + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = ChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # always enable retriever resource in debugger mode - override_model_config_dict["retriever_resource"] = { - "enabled": True - } + override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] @@ -122,7 +113,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): app_model=app_model, app_model_config=app_model_config, conversation=conversation, - override_config_dict=override_model_config_dict + override_config_dict=override_model_config_dict, ) # get tracing instance @@ -141,14 +132,11 @@ class ChatAppGenerator(MessageBasedAppGenerator): stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -157,17 +145,20 @@ class ChatAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) worker_thread.start() @@ -181,16 +172,16 @@ class ChatAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return ChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: ChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation_id: str, - message_id: str) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -212,20 +203,19 @@ class ChatAppGenerator(MessageBasedAppGenerator): application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, - message=message + message=message, ) except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 89a498eb36..bd90586825 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner): Chat Application Runner """ - def run(self, application_generate_entity: ChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message) -> None: + def run( + self, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -58,7 +61,7 @@ class ChatAppRunner(AppRunner): prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) memory = None @@ -66,13 +69,10 @@ class ChatAppRunner(AppRunner): # get memory of conversation (read-only) model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -84,7 +84,7 @@ class ChatAppRunner(AppRunner): inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # moderation @@ -96,7 +96,7 @@ class ChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) except ModerationException as e: self.direct_output( @@ -104,7 +104,7 @@ class ChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -115,13 +115,13 @@ class ChatAppRunner(AppRunner): message=message, query=query, user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from + invoke_from=application_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + PublishFrom.APPLICATION_MANAGER, ) self.direct_output( @@ -129,7 +129,7 @@ class ChatAppRunner(AppRunner): app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -141,7 +141,7 @@ class ChatAppRunner(AppRunner): app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # get context from datasets @@ -152,7 +152,7 @@ class ChatAppRunner(AppRunner): app_record.id, message.id, application_generate_entity.user_id, - application_generate_entity.invoke_from + application_generate_entity.invoke_from, ) dataset_retrieval = DatasetRetrieval(application_generate_entity) @@ -181,29 +181,26 @@ class ChatAppRunner(AppRunner): files=files, query=query, context=context, - memory=memory + memory=memory, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recalc_llm_max_tokens( - model_config=application_generate_entity.model_conf, - prompt_messages=prompt_messages - ) + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) # Invoke model model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) db.session.close() @@ -218,7 +215,5 @@ class ChatAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, - queue_manager=queue_manager, - stream=application_generate_entity.stream + invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream ) diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 625e14c9c3..0fa7af0a7f 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -23,15 +23,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -45,14 +45,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -63,14 +64,14 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -81,8 +82,9 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -93,20 +95,20 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index a771198324..1193c4b7a4 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig): """ Completion App Config Entity. """ + pass class CompletionAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - override_config_dict: Optional[dict] = None) -> CompletionAppConfig: + def get_app_config( + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None + ) -> CompletionAppConfig: """ Convert app model config to completion app config :param app_model: app model @@ -51,19 +52,11 @@ class CompletionAppConfigManager(BaseAppConfigManager): app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -101,8 +94,9 @@ class CompletionAppConfigManager(BaseAppConfigManager): related_config_keys.extend(current_related_config_keys) # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # text_to_speech @@ -114,8 +108,9 @@ class CompletionAppConfigManager(BaseAppConfigManager): related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index c0b13b40fd..d7301224e8 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -32,7 +32,8 @@ logger = logging.getLogger(__name__) class CompletionAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, @@ -41,19 +42,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: Literal[False] = False, ) -> dict: ... - def generate(self, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[str, None, None]]: + def generate( + self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -63,12 +62,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] extras = {} @@ -76,41 +75,31 @@ class CompletionAppGenerator(MessageBasedAppGenerator): conversation = None # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = CompletionAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] # convert to app config app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - override_config_dict=override_model_config_dict + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) # get tracing instance @@ -128,14 +117,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) + (conversation, message) = self._init_generate_records(application_generate_entity) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -144,16 +130,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "message_id": message.id, + }, + ) worker_thread.start() @@ -167,15 +156,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return CompletionAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: CompletionAppGenerateEntity, - queue_manager: AppQueueManager, - message_id: str) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: CompletionAppGenerateEntity, + queue_manager: AppQueueManager, + message_id: str, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -194,20 +183,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator): runner.run( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - message=message + message=message, ) except GenerateTaskStoppedException: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: @@ -216,12 +204,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator): finally: db.session.close() - def generate_more_like_this(self, app_model: App, - message_id: str, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[str, None, None]]: + def generate_more_like_this( + self, + app_model: App, + message_id: str, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + stream: bool = True, + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -231,13 +221,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() + message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ) + .first() + ) if not message: raise MessageNotExistsError() @@ -250,29 +244,23 @@ class CompletionAppGenerator(MessageBasedAppGenerator): app_model_config = message.app_model_config override_model_config_dict = app_model_config.to_dict() - model_dict = override_model_config_dict['model'] - completion_params = model_dict.get('completion_params') - completion_params['temperature'] = 0.9 - model_dict['completion_params'] = completion_params - override_model_config_dict['model'] = model_dict + model_dict = override_model_config_dict["model"] + completion_params = model_dict.get("completion_params") + completion_params["temperature"] = 0.9 + model_dict["completion_params"] = completion_params + override_model_config_dict["model"] = model_dict # parse files message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - message.files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user) else: file_objs = [] # convert to app config app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - override_config_dict=override_model_config_dict + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) # init application generate entity @@ -286,14 +274,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): user_id=user.id, stream=stream, invoke_from=invoke_from, - extras={} + extras={}, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) + (conversation, message) = self._init_generate_records(application_generate_entity) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -302,16 +287,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator): invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "message_id": message.id, + }, + ) worker_thread.start() @@ -325,7 +313,4 @@ class CompletionAppGenerator(MessageBasedAppGenerator): stream=stream, ) - return CompletionAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index f0e5f9ae17..da49c8701f 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner): Completion Application Runner """ - def run(self, application_generate_entity: CompletionAppGenerateEntity, - queue_manager: AppQueueManager, - message: Message) -> None: + def run( + self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -54,7 +54,7 @@ class CompletionAppRunner(AppRunner): prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) # organize all inputs and template to prompt messages @@ -65,7 +65,7 @@ class CompletionAppRunner(AppRunner): prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) # moderation @@ -77,7 +77,7 @@ class CompletionAppRunner(AppRunner): app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) except ModerationException as e: self.direct_output( @@ -85,7 +85,7 @@ class CompletionAppRunner(AppRunner): app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -97,7 +97,7 @@ class CompletionAppRunner(AppRunner): app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # get context from datasets @@ -108,7 +108,7 @@ class CompletionAppRunner(AppRunner): app_record.id, message.id, application_generate_entity.user_id, - application_generate_entity.invoke_from + application_generate_entity.invoke_from, ) dataset_config = app_config.dataset @@ -126,7 +126,7 @@ class CompletionAppRunner(AppRunner): invoke_from=application_generate_entity.invoke_from, show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, - message_id=message.id + message_id=message.id, ) # reorganize all inputs and template to prompt messages @@ -139,29 +139,26 @@ class CompletionAppRunner(AppRunner): inputs=inputs, files=files, query=query, - context=context + context=context, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recalc_llm_max_tokens( - model_config=application_generate_entity.model_conf, - prompt_messages=prompt_messages - ) + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) # Invoke model model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) db.session.close() @@ -176,8 +173,5 @@ class CompletionAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, - queue_manager=queue_manager, - stream=application_generate_entity.stream + invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream ) - \ No newline at end of file diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 14db74dbd0..697f0273a5 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -23,14 +23,14 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -44,14 +44,15 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -62,13 +63,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -79,8 +80,9 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -91,19 +93,19 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index fceed95b91..a91d48d246 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -35,23 +35,23 @@ logger = logging.getLogger(__name__) class MessageBasedAppGenerator(BaseAppGenerator): - def _handle_response( - self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool = False, + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, - Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: """ Handle response. @@ -70,7 +70,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) try: @@ -82,12 +82,13 @@ class MessageBasedAppGenerator(BaseAppGenerator): logger.exception(e) raise e - def _get_conversation_by_user(self, app_model: App, conversation_id: str, - user: Union[Account, EndUser]) -> Conversation: + def _get_conversation_by_user( + self, app_model: App, conversation_id: str, user: Union[Account, EndUser] + ) -> Conversation: conversation_filter = [ Conversation.id == conversation_id, Conversation.app_id == app_model.id, - Conversation.status == 'normal' + Conversation.status == "normal", ] if isinstance(user, Account): @@ -100,19 +101,18 @@ class MessageBasedAppGenerator(BaseAppGenerator): if not conversation: raise ConversationNotExistsError() - if conversation.status != 'normal': + if conversation.status != "normal": raise ConversationCompletedError() return conversation - def _get_app_model_config(self, app_model: App, - conversation: Optional[Conversation] = None) \ - -> AppModelConfig: + def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: if conversation: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) + .first() + ) if not app_model_config: raise AppModelConfigBrokenError() @@ -127,15 +127,16 @@ class MessageBasedAppGenerator(BaseAppGenerator): return app_model_config - def _init_generate_records(self, - application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], - conversation: Optional[Conversation] = None) \ - -> tuple[Conversation, Message]: + def _init_generate_records( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + conversation: Optional[Conversation] = None, + ) -> tuple[Conversation, Message]: """ Initialize generate records :param application_generate_entity: application generate entity @@ -148,10 +149,10 @@ class MessageBasedAppGenerator(BaseAppGenerator): end_user_id = None account_id = None if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - from_source = 'api' + from_source = "api" end_user_id = application_generate_entity.user_id else: - from_source = 'console' + from_source = "console" account_id = application_generate_entity.user_id if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): @@ -164,8 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator): model_provider = application_generate_entity.model_conf.provider model_id = application_generate_entity.model_conf.model override_model_configs = None - if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ - and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [ + AppMode.AGENT_CHAT, + AppMode.CHAT, + AppMode.COMPLETION, + ]: override_model_configs = app_config.app_model_config_dict # get conversation introduction @@ -179,12 +183,12 @@ class MessageBasedAppGenerator(BaseAppGenerator): model_id=model_id, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, mode=app_config.app_mode.value, - name='New conversation', + name="New conversation", inputs=application_generate_entity.inputs, introduction=introduction, system_instruction="", system_instruction_tokens=0, - status='normal', + status="normal", invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, @@ -216,11 +220,11 @@ class MessageBasedAppGenerator(BaseAppGenerator): answer_price_unit=0, provider_response_latency=0, total_price=0, - currency='USD', + currency="USD", invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, - from_account_id=account_id + from_account_id=account_id, ) db.session.add(message) @@ -232,10 +236,10 @@ class MessageBasedAppGenerator(BaseAppGenerator): message_id=message.id, type=file.type.value, transfer_method=file.transfer_method.value, - belongs_to='user', + belongs_to="user", url=file.url, upload_file_id=file.related_id, - created_by_role=('account' if account_id else 'end_user'), + created_by_role=("account" if account_id else "end_user"), created_by=account_id or end_user_id, ) db.session.add(message_file) @@ -269,11 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): :param conversation_id: conversation id :return: conversation """ - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if not conversation: raise ConversationNotExistsError() @@ -286,10 +286,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): :param message_id: message id :return: message """ - message = ( - db.session.query(Message) - .filter(Message.id == message_id) - .first() - ) + message = db.session.query(Message).filter(Message.id == message_id).first() return message diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index f4ff44ddda..7f259db6eb 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -12,12 +12,9 @@ from core.app.entities.queue_entities import ( class MessageBasedAppQueueManager(AppQueueManager): - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom, - conversation_id: str, - app_mode: str, - message_id: str) -> None: + def __init__( + self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str + ) -> None: super().__init__(task_id, user_id, invoke_from) self._conversation_id = str(conversation_id) @@ -30,7 +27,7 @@ class MessageBasedAppQueueManager(AppQueueManager): message_id=self._message_id, conversation_id=self._conversation_id, app_mode=self._app_mode, - event=event + event=event, ) def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: @@ -45,17 +42,15 @@ class MessageBasedAppQueueManager(AppQueueManager): message_id=self._message_id, conversation_id=self._conversation_id, app_mode=self._app_mode, - event=event + event=event, ) self._q.put(message) - if isinstance(event, QueueStopEvent - | QueueErrorEvent - | QueueMessageEndEvent - | QueueAdvancedChatMessageEndEvent): + if isinstance( + event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent + ): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): raise GenerateTaskStoppedException() - diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index 36d3696d60..8b98e74b85 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig): """ Workflow App Config Entity. """ + pass @@ -26,13 +27,9 @@ class WorkflowAppConfigManager(BaseAppConfigManager): app_id=app_model.id, app_mode=app_mode, workflow_id=workflow.id, - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=features_dict - ), - variables=WorkflowVariablesConfigManager.convert( - workflow=workflow - ), - additional_features=cls.convert_features(features_dict, app_mode) + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + additional_features=cls.convert_features(features_dict, app_mode), ) return app_config @@ -50,8 +47,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager): # file upload validation config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, - is_vision=False + config=config, is_vision=False ) related_config_keys.extend(current_related_config_keys) @@ -61,9 +57,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager): # moderation validation config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate ) related_config_keys.extend(current_related_config_keys) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 4347e5277b..c685008577 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -34,26 +34,28 @@ logger = logging.getLogger(__name__) class WorkflowAppGenerator(BaseAppGenerator): @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: Literal[True] = True, call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None + workflow_thread_pool_id: Optional[str] = None, ) -> Generator[str, None, None]: ... @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: Literal[False] = False, call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None + workflow_thread_pool_id: Optional[str] = None, ) -> dict: ... def generate( @@ -65,7 +67,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, stream: bool = True, call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None + workflow_thread_pool_id: Optional[str] = None, ): """ Generate App response. @@ -79,26 +81,19 @@ class WorkflowAppGenerator(BaseAppGenerator): :param call_depth: call depth :param workflow_thread_pool_id: workflow thread pool id """ - inputs = args['inputs'] + inputs = args["inputs"] # parse files - files = args['files'] if args.get('files') else [] + files = args["files"] if args.get("files") else [] message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) + file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) else: file_objs = [] # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # get tracing instance user_id = user.id if isinstance(user, Account) else user.session_id @@ -114,7 +109,7 @@ class WorkflowAppGenerator(BaseAppGenerator): stream=stream, invoke_from=invoke_from, call_depth=call_depth, - trace_manager=trace_manager + trace_manager=trace_manager, ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -125,18 +120,19 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity=application_generate_entity, invoke_from=invoke_from, stream=stream, - workflow_thread_pool_id=workflow_thread_pool_id + workflow_thread_pool_id=workflow_thread_pool_id, ) def _generate( - self, *, + self, + *, app_model: App, workflow: Workflow, user: Union[Account, EndUser], application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, stream: bool = True, - workflow_thread_pool_id: Optional[str] = None + workflow_thread_pool_id: Optional[str] = None, ) -> dict[str, Any] | Generator[str, None, None]: """ Generate App response. @@ -154,17 +150,20 @@ class WorkflowAppGenerator(BaseAppGenerator): task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, - app_mode=app_model.mode + app_mode=app_model.mode, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), # type: ignore - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'context': contextvars.copy_context(), - 'workflow_thread_pool_id': workflow_thread_pool_id - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": contextvars.copy_context(), + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) worker_thread.start() @@ -177,17 +176,11 @@ class WorkflowAppGenerator(BaseAppGenerator): stream=stream, ) - return WorkflowAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def single_iteration_generate(self, app_model: App, - workflow: Workflow, - node_id: str, - user: Account, - args: dict, - stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]: + def single_iteration_generate( + self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -199,16 +192,13 @@ class WorkflowAppGenerator(BaseAppGenerator): :param stream: is stream """ if not node_id: - raise ValueError('node_id is required') + raise ValueError("node_id is required") - if args.get('inputs') is None: - raise ValueError('inputs is required') + if args.get("inputs") is None: + raise ValueError("inputs is required") # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( @@ -219,13 +209,10 @@ class WorkflowAppGenerator(BaseAppGenerator): user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras={ - "auto_generate_conversation_name": False - }, + extras={"auto_generate_conversation_name": False}, single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( - node_id=node_id, - inputs=args['inputs'] - ) + node_id=node_id, inputs=args["inputs"] + ), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -235,14 +222,17 @@ class WorkflowAppGenerator(BaseAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, - stream=stream + stream=stream, ) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager, - context: contextvars.Context, - workflow_thread_pool_id: Optional[str] = None) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + context: contextvars.Context, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -259,7 +249,7 @@ class WorkflowAppGenerator(BaseAppGenerator): runner = WorkflowAppRunner( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - workflow_thread_pool_id=workflow_thread_pool_id + workflow_thread_pool_id=workflow_thread_pool_id, ) runner.run() @@ -267,14 +257,13 @@ class WorkflowAppGenerator(BaseAppGenerator): pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true": logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: @@ -283,14 +272,14 @@ class WorkflowAppGenerator(BaseAppGenerator): finally: db.session.close() - def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool = False) -> Union[ - WorkflowAppBlockingResponse, - Generator[WorkflowAppStreamResponse, None, None] - ]: + def _handle_response( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool = False, + ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ Handle response. :param application_generate_entity: application generate entity @@ -306,7 +295,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, - stream=stream + stream=stream, ) try: diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index f448138b53..c9f501cd5e 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -12,10 +12,7 @@ from core.app.entities.queue_entities import ( class WorkflowAppQueueManager(AppQueueManager): - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom, - app_mode: str) -> None: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: super().__init__(task_id, user_id, invoke_from) self._app_mode = app_mode @@ -27,19 +24,18 @@ class WorkflowAppQueueManager(AppQueueManager): :param pub_from: :return: """ - message = WorkflowQueueMessage( - task_id=self._task_id, - app_mode=self._app_mode, - event=event - ) + message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event) self._q.put(message) - if isinstance(event, QueueStopEvent - | QueueErrorEvent - | QueueMessageEndEvent - | QueueWorkflowSucceededEvent - | QueueWorkflowFailedEvent): + if isinstance( + event, + QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent, + ): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 9d48db7546..81c8463dd5 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -28,10 +28,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): """ def __init__( - self, - application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager, - workflow_thread_pool_id: Optional[str] = None + self, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + workflow_thread_pool_id: Optional[str] = None, ) -> None: """ :param application_generate_entity: application generate entity @@ -62,16 +62,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError('App not found') + raise ValueError("App not found") workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") db.session.close() workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): + if bool(os.environ.get("DEBUG", "False").lower() == "true"): workflow_callbacks.append(WorkflowLoggingCallback()) # if only single iteration run is requested @@ -80,10 +80,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( workflow=workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, - user_inputs=self.application_generate_entity.single_iteration_run.inputs + user_inputs=self.application_generate_entity.single_iteration_run.inputs, ) else: - inputs = self.application_generate_entity.inputs files = self.application_generate_entity.files @@ -120,12 +119,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): invoke_from=self.application_generate_entity.invoke_from, call_depth=self.application_generate_entity.call_depth, variable_pool=variable_pool, - thread_pool_id=self.workflow_thread_pool_id + thread_pool_id=self.workflow_thread_pool_id, ) - generator = workflow_entry.run( - callbacks=workflow_callbacks - ) + generator = workflow_entry.run(callbacks=workflow_callbacks) for event in generator: self._handle_event(workflow_entry, event) diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 88bde58ba0..08d00ee180 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -35,8 +35,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): return cls.convert_blocking_full_response(blocking_response) @classmethod - def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -47,12 +48,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'workflow_run_id': chunk.workflow_run_id, + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -63,8 +64,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -75,12 +77,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'workflow_run_id': chunk.workflow_run_id, + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, } if isinstance(sub_stream_response, ErrorStreamResponse): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 00b3b9f57e..904b649381 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -63,17 +63,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa """ WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _workflow: Workflow _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariableKey, Any] - def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool) -> None: + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -92,7 +96,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa self._workflow = workflow self._workflow_system_variables = { SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_id + SystemVariableKey.USER_ID: user_id, } self._task_state = WorkflowTaskState() @@ -106,16 +110,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa db.session.refresh(self._user) db.session.close() - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \ - -> WorkflowAppBlockingResponse: + def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse: """ To blocking response. :return: @@ -137,18 +138,19 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa total_tokens=stream_response.data.total_tokens, total_steps=stream_response.data.total_steps, created_at=int(stream_response.data.created_at), - finished_at=int(stream_response.data.finished_at) - ) + finished_at=int(stream_response.data.finished_at), + ), ) return response else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ - -> Generator[WorkflowAppStreamResponse, None, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[WorkflowAppStreamResponse, None, None]: """ To stream response. :return: @@ -158,10 +160,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if isinstance(stream_response, WorkflowStartStreamResponse): workflow_run_id = stream_response.workflow_run_id - yield WorkflowAppStreamResponse( - workflow_run_id=workflow_run_id, - stream_response=stream_response - ) + yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) def _listenAudioMsg(self, publisher, task_id: str): if not publisher: @@ -171,17 +170,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict - if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ - 'text_to_speech'].get('autoPlay') == 'enabled': - tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: @@ -210,13 +212,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa except Exception as e: logger.error(e) break - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) - + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( self, tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -241,22 +242,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa # init workflow run workflow_run = self._handle_workflow_run_start() yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start( - workflow_run=workflow_run, - event=event - ) + workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: @@ -267,7 +264,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: @@ -278,69 +275,61 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) if response: yield response elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationNextEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationCompletedEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None, + outputs=json.dumps(event.outputs) + if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs + else None, conversation_id=None, trace_manager=trace_manager, ) @@ -349,22 +338,23 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa self._save_workflow_app_log(workflow_run) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): if not workflow_run: - raise Exception('Workflow run not initialized.') + raise Exception("Workflow run not initialized.") if not graph_runtime_state: - raise Exception('Graph runtime state not initialized.') + raise Exception("Graph runtime state not initialized.") workflow_run = self._handle_workflow_run_failed( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED, + status=WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), conversation_id=None, trace_manager=trace_manager, @@ -374,8 +364,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa self._save_workflow_app_log(workflow_run) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueTextChunkEvent): delta_text = event.text @@ -394,7 +383,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if tts_publisher: tts_publisher.publish(None) - def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: """ Save workflow app log. @@ -417,7 +405,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa workflow_app_log.workflow_id = workflow_run.workflow_id workflow_app_log.workflow_run_id = workflow_run.id workflow_app_log.created_from = created_from.value - workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user' + workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" workflow_app_log.created_by = self._user.id db.session.add(workflow_app_log) @@ -431,8 +419,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa :return: """ response = TextChunkStreamResponse( - task_id=self._application_generate_entity.task_id, - data=TextChunkStreamResponse.Data(text=text) + task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text) ) return response diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 1709726887..ce266116a7 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -58,89 +58,86 @@ class WorkflowBasedAppRunner(AppRunner): """ Init graph """ - if 'nodes' not in graph_config or 'edges' not in graph_config: - raise ValueError('nodes or edges not found in workflow graph') + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") - if not isinstance(graph_config.get('nodes'), list): - raise ValueError('nodes in workflow graph must be a list') + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") - if not isinstance(graph_config.get('edges'), list): - raise ValueError('edges in workflow graph must be a list') + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") # init graph - graph = Graph.init( - graph_config=graph_config - ) + graph = Graph.init(graph_config=graph_config) if not graph: - raise ValueError('graph not found in workflow') - + raise ValueError("graph not found in workflow") + return graph def _get_graph_and_variable_pool_of_single_iteration( - self, - workflow: Workflow, - node_id: str, - user_inputs: dict, - ) -> tuple[Graph, VariablePool]: + self, + workflow: Workflow, + node_id: str, + user_inputs: dict, + ) -> tuple[Graph, VariablePool]: """ Get variable pool of single iteration """ # fetch workflow graph graph_config = workflow.graph_dict if not graph_config: - raise ValueError('workflow graph not found') - + raise ValueError("workflow graph not found") + graph_config = cast(dict[str, Any], graph_config) - if 'nodes' not in graph_config or 'edges' not in graph_config: - raise ValueError('nodes or edges not found in workflow graph') + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") - if not isinstance(graph_config.get('nodes'), list): - raise ValueError('nodes in workflow graph must be a list') + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") - if not isinstance(graph_config.get('edges'), list): - raise ValueError('edges in workflow graph must be a list') + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") # filter nodes only in iteration node_configs = [ - node for node in graph_config.get('nodes', []) - if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id + node + for node in graph_config.get("nodes", []) + if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id ] - graph_config['nodes'] = node_configs + graph_config["nodes"] = node_configs - node_ids = [node.get('id') for node in node_configs] + node_ids = [node.get("id") for node in node_configs] # filter edges only in iteration edge_configs = [ - edge for edge in graph_config.get('edges', []) - if (edge.get('source') is None or edge.get('source') in node_ids) - and (edge.get('target') is None or edge.get('target') in node_ids) + edge + for edge in graph_config.get("edges", []) + if (edge.get("source") is None or edge.get("source") in node_ids) + and (edge.get("target") is None or edge.get("target") in node_ids) ] - graph_config['edges'] = edge_configs + graph_config["edges"] = edge_configs # init graph - graph = Graph.init( - graph_config=graph_config, - root_node_id=node_id - ) + graph = Graph.init(graph_config=graph_config, root_node_id=node_id) if not graph: - raise ValueError('graph not found in workflow') - + raise ValueError("graph not found in workflow") + # fetch node config from node id iteration_node_config = None for node in node_configs: - if node.get('id') == node_id: + if node.get("id") == node_id: iteration_node_config = node break if not iteration_node_config: - raise ValueError('iteration node id not found in workflow graph') - + raise ValueError("iteration node id not found in workflow graph") + # Get node class - node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type')) + node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type")) node_cls = node_classes.get(node_type) node_cls = cast(type[BaseNode], node_cls) @@ -153,8 +150,7 @@ class WorkflowBasedAppRunner(AppRunner): try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, - config=iteration_node_config + graph_config=workflow.graph_dict, config=iteration_node_config ) except NotImplementedError: variable_mapping = {} @@ -165,7 +161,7 @@ class WorkflowBasedAppRunner(AppRunner): variable_pool=variable_pool, tenant_id=workflow.tenant_id, node_type=node_type, - node_data=IterationNodeData(**iteration_node_config.get('data', {})) + node_data=IterationNodeData(**iteration_node_config.get("data", {})), ) return graph, variable_pool @@ -178,18 +174,12 @@ class WorkflowBasedAppRunner(AppRunner): """ if isinstance(event, GraphRunStartedEvent): self._publish_event( - QueueWorkflowStartedEvent( - graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state - ) + QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state) ) elif isinstance(event, GraphRunSucceededEvent): - self._publish_event( - QueueWorkflowSucceededEvent(outputs=event.outputs) - ) + self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) elif isinstance(event, GraphRunFailedEvent): - self._publish_event( - QueueWorkflowFailedEvent(error=event.error) - ) + self._publish_event(QueueWorkflowFailedEvent(error=event.error)) elif isinstance(event, NodeRunStartedEvent): self._publish_event( QueueNodeStartedEvent( @@ -204,7 +194,7 @@ class WorkflowBasedAppRunner(AppRunner): start_at=event.route_node_state.start_at, node_run_index=event.route_node_state.index, predecessor_node_id=event.predecessor_node_id, - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, NodeRunSucceededEvent): @@ -220,14 +210,18 @@ class WorkflowBasedAppRunner(AppRunner): parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.route_node_state.start_at, inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result else {}, + if event.route_node_state.node_run_result + else {}, process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result else {}, + if event.route_node_state.node_run_result + else {}, outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result else {}, + if event.route_node_state.node_run_result + else {}, execution_metadata=event.route_node_state.node_run_result.metadata - if event.route_node_state.node_run_result else {}, - in_iteration_id=event.in_iteration_id + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, NodeRunFailedEvent): @@ -243,16 +237,18 @@ class WorkflowBasedAppRunner(AppRunner): parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.route_node_state.start_at, inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result else {}, - process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result else {}, - error=event.route_node_state.node_run_result.error if event.route_node_state.node_run_result - and event.route_node_state.node_run_result.error + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error else "Unknown error", - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, NodeRunStreamChunkEvent): @@ -260,14 +256,13 @@ class WorkflowBasedAppRunner(AppRunner): QueueTextChunkEvent( text=event.chunk_content, from_variable_selector=event.from_variable_selector, - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, NodeRunRetrieverResourceEvent): self._publish_event( QueueRetrieverResourcesEvent( - retriever_resources=event.retriever_resources, - in_iteration_id=event.in_iteration_id + retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id ) ) elif isinstance(event, ParallelBranchRunStartedEvent): @@ -277,7 +272,7 @@ class WorkflowBasedAppRunner(AppRunner): parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, ParallelBranchRunSucceededEvent): @@ -287,7 +282,7 @@ class WorkflowBasedAppRunner(AppRunner): parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, - in_iteration_id=event.in_iteration_id + in_iteration_id=event.in_iteration_id, ) ) elif isinstance(event, ParallelBranchRunFailedEvent): @@ -298,7 +293,7 @@ class WorkflowBasedAppRunner(AppRunner): parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, in_iteration_id=event.in_iteration_id, - error=event.error + error=event.error, ) ) elif isinstance(event, IterationRunStartedEvent): @@ -316,7 +311,7 @@ class WorkflowBasedAppRunner(AppRunner): node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, inputs=event.inputs, predecessor_node_id=event.predecessor_node_id, - metadata=event.metadata + metadata=event.metadata, ) ) elif isinstance(event, IterationRunNextEvent): @@ -352,7 +347,7 @@ class WorkflowBasedAppRunner(AppRunner): outputs=event.outputs, metadata=event.metadata, steps=event.steps, - error=event.error if isinstance(event, IterationRunFailedEvent) else None + error=event.error if isinstance(event, IterationRunFailedEvent) else None, ) ) @@ -371,9 +366,6 @@ class WorkflowBasedAppRunner(AppRunner): # return workflow return workflow - + def _publish_event(self, event: AppQueueEvent) -> None: - self.queue_manager.publish( - event, - PublishFrom.APPLICATION_MANAGER - ) + self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py index 4e8f3644b1..cdd21bf7c2 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/app/apps/workflow_logging_callback.py @@ -30,169 +30,145 @@ _TEXT_COLOR_MAPPING = { class WorkflowLoggingCallback(WorkflowCallback): - def __init__(self) -> None: self.current_node_id = None - def on_event( - self, - event: GraphEngineEvent - ) -> None: + def on_event(self, event: GraphEngineEvent) -> None: if isinstance(event, GraphRunStartedEvent): - self.print_text("\n[GraphRunStartedEvent]", color='pink') + self.print_text("\n[GraphRunStartedEvent]", color="pink") elif isinstance(event, GraphRunSucceededEvent): - self.print_text("\n[GraphRunSucceededEvent]", color='green') + self.print_text("\n[GraphRunSucceededEvent]", color="green") elif isinstance(event, GraphRunFailedEvent): - self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red') + self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red") elif isinstance(event, NodeRunStartedEvent): - self.on_workflow_node_execute_started( - event=event - ) + self.on_workflow_node_execute_started(event=event) elif isinstance(event, NodeRunSucceededEvent): - self.on_workflow_node_execute_succeeded( - event=event - ) + self.on_workflow_node_execute_succeeded(event=event) elif isinstance(event, NodeRunFailedEvent): - self.on_workflow_node_execute_failed( - event=event - ) + self.on_workflow_node_execute_failed(event=event) elif isinstance(event, NodeRunStreamChunkEvent): - self.on_node_text_chunk( - event=event - ) + self.on_node_text_chunk(event=event) elif isinstance(event, ParallelBranchRunStartedEvent): - self.on_workflow_parallel_started( - event=event - ) + self.on_workflow_parallel_started(event=event) elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent): - self.on_workflow_parallel_completed( - event=event - ) + self.on_workflow_parallel_completed(event=event) elif isinstance(event, IterationRunStartedEvent): - self.on_workflow_iteration_started( - event=event - ) + self.on_workflow_iteration_started(event=event) elif isinstance(event, IterationRunNextEvent): - self.on_workflow_iteration_next( - event=event - ) + self.on_workflow_iteration_next(event=event) elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent): - self.on_workflow_iteration_completed( - event=event - ) + self.on_workflow_iteration_completed(event=event) else: - self.print_text(f"\n[{event.__class__.__name__}]", color='blue') + self.print_text(f"\n[{event.__class__.__name__}]", color="blue") - def on_workflow_node_execute_started( - self, - event: NodeRunStartedEvent - ) -> None: + def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None: """ Workflow node execute started """ - self.print_text("\n[NodeRunStartedEvent]", color='yellow') - self.print_text(f"Node ID: {event.node_id}", color='yellow') - self.print_text(f"Node Title: {event.node_data.title}", color='yellow') - self.print_text(f"Type: {event.node_type.value}", color='yellow') + self.print_text("\n[NodeRunStartedEvent]", color="yellow") + self.print_text(f"Node ID: {event.node_id}", color="yellow") + self.print_text(f"Node Title: {event.node_data.title}", color="yellow") + self.print_text(f"Type: {event.node_type.value}", color="yellow") - def on_workflow_node_execute_succeeded( - self, - event: NodeRunSucceededEvent - ) -> None: + def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None: """ Workflow node execute succeeded """ route_node_state = event.route_node_state - self.print_text("\n[NodeRunSucceededEvent]", color='green') - self.print_text(f"Node ID: {event.node_id}", color='green') - self.print_text(f"Node Title: {event.node_data.title}", color='green') - self.print_text(f"Type: {event.node_type.value}", color='green') + self.print_text("\n[NodeRunSucceededEvent]", color="green") + self.print_text(f"Node ID: {event.node_id}", color="green") + self.print_text(f"Node Title: {event.node_data.title}", color="green") + self.print_text(f"Type: {event.node_type.value}", color="green") if route_node_state.node_run_result: node_run_result = route_node_state.node_run_result - self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", - color='green') + self.print_text( + f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="green" + ) self.print_text( f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", - color='green') - self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", - color='green') + color="green", + ) + self.print_text( + f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color="green", + ) self.print_text( f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}", - color='green') + color="green", + ) - def on_workflow_node_execute_failed( - self, - event: NodeRunFailedEvent - ) -> None: + def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None: """ Workflow node execute failed """ route_node_state = event.route_node_state - self.print_text("\n[NodeRunFailedEvent]", color='red') - self.print_text(f"Node ID: {event.node_id}", color='red') - self.print_text(f"Node Title: {event.node_data.title}", color='red') - self.print_text(f"Type: {event.node_type.value}", color='red') + self.print_text("\n[NodeRunFailedEvent]", color="red") + self.print_text(f"Node ID: {event.node_id}", color="red") + self.print_text(f"Node Title: {event.node_data.title}", color="red") + self.print_text(f"Type: {event.node_type.value}", color="red") if route_node_state.node_run_result: node_run_result = route_node_state.node_run_result - self.print_text(f"Error: {node_run_result.error}", color='red') - self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", - color='red') + self.print_text(f"Error: {node_run_result.error}", color="red") + self.print_text( + f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="red" + ) self.print_text( f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", - color='red') - self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", - color='red') + color="red", + ) + self.print_text( + f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color="red" + ) - def on_node_text_chunk( - self, - event: NodeRunStreamChunkEvent - ) -> None: + def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None: """ Publish text chunk """ route_node_state = event.route_node_state if not self.current_node_id or self.current_node_id != route_node_state.node_id: self.current_node_id = route_node_state.node_id - self.print_text('\n[NodeRunStreamChunkEvent]') + self.print_text("\n[NodeRunStreamChunkEvent]") self.print_text(f"Node ID: {route_node_state.node_id}") node_run_result = route_node_state.node_run_result if node_run_result: self.print_text( - f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}") + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}" + ) self.print_text(event.chunk_content, color="pink", end="") - def on_workflow_parallel_started( - self, - event: ParallelBranchRunStartedEvent - ) -> None: + def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None: """ Publish parallel started """ - self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue') - self.print_text(f"Parallel ID: {event.parallel_id}", color='blue') - self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue') + self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue") + self.print_text(f"Parallel ID: {event.parallel_id}", color="blue") + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue") if event.in_iteration_id: - self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue') + self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue") def on_workflow_parallel_completed( - self, - event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent + self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent ) -> None: """ Publish parallel completed """ if isinstance(event, ParallelBranchRunSucceededEvent): - color = 'blue' + color = "blue" elif isinstance(event, ParallelBranchRunFailedEvent): - color = 'red' + color = "red" - self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color) + self.print_text( + "\n[ParallelBranchRunSucceededEvent]" + if isinstance(event, ParallelBranchRunSucceededEvent) + else "\n[ParallelBranchRunFailedEvent]", + color=color, + ) self.print_text(f"Parallel ID: {event.parallel_id}", color=color) self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color) if event.in_iteration_id: @@ -201,43 +177,37 @@ class WorkflowLoggingCallback(WorkflowCallback): if isinstance(event, ParallelBranchRunFailedEvent): self.print_text(f"Error: {event.error}", color=color) - def on_workflow_iteration_started( - self, - event: IterationRunStartedEvent - ) -> None: + def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None: """ Publish iteration started """ - self.print_text("\n[IterationRunStartedEvent]", color='blue') - self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') + self.print_text("\n[IterationRunStartedEvent]", color="blue") + self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") - def on_workflow_iteration_next( - self, - event: IterationRunNextEvent - ) -> None: + def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None: """ Publish iteration next """ - self.print_text("\n[IterationRunNextEvent]", color='blue') - self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue') - self.print_text(f"Iteration Index: {event.index}", color='blue') + self.print_text("\n[IterationRunNextEvent]", color="blue") + self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") + self.print_text(f"Iteration Index: {event.index}", color="blue") - def on_workflow_iteration_completed( - self, - event: IterationRunSucceededEvent | IterationRunFailedEvent - ) -> None: + def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None: """ Publish iteration completed """ - self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue') - self.print_text(f"Node ID: {event.iteration_id}", color='blue') + self.print_text( + "\n[IterationRunSucceededEvent]" + if isinstance(event, IterationRunSucceededEvent) + else "\n[IterationRunFailedEvent]", + color="blue", + ) + self.print_text(f"Node ID: {event.iteration_id}", color="blue") - def print_text( - self, text: str, color: Optional[str] = None, end: str = "\n" - ) -> None: + def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None: """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text - print(f'{text_to_print}', end=end) + print(f"{text_to_print}", end=end) def _get_colored_text(self, text: str, color: str) -> str: """Get colored text.""" diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 6a1ab23041..ab8d4e374e 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -15,13 +15,14 @@ class InvokeFrom(Enum): """ Invoke From. """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - EXPLORE = 'explore' - DEBUGGER = 'debugger' + + SERVICE_API = "service-api" + WEB_APP = "web-app" + EXPLORE = "explore" + DEBUGGER = "debugger" @classmethod - def value_of(cls, value: str) -> 'InvokeFrom': + def value_of(cls, value: str) -> "InvokeFrom": """ Get value of given mode. @@ -31,7 +32,7 @@ class InvokeFrom(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid invoke from value {value}') + raise ValueError(f"invalid invoke from value {value}") def to_source(self) -> str: """ @@ -40,21 +41,22 @@ class InvokeFrom(Enum): :return: source """ if self == InvokeFrom.WEB_APP: - return 'web_app' + return "web_app" elif self == InvokeFrom.DEBUGGER: - return 'dev' + return "dev" elif self == InvokeFrom.EXPLORE: - return 'explore_app' + return "explore_app" elif self == InvokeFrom.SERVICE_API: - return 'api' + return "api" - return 'dev' + return "dev" class ModelConfigWithCredentialsEntity(BaseModel): """ Model Config With Credentials Entity. """ + provider: str model: str model_schema: AIModelEntity @@ -72,6 +74,7 @@ class AppGenerateEntity(BaseModel): """ App Generate Entity. """ + task_id: str # app config @@ -102,6 +105,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): """ Chat Application Generate Entity. """ + # app config app_config: EasyUIBasedAppConfig model_conf: ModelConfigWithCredentialsEntity @@ -116,6 +120,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Chat Application Generate Entity. """ + conversation_id: Optional[str] = None @@ -123,6 +128,7 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Completion Application Generate Entity. """ + pass @@ -130,6 +136,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Agent Chat Application Generate Entity. """ + conversation_id: Optional[str] = None @@ -137,6 +144,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): """ Advanced Chat Application Generate Entity. """ + # app config app_config: WorkflowUIBasedAppConfig @@ -147,15 +155,18 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): """ Single Iteration Run Entity. """ + node_id: str inputs: dict single_iteration_run: Optional[SingleIterationRunEntity] = None + class WorkflowAppGenerateEntity(AppGenerateEntity): """ Workflow Application Generate Entity. """ + # app config app_config: WorkflowUIBasedAppConfig @@ -163,6 +174,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): """ Single Iteration Run Entity. """ + node_id: str inputs: dict diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 4c86b7eee1..4577e28535 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -14,6 +14,7 @@ class QueueEvent(str, Enum): """ QueueEvent enum """ + LLM_CHUNK = "llm_chunk" TEXT_CHUNK = "text_chunk" AGENT_MESSAGE = "agent_message" @@ -45,6 +46,7 @@ class AppQueueEvent(BaseModel): """ QueueEvent abstract entity """ + event: QueueEvent @@ -53,13 +55,16 @@ class QueueLLMChunkEvent(AppQueueEvent): QueueLLMChunkEvent entity Only for basic mode apps """ + event: QueueEvent = QueueEvent.LLM_CHUNK chunk: LLMResultChunk + class QueueIterationStartEvent(AppQueueEvent): """ QueueIterationStartEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_START node_execution_id: str node_id: str @@ -80,10 +85,12 @@ class QueueIterationStartEvent(AppQueueEvent): predecessor_node_id: Optional[str] = None metadata: Optional[dict[str, Any]] = None + class QueueIterationNextEvent(AppQueueEvent): """ QueueIterationNextEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_NEXT index: int @@ -101,9 +108,9 @@ class QueueIterationNextEvent(AppQueueEvent): """parent parallel start node id if node is in parallel""" node_run_index: int - output: Optional[Any] = None # output for the current iteration + output: Optional[Any] = None # output for the current iteration - @field_validator('output', mode='before') + @field_validator("output", mode="before") @classmethod def set_output(cls, v): """ @@ -113,12 +120,14 @@ class QueueIterationNextEvent(AppQueueEvent): return None if isinstance(v, int | float | str | bool | dict | list): return v - raise ValueError('output must be a valid type') + raise ValueError("output must be a valid type") + class QueueIterationCompletedEvent(AppQueueEvent): """ QueueIterationCompletedEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_COMPLETED node_execution_id: str @@ -134,7 +143,7 @@ class QueueIterationCompletedEvent(AppQueueEvent): parent_parallel_start_node_id: Optional[str] = None """parent parallel start node id if node is in parallel""" start_at: datetime - + node_run_index: int inputs: Optional[dict[str, Any]] = None outputs: Optional[dict[str, Any]] = None @@ -148,6 +157,7 @@ class QueueTextChunkEvent(AppQueueEvent): """ QueueTextChunkEvent entity """ + event: QueueEvent = QueueEvent.TEXT_CHUNK text: str from_variable_selector: Optional[list[str]] = None @@ -160,14 +170,16 @@ class QueueAgentMessageEvent(AppQueueEvent): """ QueueMessageEvent entity """ + event: QueueEvent = QueueEvent.AGENT_MESSAGE chunk: LLMResultChunk - + class QueueMessageReplaceEvent(AppQueueEvent): """ QueueMessageReplaceEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_REPLACE text: str @@ -176,6 +188,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): """ QueueRetrieverResourcesEvent entity """ + event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES retriever_resources: list[dict] in_iteration_id: Optional[str] = None @@ -186,6 +199,7 @@ class QueueAnnotationReplyEvent(AppQueueEvent): """ QueueAnnotationReplyEvent entity """ + event: QueueEvent = QueueEvent.ANNOTATION_REPLY message_annotation_id: str @@ -194,6 +208,7 @@ class QueueMessageEndEvent(AppQueueEvent): """ QueueMessageEndEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_END llm_result: Optional[LLMResult] = None @@ -202,6 +217,7 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent): """ QueueAdvancedChatMessageEndEvent entity """ + event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END @@ -209,6 +225,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent): """ QueueWorkflowStartedEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_STARTED graph_runtime_state: GraphRuntimeState @@ -217,6 +234,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent): """ QueueWorkflowSucceededEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED outputs: Optional[dict[str, Any]] = None @@ -225,6 +243,7 @@ class QueueWorkflowFailedEvent(AppQueueEvent): """ QueueWorkflowFailedEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_FAILED error: str @@ -233,6 +252,7 @@ class QueueNodeStartedEvent(AppQueueEvent): """ QueueNodeStartedEvent entity """ + event: QueueEvent = QueueEvent.NODE_STARTED node_execution_id: str @@ -258,6 +278,7 @@ class QueueNodeSucceededEvent(AppQueueEvent): """ QueueNodeSucceededEvent entity """ + event: QueueEvent = QueueEvent.NODE_SUCCEEDED node_execution_id: str @@ -288,6 +309,7 @@ class QueueNodeFailedEvent(AppQueueEvent): """ QueueNodeFailedEvent entity """ + event: QueueEvent = QueueEvent.NODE_FAILED node_execution_id: str @@ -317,6 +339,7 @@ class QueueAgentThoughtEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ + event: QueueEvent = QueueEvent.AGENT_THOUGHT agent_thought_id: str @@ -325,6 +348,7 @@ class QueueMessageFileEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_FILE message_file_id: str @@ -333,6 +357,7 @@ class QueueErrorEvent(AppQueueEvent): """ QueueErrorEvent entity """ + event: QueueEvent = QueueEvent.ERROR error: Any = None @@ -341,6 +366,7 @@ class QueuePingEvent(AppQueueEvent): """ QueuePingEvent entity """ + event: QueueEvent = QueueEvent.PING @@ -348,10 +374,12 @@ class QueueStopEvent(AppQueueEvent): """ QueueStopEvent entity """ + class StopBy(Enum): """ Stop by enum """ + USER_MANUAL = "user-manual" ANNOTATION_REPLY = "annotation-reply" OUTPUT_MODERATION = "output-moderation" @@ -365,19 +393,20 @@ class QueueStopEvent(AppQueueEvent): To stop reason """ reason_mapping = { - QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.', - QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.', - QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.', - QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.' + QueueStopEvent.StopBy.USER_MANUAL: "Stopped by user.", + QueueStopEvent.StopBy.ANNOTATION_REPLY: "Stopped by annotation reply.", + QueueStopEvent.StopBy.OUTPUT_MODERATION: "Stopped by output moderation.", + QueueStopEvent.StopBy.INPUT_MODERATION: "Stopped by input moderation.", } - return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.') + return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.") class QueueMessage(BaseModel): """ QueueMessage abstract entity """ + task_id: str app_mode: str event: AppQueueEvent @@ -387,6 +416,7 @@ class MessageQueueMessage(QueueMessage): """ MessageQueueMessage entity """ + message_id: str conversation_id: str @@ -395,6 +425,7 @@ class WorkflowQueueMessage(QueueMessage): """ WorkflowQueueMessage entity """ + pass @@ -402,6 +433,7 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent): """ QueueParallelBranchRunStartedEvent entity """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED parallel_id: str @@ -418,6 +450,7 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent): """ QueueParallelBranchRunSucceededEvent entity """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED parallel_id: str @@ -434,6 +467,7 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent): """ QueueParallelBranchRunFailedEvent entity """ + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED parallel_id: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7cab6ca4e0..0135c97172 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -12,6 +12,7 @@ class TaskState(BaseModel): """ TaskState entity """ + metadata: dict = {} @@ -19,6 +20,7 @@ class EasyUITaskState(TaskState): """ EasyUITaskState entity """ + llm_result: LLMResult @@ -26,6 +28,7 @@ class WorkflowTaskState(TaskState): """ WorkflowTaskState entity """ + answer: str = "" @@ -33,6 +36,7 @@ class StreamEvent(Enum): """ Stream event """ + PING = "ping" ERROR = "error" MESSAGE = "message" @@ -60,6 +64,7 @@ class StreamResponse(BaseModel): """ StreamResponse entity """ + event: StreamEvent task_id: str @@ -71,6 +76,7 @@ class ErrorStreamResponse(StreamResponse): """ ErrorStreamResponse entity """ + event: StreamEvent = StreamEvent.ERROR err: Exception model_config = ConfigDict(arbitrary_types_allowed=True) @@ -80,6 +86,7 @@ class MessageStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE id: str answer: str @@ -89,6 +96,7 @@ class MessageAudioStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.TTS_MESSAGE audio: str @@ -97,6 +105,7 @@ class MessageAudioEndStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.TTS_MESSAGE_END audio: str @@ -105,6 +114,7 @@ class MessageEndStreamResponse(StreamResponse): """ MessageEndStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_END id: str metadata: dict = {} @@ -114,6 +124,7 @@ class MessageFileStreamResponse(StreamResponse): """ MessageFileStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_FILE id: str type: str @@ -125,6 +136,7 @@ class MessageReplaceStreamResponse(StreamResponse): """ MessageReplaceStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_REPLACE answer: str @@ -133,6 +145,7 @@ class AgentThoughtStreamResponse(StreamResponse): """ AgentThoughtStreamResponse entity """ + event: StreamEvent = StreamEvent.AGENT_THOUGHT id: str position: int @@ -148,6 +161,7 @@ class AgentMessageStreamResponse(StreamResponse): """ AgentMessageStreamResponse entity """ + event: StreamEvent = StreamEvent.AGENT_MESSAGE id: str answer: str @@ -162,6 +176,7 @@ class WorkflowStartStreamResponse(StreamResponse): """ Data entity """ + id: str workflow_id: str sequence_number: int @@ -182,6 +197,7 @@ class WorkflowFinishStreamResponse(StreamResponse): """ Data entity """ + id: str workflow_id: str sequence_number: int @@ -210,6 +226,7 @@ class NodeStartStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -249,7 +266,7 @@ class NodeStartStreamResponse(StreamResponse): "parent_parallel_id": self.data.parent_parallel_id, "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, "iteration_id": self.data.iteration_id, - } + }, } @@ -262,6 +279,7 @@ class NodeFinishStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -315,9 +333,9 @@ class NodeFinishStreamResponse(StreamResponse): "parent_parallel_id": self.data.parent_parallel_id, "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, "iteration_id": self.data.iteration_id, - } + }, } - + class ParallelBranchStartStreamResponse(StreamResponse): """ @@ -328,6 +346,7 @@ class ParallelBranchStartStreamResponse(StreamResponse): """ Data entity """ + parallel_id: str parallel_branch_id: str parent_parallel_id: Optional[str] = None @@ -349,6 +368,7 @@ class ParallelBranchFinishedStreamResponse(StreamResponse): """ Data entity """ + parallel_id: str parallel_branch_id: str parent_parallel_id: Optional[str] = None @@ -372,6 +392,7 @@ class IterationNodeStartStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -397,6 +418,7 @@ class IterationNodeNextStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -422,6 +444,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse): """ Data entity """ + id: str node_id: str node_type: str @@ -454,6 +477,7 @@ class TextChunkStreamResponse(StreamResponse): """ Data entity """ + text: str event: StreamEvent = StreamEvent.TEXT_CHUNK @@ -469,6 +493,7 @@ class TextReplaceStreamResponse(StreamResponse): """ Data entity """ + text: str event: StreamEvent = StreamEvent.TEXT_REPLACE @@ -479,6 +504,7 @@ class PingStreamResponse(StreamResponse): """ PingStreamResponse entity """ + event: StreamEvent = StreamEvent.PING @@ -486,6 +512,7 @@ class AppStreamResponse(BaseModel): """ AppStreamResponse entity """ + stream_response: StreamResponse @@ -493,6 +520,7 @@ class ChatbotAppStreamResponse(AppStreamResponse): """ ChatbotAppStreamResponse entity """ + conversation_id: str message_id: str created_at: int @@ -502,6 +530,7 @@ class CompletionAppStreamResponse(AppStreamResponse): """ CompletionAppStreamResponse entity """ + message_id: str created_at: int @@ -510,6 +539,7 @@ class WorkflowAppStreamResponse(AppStreamResponse): """ WorkflowAppStreamResponse entity """ + workflow_run_id: Optional[str] = None @@ -517,6 +547,7 @@ class AppBlockingResponse(BaseModel): """ AppBlockingResponse entity """ + task_id: str def to_dict(self) -> dict: @@ -532,6 +563,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse): """ Data entity """ + id: str mode: str conversation_id: str @@ -552,6 +584,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse): """ Data entity """ + id: str mode: str message_id: str @@ -571,6 +604,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): """ Data entity """ + id: str workflow_id: str status: str diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 19ff94de5e..2e37a126c3 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -13,11 +13,9 @@ logger = logging.getLogger(__name__) class AnnotationReplyFeature: - def query(self, app_record: App, - message: Message, - query: str, - user_id: str, - invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: + def query( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: """ Query app annotations to reply :param app_record: app record @@ -27,8 +25,9 @@ class AnnotationReplyFeature: :param invoke_from: invoke from :return: """ - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_record.id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first() + ) if not annotation_setting: return None @@ -41,55 +40,50 @@ class AnnotationReplyFeature: embedding_model_name = collection_binding_detail.model_name dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, - embedding_model_name, - 'annotation' + embedding_provider_name, embedding_model_name, "annotation" ) dataset = Dataset( id=app_record.id, tenant_id=app_record.tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) documents = vector.search_by_vector( - query=query, - top_k=1, - score_threshold=score_threshold, - filter={ - 'group_id': [dataset.id] - } + query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]} ) if documents: - annotation_id = documents[0].metadata['annotation_id'] - score = documents[0].metadata['score'] + annotation_id = documents[0].metadata["annotation_id"] + score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) if annotation: if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: - from_source = 'api' + from_source = "api" else: - from_source = 'console' + from_source = "console" # insert annotation history - AppAnnotationService.add_annotation_history(annotation.id, - app_record.id, - annotation.question, - annotation.content, - query, - user_id, - message.id, - from_source, - score) + AppAnnotationService.add_annotation_history( + annotation.id, + app_record.id, + annotation.question, + annotation.content, + query, + user_id, + message.id, + from_source, + score, + ) return annotation except Exception as e: - logger.warning(f'Query annotation failed, exception: {str(e)}.') + logger.warning(f"Query annotation failed, exception: {str(e)}.") return None return None diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index b8f3e0e1f6..ba14b61201 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -8,8 +8,9 @@ logger = logging.getLogger(__name__) class HostingModerationFeature: - def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity, - prompt_messages: list[PromptMessage]) -> bool: + def check( + self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage] + ) -> bool: """ Check hosting moderation :param application_generate_entity: application generate entity @@ -23,9 +24,6 @@ class HostingModerationFeature: if isinstance(prompt_message.content, str): text += prompt_message.content + "\n" - moderation_result = moderation.check_moderation( - model_config, - text - ) + moderation_result = moderation.check_moderation(model_config, text) return moderation_result diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index f11e8021f0..227182f5ab 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,7 +19,7 @@ class RateLimit: _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict = {} - def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int): + def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: instance = super().__new__(cls) cls._instance_dict[client_id] = instance @@ -27,13 +27,13 @@ class RateLimit: def __init__(self, client_id: str, max_active_requests: int): self.max_active_requests = max_active_requests - if hasattr(self, 'initialized'): + if hasattr(self, "initialized"): return self.initialized = True self.client_id = client_id self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id) self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id) - self.last_recalculate_time = float('-inf') + self.last_recalculate_time = float("-inf") self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): @@ -46,7 +46,7 @@ class RateLimit: pipe.execute() else: with redis_client.pipeline() as pipe: - self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8')) + self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) redis_client.expire(self.max_active_requests_key, timedelta(days=1)) # flush max active requests (in-transit request list) @@ -54,8 +54,11 @@ class RateLimit: return request_details = redis_client.hgetall(self.active_requests_key) redis_client.expire(self.active_requests_key, timedelta(days=1)) - timeout_requests = [k for k, v in request_details.items() if - time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME] + timeout_requests = [ + k + for k, v in request_details.items() + if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME + ] if timeout_requests: redis_client.hdel(self.active_requests_key, *timeout_requests) @@ -69,8 +72,10 @@ class RateLimit: active_requests_count = redis_client.hlen(self.active_requests_key) if active_requests_count >= self.max_active_requests: - raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum " - "concurrent requests allowed is {}.".format(self.max_active_requests)) + raise AppInvokeQuotaExceededError( + "Too many requests. Please try again later. The current maximum " + "concurrent requests allowed is {}.".format(self.max_active_requests) + ) redis_client.hset(self.active_requests_key, request_id, str(time.time())) return request_id @@ -116,5 +121,5 @@ class RateLimitGenerator: if not self.closed: self.closed = True self.rate_limit.exit(self.request_id) - if self.generator is not None and hasattr(self.generator, 'close'): + if self.generator is not None and hasattr(self.generator, "close"): self.generator.close() diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py index 7de06dfb96..652ef243b4 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/app/segments/__init__.py @@ -25,25 +25,25 @@ from .variables import ( ) __all__ = [ - 'IntegerVariable', - 'FloatVariable', - 'ObjectVariable', - 'SecretVariable', - 'StringVariable', - 'ArrayAnyVariable', - 'Variable', - 'SegmentType', - 'SegmentGroup', - 'Segment', - 'NoneSegment', - 'NoneVariable', - 'IntegerSegment', - 'FloatSegment', - 'ObjectSegment', - 'ArrayAnySegment', - 'StringSegment', - 'ArrayStringVariable', - 'ArrayNumberVariable', - 'ArrayObjectVariable', - 'ArraySegment', + "IntegerVariable", + "FloatVariable", + "ObjectVariable", + "SecretVariable", + "StringVariable", + "ArrayAnyVariable", + "Variable", + "SegmentType", + "SegmentGroup", + "Segment", + "NoneSegment", + "NoneVariable", + "IntegerSegment", + "FloatSegment", + "ObjectSegment", + "ArrayAnySegment", + "StringSegment", + "ArrayStringVariable", + "ArrayNumberVariable", + "ArrayObjectVariable", + "ArraySegment", ] diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py index e6e9ce9774..40a69ed4eb 100644 --- a/api/core/app/segments/factory.py +++ b/api/core/app/segments/factory.py @@ -28,12 +28,12 @@ from .variables import ( def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: - if (value_type := mapping.get('value_type')) is None: - raise VariableError('missing value type') - if not mapping.get('name'): - raise VariableError('missing name') - if (value := mapping.get('value')) is None: - raise VariableError('missing value') + if (value_type := mapping.get("value_type")) is None: + raise VariableError("missing value type") + if not mapping.get("name"): + raise VariableError("missing name") + if (value := mapping.get("value")) is None: + raise VariableError("missing value") match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -44,7 +44,7 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: case SegmentType.NUMBER if isinstance(value, float): result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): - raise VariableError(f'invalid number value {value}') + raise VariableError(f"invalid number value {value}") case SegmentType.OBJECT if isinstance(value, dict): result = ObjectVariable.model_validate(mapping) case SegmentType.ARRAY_STRING if isinstance(value, list): @@ -54,9 +54,9 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: case SegmentType.ARRAY_OBJECT if isinstance(value, list): result = ArrayObjectVariable.model_validate(mapping) case _: - raise VariableError(f'not supported value type {value_type}') + raise VariableError(f"not supported value type {value_type}") if result.size > dify_config.MAX_VARIABLE_SIZE: - raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}') + raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") return result @@ -73,4 +73,4 @@ def build_segment(value: Any, /) -> Segment: return ObjectSegment(value=value) if isinstance(value, list): return ArrayAnySegment(value=value) - raise ValueError(f'not supported value {value}') + raise ValueError(f"not supported value {value}") diff --git a/api/core/app/segments/parser.py b/api/core/app/segments/parser.py index de6c796652..3c4d7046f4 100644 --- a/api/core/app/segments/parser.py +++ b/api/core/app/segments/parser.py @@ -4,14 +4,14 @@ from core.workflow.entities.variable_pool import VariablePool from . import SegmentGroup, factory -VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}') +VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") def convert_template(*, template: str, variable_pool: VariablePool): parts = re.split(VARIABLE_PATTERN, template) segments = [] for part in filter(lambda x: x, parts): - if '.' in part and (value := variable_pool.get(part.split('.'))): + if "." in part and (value := variable_pool.get(part.split("."))): segments.append(value) else: segments.append(factory.build_segment(part)) diff --git a/api/core/app/segments/segment_group.py b/api/core/app/segments/segment_group.py index b4ff09b6d3..b363255b2c 100644 --- a/api/core/app/segments/segment_group.py +++ b/api/core/app/segments/segment_group.py @@ -8,15 +8,15 @@ class SegmentGroup(Segment): @property def text(self): - return ''.join([segment.text for segment in self.value]) + return "".join([segment.text for segment in self.value]) @property def log(self): - return ''.join([segment.log for segment in self.value]) + return "".join([segment.log for segment in self.value]) @property def markdown(self): - return ''.join([segment.markdown for segment in self.value]) + return "".join([segment.markdown for segment in self.value]) def to_object(self): return [segment.to_object() for segment in self.value] diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index 5c713cac67..b71924b2d3 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -14,13 +14,13 @@ class Segment(BaseModel): value_type: SegmentType value: Any - @field_validator('value_type') + @field_validator("value_type") def validate_value_type(cls, value): """ This validator checks if the provided value is equal to the default value of the 'value_type' field. If the value is different, a ValueError is raised. """ - if value != cls.model_fields['value_type'].default: + if value != cls.model_fields["value_type"].default: raise ValueError("Cannot modify 'value_type'") return value @@ -50,15 +50,15 @@ class NoneSegment(Segment): @property def text(self) -> str: - return 'null' + return "null" @property def log(self) -> str: - return 'null' + return "null" @property def markdown(self) -> str: - return 'null' + return "null" class StringSegment(Segment): @@ -76,24 +76,21 @@ class IntegerSegment(Segment): value: int - - - class ObjectSegment(Segment): value_type: SegmentType = SegmentType.OBJECT value: Mapping[str, Any] @property def text(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False) + return json.dumps(self.model_dump()["value"], ensure_ascii=False) @property def log(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) @property def markdown(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) class ArraySegment(Segment): @@ -101,11 +98,11 @@ class ArraySegment(Segment): def markdown(self) -> str: items = [] for item in self.value: - if hasattr(item, 'to_markdown'): + if hasattr(item, "to_markdown"): items.append(item.to_markdown()) else: items.append(str(item)) - return '\n'.join(items) + return "\n".join(items) class ArrayAnySegment(ArraySegment): @@ -126,4 +123,3 @@ class ArrayNumberSegment(ArraySegment): class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT value: Sequence[Mapping[str, Any]] - diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py index cdd2b0b4b0..9cf0856df5 100644 --- a/api/core/app/segments/types.py +++ b/api/core/app/segments/types.py @@ -2,14 +2,14 @@ from enum import Enum class SegmentType(str, Enum): - NONE = 'none' - NUMBER = 'number' - STRING = 'string' - SECRET = 'secret' - ARRAY_ANY = 'array[any]' - ARRAY_STRING = 'array[string]' - ARRAY_NUMBER = 'array[number]' - ARRAY_OBJECT = 'array[object]' - OBJECT = 'object' + NONE = "none" + NUMBER = "number" + STRING = "string" + SECRET = "secret" + ARRAY_ANY = "array[any]" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_OBJECT = "array[object]" + OBJECT = "object" - GROUP = 'group' + GROUP = "group" diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py index 8fef707fcf..f0e403ab8d 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/app/segments/variables.py @@ -23,11 +23,11 @@ class Variable(Segment): """ id: str = Field( - default='', + default="", description="Unique identity for variable. It's only used by environment variables now.", ) name: str - description: str = Field(default='', description='Description of the variable.') + description: str = Field(default="", description="Description of the variable.") class StringVariable(StringSegment, Variable): @@ -62,7 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable): pass - class SecretVariable(StringVariable): value_type: SegmentType = SegmentType.SECRET diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 2f74a180d1..49f58af12c 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline: _task_state: TaskState _application_generate_entity: AppGenerateEntity - def __init__(self, application_generate_entity: AppGenerateEntity, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool) -> None: + def __init__( + self, + application_generate_entity: AppGenerateEntity, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -61,18 +64,18 @@ class BasedGenerateTaskPipeline: e = event.error if isinstance(e, InvokeAuthorizationError): - err = InvokeAuthorizationError('Incorrect API key provided') + err = InvokeAuthorizationError("Incorrect API key provided") elif isinstance(e, InvokeError) or isinstance(e, ValueError): err = e else: - err = Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) if message: refetch_message = db.session.query(Message).filter(Message.id == message.id).first() if refetch_message: err_desc = self._error_to_desc(err) - refetch_message.status = 'error' + refetch_message.status = "error" refetch_message.error = err_desc db.session.commit() @@ -86,12 +89,14 @@ class BasedGenerateTaskPipeline: :return: """ if isinstance(e, QuotaExceededError): - return ("Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.") + return ( + "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) - message = getattr(e, 'description', str(e)) + message = getattr(e, "description", str(e)) if not message: - message = 'Internal Server Error, please contact support.' + message = "Internal Server Error, please contact support." return message @@ -101,10 +106,7 @@ class BasedGenerateTaskPipeline: :param e: exception :return: """ - return ErrorStreamResponse( - task_id=self._application_generate_entity.task_id, - err=e - ) + return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e) def _ping_stream_response(self) -> PingStreamResponse: """ @@ -125,11 +127,8 @@ class BasedGenerateTaskPipeline: return OutputModeration( tenant_id=app_config.tenant_id, app_id=app_config.app_id, - rule=ModerationRule( - type=sensitive_word_avoidance.type, - config=sensitive_word_avoidance.config - ), - queue_manager=self._queue_manager + rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config), + queue_manager=self._queue_manager, ) def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: @@ -143,8 +142,7 @@ class BasedGenerateTaskPipeline: self._output_moderation_handler.stop_thread() completion = self._output_moderation_handler.moderation_completion( - completion=completion, - public_event=False + completion=completion, public_event=False ) self._output_moderation_handler = None diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 8d91a507a9..61e920845c 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan """ EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: EasyUITaskState - _application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity - ] - def __init__(self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool) -> None: + _task_state: EasyUITaskState + _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] + + def __init__( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -101,18 +99,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan model=self._model_config.model, prompt_messages=[], message=AssistantPromptMessage(content=""), - usage=LLMUsage.empty_usage() + usage=LLMUsage.empty_usage(), ) ) self._conversation_name_generate_thread = None def process( - self, + self, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, - Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: """ Process generate task pipeline. @@ -125,22 +123,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, - self._application_generate_entity.query + self._conversation, self._application_generate_entity.query ) - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[ - ChatbotAppBlockingResponse, - CompletionAppBlockingResponse - ]: + def _to_blocking_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]: """ Process blocking response. :return: @@ -149,11 +143,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, MessageEndStreamResponse): - extras = { - 'usage': jsonable_encoder(self._task_state.llm_result.usage) - } + extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata if self._conversation.mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( @@ -164,8 +156,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message_id=self._message.id, answer=self._task_state.llm_result.message.content, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) else: response = ChatbotAppBlockingResponse( @@ -177,18 +169,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message_id=self._message.id, answer=self._task_state.llm_result.message.content, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) return response else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ - -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: """ To stream response. :return: @@ -198,14 +191,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan yield CompletionAppStreamResponse( message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) else: yield ChatbotAppStreamResponse( conversation_id=self._conversation.id, message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) def _listenAudioMsg(self, publisher, task_id: str): @@ -217,15 +210,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: tenant_id = self._application_generate_entity.app_config.tenant_id task_id = self._application_generate_entity.task_id publisher = None - text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech') - if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'): - publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None)) + text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech") + if ( + text_to_speech_dict + and text_to_speech_dict.get("autoPlay") == "enabled" + and text_to_speech_dict.get("enabled") + ): + publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): while True: audio_response = self._listenAudioMsg(publisher, task_id) @@ -250,14 +247,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan break else: start_listener_time = time.time() - yield MessageAudioStreamResponse(audio=audio.audio, - task_id=task_id) - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id) + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, - publisher: AppGeneratorTTSPublisher, - trace_manager: Optional[TraceQueueManager] = None + self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -333,9 +327,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message( - self, trace_manager: Optional[TraceQueueManager] = None - ) -> None: + def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: """ Save message. :return: @@ -347,31 +339,32 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( - self._model_config.mode, - self._task_state.llm_result.prompt_messages + self._model_config.mode, self._task_state.llm_result.prompt_messages ) self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit - self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \ - if llm_result.message.content else '' + self._message.answer = ( + PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) + if llm_result.message.content + else "" + ) self._message.answer_tokens = usage.completion_tokens self._message.answer_unit_price = usage.completion_unit_price self._message.answer_price_unit = usage.completion_price_unit self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.total_price = usage.total_price self._message.currency = usage.currency - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) db.session.commit() if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.MESSAGE_TRACE, - conversation_id=self._conversation.id, - message_id=self._message.id + TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id ) ) @@ -379,11 +372,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._message, application_generate_entity=self._application_generate_entity, conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in [ - AppMode.AGENT_CHAT, - AppMode.CHAT - ] and self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras + is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT] + and self._application_generate_entity.conversation_id is None, + extras=self._application_generate_entity.extras, ) def _handle_stop(self, event: QueueStopEvent) -> None: @@ -395,22 +386,17 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan model = model_config.model model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) # calculate num tokens prompt_tokens = 0 if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: - prompt_tokens = model_instance.get_llm_num_tokens( - self._task_state.llm_result.prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages) completion_tokens = 0 if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: - completion_tokens = model_instance.get_llm_num_tokens( - [self._task_state.llm_result.message] - ) + completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message]) credentials = model_config.credentials @@ -418,10 +404,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) self._task_state.llm_result.usage = model_type_instance._calc_response_usage( - model, - credentials, - prompt_tokens, - completion_tokens + model, credentials, prompt_tokens, completion_tokens ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -429,16 +412,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan Message end to stream response. :return: """ - self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage) + self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, - id=self._message.id, - **extras + task_id=self._application_generate_entity.task_id, id=self._message.id, **extras ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: @@ -449,9 +430,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan :return: """ return AgentMessageStreamResponse( - task_id=self._application_generate_entity.task_id, - id=message_id, - answer=answer + task_id=self._application_generate_entity.task_id, id=message_id, answer=answer ) def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]: @@ -461,9 +440,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan :return: """ agent_thought: MessageAgentThought = ( - db.session.query(MessageAgentThought) - .filter(MessageAgentThought.id == event.agent_thought_id) - .first() + db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() ) db.session.refresh(agent_thought) db.session.close() @@ -478,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan tool=agent_thought.tool, tool_labels=agent_thought.tool_labels, tool_input=agent_thought.tool_input, - message_files=agent_thought.files + message_files=agent_thought.files, ) return None @@ -500,15 +477,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan prompt_messages=self._task_state.llm_result.prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) - ) + message=AssistantPromptMessage(content=self._task_state.llm_result.message.content), + ), ) - ), PublishFrom.TASK_PIPELINE + ), + PublishFrom.TASK_PIPELINE, ) self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 8ff50dd174..011daba687 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -30,10 +30,7 @@ from services.annotation_service import AppAnnotationService class MessageCycleManage: _application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity + ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity ] _task_state: Union[EasyUITaskState, WorkflowTaskState] @@ -49,15 +46,18 @@ class MessageCycleManage: is_first_message = self._application_generate_entity.conversation_id is None extras = self._application_generate_entity.extras - auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) + auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True) if auto_generate_conversation_name and is_first_message: # start generate thread - thread = Thread(target=self._generate_conversation_name_worker, kwargs={ - 'flask_app': current_app._get_current_object(), # type: ignore - 'conversation_id': conversation.id, - 'query': query - }) + thread = Thread( + target=self._generate_conversation_name_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "conversation_id": conversation.id, + "query": query, + }, + ) thread.start() @@ -65,17 +65,10 @@ class MessageCycleManage: return None - def _generate_conversation_name_worker(self, - flask_app: Flask, - conversation_id: str, - query: str): + def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): with flask_app.app_context(): # get conversation and message - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if not conversation: return @@ -105,12 +98,9 @@ class MessageCycleManage: annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account - self._task_state.metadata['annotation_reply'] = { - 'id': annotation.id, - 'account': { - 'id': annotation.account_id, - 'name': account.name if account else 'Dify user' - } + self._task_state.metadata["annotation_reply"] = { + "id": annotation.id, + "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, } return annotation @@ -124,7 +114,7 @@ class MessageCycleManage: :return: """ if self._application_generate_entity.app_config.additional_features.show_retrieve_source: - self._task_state.metadata['retriever_resources'] = event.retriever_resources + self._task_state.metadata["retriever_resources"] = event.retriever_resources def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: """ @@ -132,27 +122,23 @@ class MessageCycleManage: :param event: event :return: """ - message_file = ( - db.session.query(MessageFile) - .filter(MessageFile.id == event.message_file_id) - .first() - ) + message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() if message_file: # get tool file id - tool_file_id = message_file.url.split('/')[-1] + tool_file_id = message_file.url.split("/")[-1] # trim extension - tool_file_id = tool_file_id.split('.')[0] + tool_file_id = tool_file_id.split(".")[0] # get extension - if '.' in message_file.url: + if "." in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' if len(extension) > 10: - extension = '.bin' + extension = ".bin" else: - extension = '.bin' + extension = ".bin" # add sign url to local file - if message_file.url.startswith('http'): + if message_file.url.startswith("http"): url = message_file.url else: url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension) @@ -161,8 +147,8 @@ class MessageCycleManage: task_id=self._application_generate_entity.task_id, id=message_file.id, type=message_file.type, - belongs_to=message_file.belongs_to or 'user', - url=url + belongs_to=message_file.belongs_to or "user", + url=url, ) return None @@ -174,11 +160,7 @@ class MessageCycleManage: :param message_id: message id :return: """ - return MessageStreamResponse( - task_id=self._application_generate_entity.task_id, - id=message_id, - answer=answer - ) + return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer) def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse: """ @@ -186,7 +168,4 @@ class MessageCycleManage: :param answer: answer :return: """ - return MessageReplaceStreamResponse( - task_id=self._application_generate_entity.task_id, - answer=answer - ) + return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index ed3225310a..a030d5dcbf 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -70,14 +70,14 @@ class WorkflowCycleManage: inputs = {**self._application_generate_entity.inputs} for key, value in (self._workflow_system_variables or {}).items(): - if key.value == 'conversation': + if key.value == "conversation": continue - inputs[f'sys.{key.value}'] = value + inputs[f"sys.{key.value}"] = value inputs = WorkflowEntry.handle_special_values(inputs) - triggered_from= ( + triggered_from = ( WorkflowRunTriggeredFrom.DEBUGGING if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN @@ -185,20 +185,26 @@ class WorkflowCycleManage: db.session.commit() - running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, - WorkflowNodeExecution.app_id == workflow_run.app_id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == workflow_run.id, - WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value - ).all() + running_workflow_node_executions = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, + ) + .all() + ) for workflow_node_execution in running_workflow_node_executions: workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) - workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds() + workflow_node_execution.elapsed_time = ( + workflow_node_execution.finished_at - workflow_node_execution.created_at + ).total_seconds() db.session.commit() db.session.refresh(workflow_run) @@ -216,7 +222,9 @@ class WorkflowCycleManage: return workflow_run - def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: + def _handle_node_execution_start( + self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + ) -> WorkflowNodeExecution: # init workflow node execution workflow_node_execution = WorkflowNodeExecution() workflow_node_execution.tenant_id = workflow_run.tenant_id @@ -333,16 +341,16 @@ class WorkflowCycleManage: created_by_account = workflow_run.created_by_account if created_by_account: created_by = { - 'id': created_by_account.id, - 'name': created_by_account.name, - 'email': created_by_account.email, + "id": created_by_account.id, + "name": created_by_account.name, + "email": created_by_account.email, } else: created_by_end_user = workflow_run.created_by_end_user if created_by_end_user: created_by = { - 'id': created_by_end_user.id, - 'user': created_by_end_user.session_id, + "id": created_by_end_user.id, + "user": created_by_end_user.session_id, } return WorkflowFinishStreamResponse( @@ -401,7 +409,7 @@ class WorkflowCycleManage: # extras logic if event.node_type == NodeType.TOOL: node_data = cast(ToolNodeData, event.node_data) - response.data.extras['icon'] = ToolManager.get_tool_icon( + response.data.extras["icon"] = ToolManager.get_tool_icon( tenant_id=self._application_generate_entity.app_config.tenant_id, provider_type=node_data.provider_type, provider_id=node_data.provider_id, @@ -410,10 +418,10 @@ class WorkflowCycleManage: return response def _workflow_node_finish_to_stream_response( - self, - event: QueueNodeSucceededEvent | QueueNodeFailedEvent, - task_id: str, - workflow_node_execution: WorkflowNodeExecution + self, + event: QueueNodeSucceededEvent | QueueNodeFailedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: """ Workflow node finish to stream response. @@ -424,7 +432,7 @@ class WorkflowCycleManage: """ if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: return None - + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -452,13 +460,10 @@ class WorkflowCycleManage: iteration_id=event.in_iteration_id, ), ) - + def _workflow_parallel_branch_start_to_stream_response( - self, - task_id: str, - workflow_run: WorkflowRun, - event: QueueParallelBranchRunStartedEvent - ) -> ParallelBranchStartStreamResponse: + self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent + ) -> ParallelBranchStartStreamResponse: """ Workflow parallel branch start to stream response :param task_id: task id @@ -476,15 +481,15 @@ class WorkflowCycleManage: parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, created_at=int(time.time()), - ) + ), ) - + def _workflow_parallel_branch_finished_to_stream_response( - self, - task_id: str, - workflow_run: WorkflowRun, - event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent - ) -> ParallelBranchFinishedStreamResponse: + self, + task_id: str, + workflow_run: WorkflowRun, + event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, + ) -> ParallelBranchFinishedStreamResponse: """ Workflow parallel branch finished to stream response :param task_id: task id @@ -501,18 +506,15 @@ class WorkflowCycleManage: parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, - status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed', + status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed", error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, created_at=int(time.time()), - ) + ), ) def _workflow_iteration_start_to_stream_response( - self, - task_id: str, - workflow_run: WorkflowRun, - event: QueueIterationStartEvent - ) -> IterationNodeStartStreamResponse: + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent + ) -> IterationNodeStartStreamResponse: """ Workflow iteration start to stream response :param task_id: task id @@ -534,10 +536,12 @@ class WorkflowCycleManage: metadata=event.metadata or {}, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, - ) + ), ) - def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse: + def _workflow_iteration_next_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent + ) -> IterationNodeNextStreamResponse: """ Workflow iteration next to stream response :param task_id: task id @@ -559,10 +563,12 @@ class WorkflowCycleManage: extras={}, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, - ) + ), ) - def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse: + def _workflow_iteration_completed_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent + ) -> IterationNodeCompletedStreamResponse: """ Workflow iteration completed to stream response :param task_id: task id @@ -585,13 +591,13 @@ class WorkflowCycleManage: status=WorkflowNodeExecutionStatus.SUCCEEDED, error=None, elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), - total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0, + total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, execution_metadata=event.metadata, finished_at=int(time.time()), steps=event.steps, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, - ) + ), ) def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: @@ -643,7 +649,7 @@ class WorkflowCycleManage: return None if isinstance(value, dict): - if '__variant' in value and value['__variant'] == FileVar.__name__: + if "__variant" in value and value["__variant"] == FileVar.__name__: return value elif isinstance(value, FileVar): return value.to_dict() @@ -656,11 +662,10 @@ class WorkflowCycleManage: :param workflow_run_id: workflow run id :return: """ - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == workflow_run_id).first() + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() if not workflow_run: - raise Exception(f'Workflow run not found: {workflow_run_id}') + raise Exception(f"Workflow run not found: {workflow_run_id}") return workflow_run @@ -683,6 +688,6 @@ class WorkflowCycleManage: ) if not workflow_node_execution: - raise Exception(f'Workflow node execution not found: {node_execution_id}') + raise Exception(f"Workflow node execution not found: {node_execution_id}") - return workflow_node_execution \ No newline at end of file + return workflow_node_execution diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 5789965747..99e992fd89 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -16,31 +16,32 @@ _TEXT_COLOR_MAPPING = { "red": "31;1", } + def get_colored_text(text: str, color: str) -> str: """Get colored text.""" color_str = _TEXT_COLOR_MAPPING[color] return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" -def print_text( - text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None -) -> None: +def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None: """Print text with highlighting and no end characters.""" text_to_print = get_colored_text(text, color) if color else text print(text_to_print, end=end, file=file) if file: file.flush() # ensure all printed content are written to file + class DifyAgentCallbackHandler(BaseModel): """Callback Handler that prints to std out.""" - color: Optional[str] = '' + + color: Optional[str] = "" current_loop: int = 1 def __init__(self, color: Optional[str] = None) -> None: super().__init__() """Initialize callback handler.""" # use a specific color is not specified - self.color = color or 'green' + self.color = color or "green" self.current_loop = 1 def on_tool_start( @@ -58,7 +59,7 @@ class DifyAgentCallbackHandler(BaseModel): tool_outputs: Sequence[ToolInvokeMessage], message_id: Optional[str] = None, timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> None: """If not the final action, print out observation.""" print_text("\n[on_tool_end]\n", color=self.color) @@ -79,26 +80,21 @@ class DifyAgentCallbackHandler(BaseModel): ) ) - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: """Do nothing.""" - print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red') + print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") - def on_agent_start( - self, thought: str - ) -> None: + def on_agent_start(self, thought: str) -> None: """Run on agent start.""" if thought: - print_text("\n[on_agent_start] \nCurrent Loop: " + \ - str(self.current_loop) + \ - "\nThought: " + thought + "\n", color=self.color) + print_text( + "\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n", + color=self.color, + ) else: print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) - def on_agent_finish( - self, color: Optional[str] = None, **kwargs: Any - ) -> None: + def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None: """Run on agent end.""" print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) @@ -107,9 +103,9 @@ class DifyAgentCallbackHandler(BaseModel): @property def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true" @property def ignore_chat_model(self) -> bool: """Whether to ignore chat model callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true" diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 8e1f496b22..50cde18c54 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,4 +1,3 @@ - from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent @@ -11,11 +10,9 @@ from models.model import DatasetRetrieverResource class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" - def __init__(self, queue_manager: AppQueueManager, - app_id: str, - message_id: str, - user_id: str, - invoke_from: InvokeFrom) -> None: + def __init__( + self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom + ) -> None: self._queue_manager = queue_manager self._app_id = app_id self._message_id = message_id @@ -29,11 +26,12 @@ class DatasetIndexToolCallbackHandler: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source='app', + source="app", source_app_id=self._app_id, - created_by_role=('account' - if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), - created_by=self._user_id + created_by_role=( + "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" + ), + created_by=self._user_id, ) db.session.add(dataset_query) @@ -43,18 +41,15 @@ class DatasetIndexToolCallbackHandler: """Handle tool end.""" for document in documents: query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata['doc_id'] + DocumentSegment.index_node_id == document.metadata["doc_id"] ) # if 'dataset_id' in document.metadata: - if 'dataset_id' in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False - ) + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) db.session.commit() @@ -64,26 +59,25 @@ class DatasetIndexToolCallbackHandler: for item in resource: dataset_retriever_resource = DatasetRetrieverResource( message_id=self._message_id, - position=item.get('position'), - dataset_id=item.get('dataset_id'), - dataset_name=item.get('dataset_name'), - document_id=item.get('document_id'), - document_name=item.get('document_name'), - data_source_type=item.get('data_source_type'), - segment_id=item.get('segment_id'), - score=item.get('score') if 'score' in item else None, - hit_count=item.get('hit_count') if 'hit_count' else None, - word_count=item.get('word_count') if 'word_count' in item else None, - segment_position=item.get('segment_position') if 'segment_position' in item else None, - index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None, - content=item.get('content'), - retriever_from=item.get('retriever_from'), - created_by=self._user_id + position=item.get("position"), + dataset_id=item.get("dataset_id"), + dataset_name=item.get("dataset_name"), + document_id=item.get("document_id"), + document_name=item.get("document_name"), + data_source_type=item.get("data_source_type"), + segment_id=item.get("segment_id"), + score=item.get("score") if "score" in item else None, + hit_count=item.get("hit_count") if "hit_count" else None, + word_count=item.get("word_count") if "word_count" in item else None, + segment_position=item.get("segment_position") if "segment_position" in item else None, + index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None, + content=item.get("content"), + retriever_from=item.get("retriever_from"), + created_by=self._user_id, ) db.session.add(dataset_retriever_resource) db.session.commit() self._queue_manager.publish( - QueueRetrieverResourcesEvent(retriever_resources=resource), - PublishFrom.APPLICATION_MANAGER + QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/callback_handler/workflow_tool_callback_handler.py b/api/core/callback_handler/workflow_tool_callback_handler.py index 84bab7e1a3..8ac12f72f2 100644 --- a/api/core/callback_handler/workflow_tool_callback_handler.py +++ b/api/core/callback_handler/workflow_tool_callback_handler.py @@ -2,4 +2,4 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): - """Callback Handler that prints to std out.""" \ No newline at end of file + """Callback Handler that prints to std out.""" diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index b7e0cc0c2b..4cc793b0d7 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -29,9 +29,13 @@ class CacheEmbedding(Embeddings): embedding_queue_indices = [] for i, text in enumerate(texts): hash = helper.generate_text_hash(text) - embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, - hash=hash, - provider_name=self._model_instance.provider).first() + embedding = ( + db.session.query(Embedding) + .filter_by( + model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider + ) + .first() + ) if embedding: text_embeddings[i] = embedding.get_embedding() else: @@ -41,17 +45,18 @@ class CacheEmbedding(Embeddings): embedding_queue_embeddings = [] try: model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) - model_schema = model_type_instance.get_model_schema(self._model_instance.model, - self._model_instance.credentials) - max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 + model_schema = model_type_instance.get_model_schema( + self._model_instance.model, self._model_instance.credentials + ) + max_chunks = ( + model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties + else 1 + ) for i in range(0, len(embedding_queue_texts), max_chunks): - batch_texts = embedding_queue_texts[i:i + max_chunks] + batch_texts = embedding_queue_texts[i : i + max_chunks] - embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, - user=self._user - ) + embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user) for vector in embedding_result.embeddings: try: @@ -60,16 +65,18 @@ class CacheEmbedding(Embeddings): except IntegrityError: db.session.rollback() except Exception as e: - logging.exception('Failed transform embedding: ', e) + logging.exception("Failed transform embedding: ", e) cache_embeddings = [] try: for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): text_embeddings[i] = embedding hash = helper.generate_text_hash(texts[i]) if hash not in cache_embeddings: - embedding_cache = Embedding(model_name=self._model_instance.model, - hash=hash, - provider_name=self._model_instance.provider) + embedding_cache = Embedding( + model_name=self._model_instance.model, + hash=hash, + provider_name=self._model_instance.provider, + ) embedding_cache.set_embedding(embedding) db.session.add(embedding_cache) cache_embeddings.append(hash) @@ -78,7 +85,7 @@ class CacheEmbedding(Embeddings): db.session.rollback() except Exception as ex: db.session.rollback() - logger.error('Failed to embed documents: ', ex) + logger.error("Failed to embed documents: ", ex) raise ex return text_embeddings @@ -87,16 +94,13 @@ class CacheEmbedding(Embeddings): """Embed query text.""" # use doc embedding cache or store if not exists hash = helper.generate_text_hash(text) - embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' + embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}" embedding = redis_client.get(embedding_cache_key) if embedding: redis_client.expire(embedding_cache_key, 600) return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) try: - embedding_result = self._model_instance.invoke_text_embedding( - texts=[text], - user=self._user - ) + embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user) embedding_results = embedding_result.embeddings[0] embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() @@ -116,6 +120,6 @@ class CacheEmbedding(Embeddings): except IntegrityError: db.session.rollback() except: - logging.exception('Failed to add embedding to redis') + logging.exception("Failed to add embedding to redis") return embedding_results diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py index 0cdf8670c4..656bf4aa72 100644 --- a/api/core/entities/agent_entities.py +++ b/api/core/entities/agent_entities.py @@ -2,7 +2,7 @@ from enum import Enum class PlanningStrategy(Enum): - ROUTER = 'router' - REACT_ROUTER = 'react_router' - REACT = 'react' - FUNCTION_CALL = 'function_call' + ROUTER = "router" + REACT_ROUTER = "react_router" + REACT = "react" + FUNCTION_CALL = "function_call" diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py index 370aeee463..10bc9f6ed7 100644 --- a/api/core/entities/message_entities.py +++ b/api/core/entities/message_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel class PromptMessageFileType(enum.Enum): - IMAGE = 'image' + IMAGE = "image" @staticmethod def value_of(value): @@ -22,8 +22,8 @@ class PromptMessageFile(BaseModel): class ImagePromptMessageFile(PromptMessageFile): class DETAIL(enum.Enum): - LOW = 'low' - HIGH = 'high' + LOW = "low" + HIGH = "high" type: PromptMessageFileType = PromptMessageFileType.IMAGE detail: DETAIL = DETAIL.LOW diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 22a21ecf93..9ed5528e43 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -12,6 +12,7 @@ class ModelStatus(Enum): """ Enum class for model status. """ + ACTIVE = "active" NO_CONFIGURE = "no-configure" QUOTA_EXCEEDED = "quota-exceeded" @@ -23,6 +24,7 @@ class SimpleModelProviderEntity(BaseModel): """ Simple provider. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -40,7 +42,7 @@ class SimpleModelProviderEntity(BaseModel): label=provider_entity.label, icon_small=provider_entity.icon_small, icon_large=provider_entity.icon_large, - supported_model_types=provider_entity.supported_model_types + supported_model_types=provider_entity.supported_model_types, ) @@ -48,6 +50,7 @@ class ProviderModelWithStatusEntity(ProviderModel): """ Model class for model response. """ + status: ModelStatus load_balancing_enabled: bool = False @@ -56,6 +59,7 @@ class ModelWithProviderEntity(ProviderModelWithStatusEntity): """ Model with provider entity. """ + provider: SimpleModelProviderEntity @@ -63,6 +67,7 @@ class DefaultModelProviderEntity(BaseModel): """ Default model provider entity. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -74,6 +79,7 @@ class DefaultModelEntity(BaseModel): """ Default model entity. """ + model: str model_type: ModelType provider: DefaultModelProviderEntity diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 778ef2e1ac..4797b69b85 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -47,6 +47,7 @@ class ProviderConfiguration(BaseModel): """ Model class for provider configuration. """ + tenant_id: str provider: ProviderEntity preferred_provider_type: ProviderType @@ -67,9 +68,13 @@ class ProviderConfiguration(BaseModel): original_provider_configurate_methods[self.provider.provider].append(configurate_method) if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: - if (any(len(quota_configuration.restrict_models) > 0 - for quota_configuration in self.system_configuration.quota_configurations) - and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): + if ( + any( + len(quota_configuration.restrict_models) > 0 + for quota_configuration in self.system_configuration.quota_configurations + ) + and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods + ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: @@ -83,10 +88,9 @@ class ProviderConfiguration(BaseModel): if self.model_settings: # check if model is disabled by admin for model_setting in self.model_settings: - if (model_setting.model_type == model_type - and model_setting.model == model): + if model_setting.model_type == model_type and model_setting.model == model: if not model_setting.enabled: - raise ValueError(f'Model {model} is disabled.') + raise ValueError(f"Model {model} is disabled.") if self.using_provider_type == ProviderType.SYSTEM: restrict_models = [] @@ -99,10 +103,12 @@ class ProviderConfiguration(BaseModel): copy_credentials = self.system_configuration.credentials.copy() if restrict_models: for restrict_model in restrict_models: - if (restrict_model.model_type == model_type - and restrict_model.model == model - and restrict_model.base_model_name): - copy_credentials['base_model_name'] = restrict_model.base_model_name + if ( + restrict_model.model_type == model_type + and restrict_model.model == model + and restrict_model.base_model_name + ): + copy_credentials["base_model_name"] = restrict_model.base_model_name return copy_credentials else: @@ -128,20 +134,21 @@ class ProviderConfiguration(BaseModel): current_quota_type = self.system_configuration.current_quota_type current_quota_configuration = next( - (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), - None + (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) - return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \ - SystemConfigurationStatus.QUOTA_EXCEEDED + return ( + SystemConfigurationStatus.ACTIVE + if current_quota_configuration.is_valid + else SystemConfigurationStatus.QUOTA_EXCEEDED + ) def is_custom_configuration_available(self) -> bool: """ Check custom configuration available. :return: """ - return (self.custom_configuration.provider is not None - or len(self.custom_configuration.models) > 0) + return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: """ @@ -161,7 +168,8 @@ class ProviderConfiguration(BaseModel): return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema else [] + if self.provider.provider_credential_schema + else [], ) def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: @@ -171,17 +179,21 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.provider.provider, + Provider.provider_type == ProviderType.CUSTOM.value, + ) + .first() + ) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema else [] + if self.provider.provider_credential_schema + else [] ) if provider_record: @@ -189,9 +201,7 @@ class ProviderConfiguration(BaseModel): # fix origin data if provider_record.encrypted_config: if not provider_record.encrypted_config.startswith("{"): - original_credentials = { - "openai_api_key": provider_record.encrypted_config - } + original_credentials = {"openai_api_key": provider_record.encrypted_config} else: original_credentials = json.loads(provider_record.encrypted_config) else: @@ -207,8 +217,7 @@ class ProviderConfiguration(BaseModel): credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, - credentials=credentials + provider=self.provider.provider, credentials=credentials ) for key, value in credentials.items(): @@ -239,15 +248,13 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, provider_type=ProviderType.CUSTOM.value, encrypted_config=json.dumps(credentials), - is_valid=True + is_valid=True, ) db.session.add(provider_record) db.session.commit() provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER ) provider_model_credentials_cache.delete() @@ -260,12 +267,15 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.provider.provider, + Provider.provider_type == ProviderType.CUSTOM.value, + ) + .first() + ) # delete provider if provider_record: @@ -277,13 +287,14 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() - def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ - -> Optional[dict]: + def get_custom_model_credentials( + self, model_type: ModelType, model: str, obfuscated: bool = False + ) -> Optional[dict]: """ Get custom model credentials. @@ -305,13 +316,15 @@ class ProviderConfiguration(BaseModel): return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema else [] + if self.provider.model_credential_schema + else [], ) return None - def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \ - -> tuple[ProviderModel, dict]: + def custom_model_credentials_validate( + self, model_type: ModelType, model: str, credentials: dict + ) -> tuple[ProviderModel, dict]: """ Validate custom model credentials. @@ -321,24 +334,29 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider model - provider_model_record = db.session.query(ProviderModel) \ + provider_model_record = ( + db.session.query(ProviderModel) .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type() - ).first() + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name == self.provider.provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema else [] + if self.provider.model_credential_schema + else [] ) if provider_model_record: try: - original_credentials = json.loads( - provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} + original_credentials = ( + json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} + ) except JSONDecodeError: original_credentials = {} @@ -350,10 +368,7 @@ class ProviderConfiguration(BaseModel): credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, - model_type=model_type, - model=model, - credentials=credentials + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) for key, value in credentials.items(): @@ -388,7 +403,7 @@ class ProviderConfiguration(BaseModel): model_name=model, model_type=model_type.to_origin_model_type(), encrypted_config=json.dumps(credentials), - is_valid=True + is_valid=True, ) db.session.add(provider_model_record) db.session.commit() @@ -396,7 +411,7 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() @@ -409,13 +424,16 @@ class ProviderConfiguration(BaseModel): :return: """ # get provider model - provider_model_record = db.session.query(ProviderModel) \ + provider_model_record = ( + db.session.query(ProviderModel) .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type() - ).first() + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name == self.provider.provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # delete provider model if provider_model_record: @@ -425,7 +443,7 @@ class ProviderConfiguration(BaseModel): provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() @@ -437,13 +455,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.enabled = True @@ -455,7 +476,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - enabled=True + enabled=True, ) db.session.add(model_setting) db.session.commit() @@ -469,13 +490,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.enabled = False @@ -487,7 +511,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - enabled=False + enabled=False, ) db.session.add(model_setting) db.session.commit() @@ -501,13 +525,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - return db.session.query(ProviderModelSetting) \ + return ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ @@ -516,24 +543,30 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \ + load_balancing_config_count = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).count() + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .count() + ) if load_balancing_config_count <= 1: - raise ValueError('Model load balancing configuration must be more than 1.') + raise ValueError("Model load balancing configuration must be more than 1.") - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.load_balancing_enabled = True @@ -545,7 +578,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - load_balancing_enabled=True + load_balancing_enabled=True, ) db.session.add(model_setting) db.session.commit() @@ -559,13 +592,16 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.load_balancing_enabled = False @@ -577,7 +613,7 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - load_balancing_enabled=False + load_balancing_enabled=False, ) db.session.add(model_setting) db.session.commit() @@ -617,11 +653,14 @@ class ProviderConfiguration(BaseModel): return # get preferred provider - preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ + preferred_model_provider = ( + db.session.query(TenantPreferredModelProvider) .filter( - TenantPreferredModelProvider.tenant_id == self.tenant_id, - TenantPreferredModelProvider.provider_name == self.provider.provider - ).first() + TenantPreferredModelProvider.tenant_id == self.tenant_id, + TenantPreferredModelProvider.provider_name == self.provider.provider, + ) + .first() + ) if preferred_model_provider: preferred_model_provider.preferred_provider_type = provider_type.value @@ -629,7 +668,7 @@ class ProviderConfiguration(BaseModel): preferred_model_provider = TenantPreferredModelProvider( tenant_id=self.tenant_id, provider_name=self.provider.provider, - preferred_provider_type=provider_type.value + preferred_provider_type=provider_type.value, ) db.session.add(preferred_model_provider) @@ -658,9 +697,7 @@ class ProviderConfiguration(BaseModel): :return: """ # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables( - credential_form_schemas - ) + credential_secret_variables = self.extract_secret_variables(credential_form_schemas) # Obfuscate provider credentials copy_credentials = credentials.copy() @@ -670,9 +707,9 @@ class ProviderConfiguration(BaseModel): return copy_credentials - def get_provider_model(self, model_type: ModelType, - model: str, - only_active: bool = False) -> Optional[ModelWithProviderEntity]: + def get_provider_model( + self, model_type: ModelType, model: str, only_active: bool = False + ) -> Optional[ModelWithProviderEntity]: """ Get provider model. :param model_type: model type @@ -688,8 +725,9 @@ class ProviderConfiguration(BaseModel): return None - def get_provider_models(self, model_type: Optional[ModelType] = None, - only_active: bool = False) -> list[ModelWithProviderEntity]: + def get_provider_models( + self, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: """ Get provider models. :param model_type: model type @@ -711,15 +749,11 @@ class ProviderConfiguration(BaseModel): if self.using_provider_type == ProviderType.SYSTEM: provider_models = self._get_system_provider_models( - model_types=model_types, - provider_instance=provider_instance, - model_setting_map=model_setting_map + model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map ) else: provider_models = self._get_custom_provider_models( - model_types=model_types, - provider_instance=provider_instance, - model_setting_map=model_setting_map + model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map ) if only_active: @@ -728,11 +762,12 @@ class ProviderConfiguration(BaseModel): # resort provider_models return sorted(provider_models, key=lambda x: x.model_type.value) - def _get_system_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider, - model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ - -> list[ModelWithProviderEntity]: + def _get_system_provider_models( + self, + model_types: list[ModelType], + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]], + ) -> list[ModelWithProviderEntity]: """ Get system provider models. @@ -760,7 +795,7 @@ class ProviderConfiguration(BaseModel): model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=status + status=status, ) ) @@ -783,23 +818,20 @@ class ProviderConfiguration(BaseModel): if should_use_custom_model: if original_provider_configurate_methods[self.provider.provider] == [ - ConfigurateMethod.CUSTOMIZABLE_MODEL]: + ConfigurateMethod.CUSTOMIZABLE_MODEL + ]: # only customizable model for restrict_model in restrict_models: copy_credentials = self.system_configuration.credentials.copy() if restrict_model.base_model_name: - copy_credentials['base_model_name'] = restrict_model.base_model_name + copy_credentials["base_model_name"] = restrict_model.base_model_name try: - custom_model_schema = ( - provider_instance.get_model_instance(restrict_model.model_type) - .get_customizable_model_schema_from_credentials( - restrict_model.model, - copy_credentials - ) - ) + custom_model_schema = provider_instance.get_model_instance( + restrict_model.model_type + ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') + logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: @@ -809,8 +841,10 @@ class ProviderConfiguration(BaseModel): continue status = ModelStatus.ACTIVE - if (custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED @@ -825,7 +859,7 @@ class ProviderConfiguration(BaseModel): model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=status + status=status, ) ) @@ -839,11 +873,12 @@ class ProviderConfiguration(BaseModel): return provider_models - def _get_custom_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider, - model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ - -> list[ModelWithProviderEntity]: + def _get_custom_provider_models( + self, + model_types: list[ModelType], + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]], + ) -> list[ModelWithProviderEntity]: """ Get custom provider models. @@ -885,7 +920,7 @@ class ProviderConfiguration(BaseModel): deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, - load_balancing_enabled=load_balancing_enabled + load_balancing_enabled=load_balancing_enabled, ) ) @@ -895,15 +930,13 @@ class ProviderConfiguration(BaseModel): continue try: - custom_model_schema = ( - provider_instance.get_model_instance(model_configuration.model_type) - .get_customizable_model_schema_from_credentials( - model_configuration.model, - model_configuration.credentials - ) + custom_model_schema = provider_instance.get_model_instance( + model_configuration.model_type + ).get_customizable_model_schema_from_credentials( + model_configuration.model, model_configuration.credentials ) except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') + logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: @@ -911,8 +944,10 @@ class ProviderConfiguration(BaseModel): status = ModelStatus.ACTIVE load_balancing_enabled = False - if (custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED @@ -931,7 +966,7 @@ class ProviderConfiguration(BaseModel): deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, - load_balancing_enabled=load_balancing_enabled + load_balancing_enabled=load_balancing_enabled, ) ) @@ -942,17 +977,16 @@ class ProviderConfigurations(BaseModel): """ Model class for provider configuration dict. """ + tenant_id: str configurations: dict[str, ProviderConfiguration] = {} def __init__(self, tenant_id: str): super().__init__(tenant_id=tenant_id) - def get_models(self, - provider: Optional[str] = None, - model_type: Optional[ModelType] = None, - only_active: bool = False) \ - -> list[ModelWithProviderEntity]: + def get_models( + self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: """ Get available models. @@ -1019,10 +1053,10 @@ class ProviderModelBundle(BaseModel): """ Provider model bundle. """ + configuration: ProviderConfiguration provider_instance: ModelProvider model_type_instance: AIModel # pydantic configs - model_config = ConfigDict(arbitrary_types_allowed=True, - protected_namespaces=()) + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 0d5b0a1b2c..44725623dc 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -8,18 +8,19 @@ from models.provider import ProviderQuotaType class QuotaUnit(Enum): - TIMES = 'times' - TOKENS = 'tokens' - CREDITS = 'credits' + TIMES = "times" + TOKENS = "tokens" + CREDITS = "credits" class SystemConfigurationStatus(Enum): """ Enum class for system configuration status. """ - ACTIVE = 'active' - QUOTA_EXCEEDED = 'quota-exceeded' - UNSUPPORTED = 'unsupported' + + ACTIVE = "active" + QUOTA_EXCEEDED = "quota-exceeded" + UNSUPPORTED = "unsupported" class RestrictModel(BaseModel): @@ -35,6 +36,7 @@ class QuotaConfiguration(BaseModel): """ Model class for provider quota configuration. """ + quota_type: ProviderQuotaType quota_unit: QuotaUnit quota_limit: int @@ -47,6 +49,7 @@ class SystemConfiguration(BaseModel): """ Model class for provider system configuration. """ + enabled: bool current_quota_type: Optional[ProviderQuotaType] = None quota_configurations: list[QuotaConfiguration] = [] @@ -57,6 +60,7 @@ class CustomProviderConfiguration(BaseModel): """ Model class for provider custom configuration. """ + credentials: dict @@ -64,6 +68,7 @@ class CustomModelConfiguration(BaseModel): """ Model class for provider custom model configuration. """ + model: str model_type: ModelType credentials: dict @@ -76,6 +81,7 @@ class CustomConfiguration(BaseModel): """ Model class for provider custom configuration. """ + provider: Optional[CustomProviderConfiguration] = None models: list[CustomModelConfiguration] = [] @@ -84,6 +90,7 @@ class ModelLoadBalancingConfiguration(BaseModel): """ Class for model load balancing configuration. """ + id: str name: str credentials: dict @@ -93,6 +100,7 @@ class ModelSettings(BaseModel): """ Model class for model settings. """ + model: str model_type: ModelType enabled: bool = True diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 53323a2eeb..3b186476eb 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -3,6 +3,7 @@ from typing import Optional class LLMError(Exception): """Base class for all LLM exceptions.""" + description: Optional[str] = None def __init__(self, description: Optional[str] = None) -> None: @@ -11,6 +12,7 @@ class LLMError(Exception): class LLMBadRequestError(LLMError): """Raised when the LLM returns bad request.""" + description = "Bad Request" @@ -18,6 +20,7 @@ class ProviderTokenNotInitError(Exception): """ Custom exception raised when the provider token is not initialized. """ + description = "Provider Token Not Init" def __init__(self, *args, **kwargs): @@ -28,6 +31,7 @@ class QuotaExceededError(Exception): """ Custom exception raised when the quota for a provider has been exceeded. """ + description = "Quota Exceeded" @@ -35,6 +39,7 @@ class AppInvokeQuotaExceededError(Exception): """ Custom exception raised when the quota for an app has been exceeded. """ + description = "App Invoke Quota Exceeded" @@ -42,9 +47,11 @@ class ModelCurrentlyNotSupportError(Exception): """ Custom exception raised when the model not support """ + description = "Model Currently Not Support" class InvokeRateLimitError(Exception): """Raised when the Invoke returns rate limit error.""" + description = "Rate Limit Error" diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 4db7a99973..38cebb6b6b 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -20,10 +20,7 @@ class APIBasedExtensionRequestor: :param params: the request params :return: the response json """ - headers = { - "Content-Type": "application/json", - "Authorization": "Bearer {}".format(self.api_key) - } + headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)} url = self.api_endpoint @@ -32,20 +29,17 @@ class APIBasedExtensionRequestor: proxies = None if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: proxies = { - 'http': dify_config.SSRF_PROXY_HTTP_URL, - 'https': dify_config.SSRF_PROXY_HTTPS_URL, + "http": dify_config.SSRF_PROXY_HTTP_URL, + "https": dify_config.SSRF_PROXY_HTTPS_URL, } response = requests.request( - method='POST', + method="POST", url=url, - json={ - 'point': point.value, - 'params': params - }, + json={"point": point.value, "params": params}, headers=headers, timeout=self.timeout, - proxies=proxies + proxies=proxies, ) except requests.exceptions.Timeout: raise ValueError("request timeout") @@ -53,9 +47,8 @@ class APIBasedExtensionRequestor: raise ValueError("request connection error") if response.status_code != 200: - raise ValueError("request error, status_code: {}, content: {}".format( - response.status_code, - response.text[:100] - )) + raise ValueError( + "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) + ) return response.json() diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 8d73aa2b8b..f1a49c4921 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -11,8 +11,8 @@ from core.helper.position_helper import sort_to_dict_by_position_map class ExtensionModule(enum.Enum): - MODERATION = 'moderation' - EXTERNAL_DATA_TOOL = 'external_data_tool' + MODERATION = "moderation" + EXTERNAL_DATA_TOOL = "external_data_tool" class ModuleExtension(BaseModel): @@ -41,12 +41,12 @@ class Extensible: position_map = {} # get the path of the current class - current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') + current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") current_dir_path = os.path.dirname(current_path) # traverse subdirectories for subdir_name in os.listdir(current_dir_path): - if subdir_name.startswith('__'): + if subdir_name.startswith("__"): continue subdir_path = os.path.join(current_dir_path, subdir_name) @@ -58,21 +58,21 @@ class Extensible: # in the front-end page and business logic, there are special treatments. builtin = False position = None - if '__builtin__' in file_names: + if "__builtin__" in file_names: builtin = True - builtin_file_path = os.path.join(subdir_path, '__builtin__') + builtin_file_path = os.path.join(subdir_path, "__builtin__") if os.path.exists(builtin_file_path): - with open(builtin_file_path, encoding='utf-8') as f: + with open(builtin_file_path, encoding="utf-8") as f: position = int(f.read().strip()) position_map[extension_name] = position - if (extension_name + '.py') not in file_names: + if (extension_name + ".py") not in file_names: logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") continue # Dynamic loading {subdir_name}.py file and find the subclass of Extensible - py_path = os.path.join(subdir_path, extension_name + '.py') + py_path = os.path.join(subdir_path, extension_name + ".py") spec = importlib.util.spec_from_file_location(extension_name, py_path) if not spec or not spec.loader: raise Exception(f"Failed to load module {extension_name} from {py_path}") @@ -91,25 +91,29 @@ class Extensible: json_data = {} if not builtin: - if 'schema.json' not in file_names: + if "schema.json" not in file_names: logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") continue - json_path = os.path.join(subdir_path, 'schema.json') + json_path = os.path.join(subdir_path, "schema.json") json_data = {} if os.path.exists(json_path): - with open(json_path, encoding='utf-8') as f: + with open(json_path, encoding="utf-8") as f: json_data = json.load(f) - extensions.append(ModuleExtension( - extension_class=extension_class, - name=extension_name, - label=json_data.get('label'), - form_schema=json_data.get('form_schema'), - builtin=builtin, - position=position - )) + extensions.append( + ModuleExtension( + extension_class=extension_class, + name=extension_name, + label=json_data.get("label"), + form_schema=json_data.get("form_schema"), + builtin=builtin, + position=position, + ) + ) - sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name) + sorted_extensions = sort_to_dict_by_position_map( + position_map=position_map, data=extensions, name_func=lambda x: x.name + ) return sorted_extensions diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 29e892c58a..3da170455e 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -6,10 +6,7 @@ from core.moderation.base import Moderation class Extension: __module_extensions: dict[str, dict[str, ModuleExtension]] = {} - module_classes = { - ExtensionModule.MODERATION: Moderation, - ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool - } + module_classes = {ExtensionModule.MODERATION: Moderation, ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool} def init(self): for module, module_class in self.module_classes.items(): diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 58c82502ea..54ec97a493 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -30,10 +30,11 @@ class ApiExternalDataTool(ExternalDataTool): raise ValueError("api_based_extension_id is required") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) if not api_based_extension: raise ValueError("api_based_extension_id is invalid") @@ -50,47 +51,42 @@ class ApiExternalDataTool(ExternalDataTool): api_based_extension_id = self.config.get("api_based_extension_id") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == self.tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) if not api_based_extension: - raise ValueError("[External data tool] API query failed, variable: {}, " - "error: api_based_extension_id is invalid" - .format(self.variable)) + raise ValueError( + "[External data tool] API query failed, variable: {}, " + "error: api_based_extension_id is invalid".format(self.variable) + ) # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=self.tenant_id, - token=api_based_extension.api_key - ) + api_key = encrypter.decrypt_token(tenant_id=self.tenant_id, token=api_based_extension.api_key) try: # request api - requestor = APIBasedExtensionRequestor( - api_endpoint=api_based_extension.api_endpoint, - api_key=api_key - ) + requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key) except Exception as e: - raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format( - self.variable, - e - )) + raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e)) - response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={ - 'app_id': self.app_id, - 'tool_variable': self.variable, - 'inputs': inputs, - 'query': query - }) + response_json = requestor.request( + point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, + params={"app_id": self.app_id, "tool_variable": self.variable, "inputs": inputs, "query": query}, + ) - if 'result' not in response_json: - raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response" - .format(self.variable)) + if "result" not in response_json: + raise ValueError( + "[External data tool] API query failed, variable: {}, error: result not found in response".format( + self.variable + ) + ) - if not isinstance(response_json['result'], str): - raise ValueError("[External data tool] API query failed, variable: {}, error: result is not string" - .format(self.variable)) + if not isinstance(response_json["result"], str): + raise ValueError( + "[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable) + ) - return response_json['result'] + return response_json["result"] diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 8601cb34e7..84b94e117f 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -12,11 +12,14 @@ logger = logging.getLogger(__name__) class ExternalDataFetch: - def fetch(self, tenant_id: str, - app_id: str, - external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, - query: str) -> dict: + def fetch( + self, + tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: dict, + query: str, + ) -> dict: """ Fill in variable inputs from external data tools if exists. @@ -38,7 +41,7 @@ class ExternalDataFetch: app_id, tool, inputs, - query + query, ) futures[future] = tool @@ -50,12 +53,15 @@ class ExternalDataFetch: inputs.update(results) return inputs - def _query_external_data_tool(self, flask_app: Flask, - tenant_id: str, - app_id: str, - external_data_tool: ExternalDataVariableEntity, - inputs: dict, - query: str) -> tuple[Optional[str], Optional[str]]: + def _query_external_data_tool( + self, + flask_app: Flask, + tenant_id: str, + app_id: str, + external_data_tool: ExternalDataVariableEntity, + inputs: dict, + query: str, + ) -> tuple[Optional[str], Optional[str]]: """ Query external data tool. :param flask_app: flask app @@ -72,17 +78,10 @@ class ExternalDataFetch: tool_config = external_data_tool.config external_data_tool_factory = ExternalDataToolFactory( - name=tool_type, - tenant_id=tenant_id, - app_id=app_id, - variable=tool_variable, - config=tool_config + name=tool_type, tenant_id=tenant_id, app_id=app_id, variable=tool_variable, config=tool_config ) # query external data tool - result = external_data_tool_factory.query( - inputs=inputs, - query=query - ) + result = external_data_tool_factory.query(inputs=inputs, query=query) return tool_variable, result diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 979f243af6..2872109859 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -5,14 +5,10 @@ from extensions.ext_code_based_extension import code_based_extension class ExternalDataToolFactory: - def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) self.__extension_instance = extension_class( - tenant_id=tenant_id, - app_id=app_id, - variable=variable, - config=config + tenant_id=tenant_id, app_id=app_id, variable=variable, config=config ) @classmethod diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 3959f4b4a0..5c4e694025 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -13,11 +13,12 @@ class FileExtraConfig(BaseModel): """ File Upload Entity. """ + image_config: Optional[dict[str, Any]] = None class FileType(enum.Enum): - IMAGE = 'image' + IMAGE = "image" @staticmethod def value_of(value): @@ -28,9 +29,9 @@ class FileType(enum.Enum): class FileTransferMethod(enum.Enum): - REMOTE_URL = 'remote_url' - LOCAL_FILE = 'local_file' - TOOL_FILE = 'tool_file' + REMOTE_URL = "remote_url" + LOCAL_FILE = "local_file" + TOOL_FILE = "tool_file" @staticmethod def value_of(value): @@ -39,9 +40,10 @@ class FileTransferMethod(enum.Enum): return member raise ValueError(f"No matching enum found for value '{value}'") + class FileBelongsTo(enum.Enum): - USER = 'user' - ASSISTANT = 'assistant' + USER = "user" + ASSISTANT = "assistant" @staticmethod def value_of(value): @@ -65,16 +67,16 @@ class FileVar(BaseModel): def to_dict(self) -> dict: return { - '__variant': self.__class__.__name__, - 'tenant_id': self.tenant_id, - 'type': self.type.value, - 'transfer_method': self.transfer_method.value, - 'url': self.preview_url, - 'remote_url': self.url, - 'related_id': self.related_id, - 'filename': self.filename, - 'extension': self.extension, - 'mime_type': self.mime_type, + "__variant": self.__class__.__name__, + "tenant_id": self.tenant_id, + "type": self.type.value, + "transfer_method": self.transfer_method.value, + "url": self.preview_url, + "remote_url": self.url, + "related_id": self.related_id, + "filename": self.filename, + "extension": self.extension, + "mime_type": self.mime_type, } def to_markdown(self) -> str: @@ -86,7 +88,7 @@ class FileVar(BaseModel): if self.type == FileType.IMAGE: text = f'![{self.filename or ""}]({preview_url})' else: - text = f'[{self.filename or preview_url}]({preview_url})' + text = f"[{self.filename or preview_url}]({preview_url})" return text @@ -115,28 +117,29 @@ class FileVar(BaseModel): return ImagePromptMessageContent( data=self.data, detail=ImagePromptMessageContent.DETAIL.HIGH - if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW + if image_config.get("detail") == "high" + else ImagePromptMessageContent.DETAIL.LOW, ) def _get_data(self, force_url: bool = False) -> Optional[str]: from models.model import UploadFile + if self.type == FileType.IMAGE: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.url elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == self.related_id, - UploadFile.tenant_id == self.tenant_id - ).first()) - - return UploadFileParser.get_image_data( - upload_file=upload_file, - force_url=force_url + upload_file = ( + db.session.query(UploadFile) + .filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id) + .first() ) + + return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url) elif self.transfer_method == FileTransferMethod.TOOL_FILE: extension = self.extension # add sign url - return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=self.related_id, extension=extension) + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=extension + ) return None diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 085ff07cfd..8feaabedbb 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -13,13 +13,13 @@ from services.file_service import IMAGE_EXTENSIONS class MessageFileParser: - def __init__(self, tenant_id: str, app_id: str) -> None: self.tenant_id = tenant_id self.app_id = app_id - def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, - user: Union[Account, EndUser]) -> list[FileVar]: + def validate_and_transform_files_arg( + self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser] + ) -> list[FileVar]: """ validate and transform files arg @@ -30,22 +30,22 @@ class MessageFileParser: """ for file in files: if not isinstance(file, dict): - raise ValueError('Invalid file format, must be dict') - if not file.get('type'): - raise ValueError('Missing file type') - FileType.value_of(file.get('type')) - if not file.get('transfer_method'): - raise ValueError('Missing file transfer method') - FileTransferMethod.value_of(file.get('transfer_method')) - if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value: - if not file.get('url'): - raise ValueError('Missing file url') - if not file.get('url').startswith('http'): - raise ValueError('Invalid file url') - if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'): - raise ValueError('Missing file upload_file_id') - if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'): - raise ValueError('Missing file tool_file_id') + raise ValueError("Invalid file format, must be dict") + if not file.get("type"): + raise ValueError("Missing file type") + FileType.value_of(file.get("type")) + if not file.get("transfer_method"): + raise ValueError("Missing file transfer method") + FileTransferMethod.value_of(file.get("transfer_method")) + if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value: + if not file.get("url"): + raise ValueError("Missing file url") + if not file.get("url").startswith("http"): + raise ValueError("Invalid file url") + if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"): + raise ValueError("Missing file upload_file_id") + if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"): + raise ValueError("Missing file tool_file_id") # transform files to file objs type_file_objs = self._to_file_objs(files, file_extra_config) @@ -62,17 +62,17 @@ class MessageFileParser: continue # Validate number of files - if len(files) > image_config['number_limits']: + if len(files) > image_config["number_limits"]: raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") for file_obj in file_objs: # Validate transfer method - if file_obj.transfer_method.value not in image_config['transfer_methods']: - raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}') + if file_obj.transfer_method.value not in image_config["transfer_methods"]: + raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}") # Validate file type if file_obj.type != FileType.IMAGE: - raise ValueError(f'Invalid file type: {file_obj.type}') + raise ValueError(f"Invalid file type: {file_obj.type}") if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: # check remote url valid and is image @@ -81,18 +81,21 @@ class MessageFileParser: raise ValueError(error) elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: # get upload file from upload_file_id - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == file_obj.related_id, - UploadFile.tenant_id == self.tenant_id, - UploadFile.created_by == user.id, - UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - UploadFile.extension.in_(IMAGE_EXTENSIONS) - ).first()) + upload_file = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == file_obj.related_id, + UploadFile.tenant_id == self.tenant_id, + UploadFile.created_by == user.id, + UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + UploadFile.extension.in_(IMAGE_EXTENSIONS), + ) + .first() + ) # check upload file is belong to tenant and user if not upload_file: - raise ValueError('Invalid upload file') + raise ValueError("Invalid upload file") new_files.append(file_obj) @@ -113,8 +116,9 @@ class MessageFileParser: # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] - def _to_file_objs(self, files: list[Union[dict, MessageFile]], - file_extra_config: FileExtraConfig) -> dict[FileType, list[FileVar]]: + def _to_file_objs( + self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig + ) -> dict[FileType, list[FileVar]]: """ transform files to file objs @@ -152,23 +156,23 @@ class MessageFileParser: :return: """ if isinstance(file, dict): - transfer_method = FileTransferMethod.value_of(file.get('transfer_method')) + transfer_method = FileTransferMethod.value_of(file.get("transfer_method")) if transfer_method != FileTransferMethod.TOOL_FILE: return FileVar( tenant_id=self.tenant_id, - type=FileType.value_of(file.get('type')), + type=FileType.value_of(file.get("type")), transfer_method=transfer_method, - url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=file_extra_config + url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=file_extra_config, ) return FileVar( tenant_id=self.tenant_id, - type=FileType.value_of(file.get('type')), + type=FileType.value_of(file.get("type")), transfer_method=transfer_method, url=None, - related_id=file.get('tool_file_id'), - extra_config=file_extra_config + related_id=file.get("tool_file_id"), + extra_config=file_extra_config, ) else: return FileVar( @@ -178,7 +182,7 @@ class MessageFileParser: transfer_method=FileTransferMethod.value_of(file.transfer_method), url=file.url, related_id=file.upload_file_id or None, - extra_config=file_extra_config + extra_config=file_extra_config, ) def _check_image_remote_url(self, url): @@ -190,17 +194,17 @@ class MessageFileParser: def is_s3_presigned_url(url): try: parsed_url = urlparse(url) - if 'amazonaws.com' not in parsed_url.netloc: + if "amazonaws.com" not in parsed_url.netloc: return False query_params = parse_qs(parsed_url.query) - required_params = ['Signature', 'Expires'] + required_params = ["Signature", "Expires"] for param in required_params: if param not in query_params: return False - if not query_params['Expires'][0].isdigit(): + if not query_params["Expires"][0].isdigit(): return False - signature = query_params['Signature'][0] - if not re.match(r'^[A-Za-z0-9+/]+={0,2}$', signature): + signature = query_params["Signature"][0] + if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature): return False return True except Exception: diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index ea8605ac57..1efaf5529d 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,8 +1,7 @@ -tool_file_manager = { - 'manager': None -} +tool_file_manager = {"manager": None} + class ToolFileParser: @staticmethod - def get_tool_file_manager() -> 'ToolFileManager': - return tool_file_manager['manager'] \ No newline at end of file + def get_tool_file_manager() -> "ToolFileManager": + return tool_file_manager["manager"] diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index 737a11e426..a8c1fd4d02 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -9,7 +9,7 @@ from typing import Optional from configs import dify_config from extensions.ext_storage import storage -IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) @@ -22,18 +22,18 @@ class UploadFileParser: if upload_file.extension not in IMAGE_EXTENSIONS: return None - if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == 'url' or force_url: + if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url: return cls.get_signed_temp_image_url(upload_file.id) else: # get image file base64 try: data = storage.load(upload_file.key) except FileNotFoundError: - logging.error(f'File not found: {upload_file.key}') + logging.error(f"File not found: {upload_file.key}") return None - encoded_string = base64.b64encode(data).decode('utf-8') - return f'data:{upload_file.mime_type};base64,{encoded_string}' + encoded_string = base64.b64encode(data).decode("utf-8") + return f"data:{upload_file.mime_type};base64,{encoded_string}" @classmethod def get_signed_temp_image_url(cls, upload_file_id) -> str: @@ -44,7 +44,7 @@ class UploadFileParser: :return: """ base_url = dify_config.FILES_URL - image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview' + image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 4662ebb47a..4a80a3ffe9 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -15,9 +15,11 @@ from core.helper.code_executor.template_transformer import TemplateTransformer logger = logging.getLogger(__name__) + class CodeExecutionException(Exception): pass + class CodeExecutionResponse(BaseModel): class Data(BaseModel): stdout: Optional[str] = None @@ -29,9 +31,9 @@ class CodeExecutionResponse(BaseModel): class CodeLanguage(str, Enum): - PYTHON3 = 'python3' - JINJA2 = 'jinja2' - JAVASCRIPT = 'javascript' + PYTHON3 = "python3" + JINJA2 = "jinja2" + JAVASCRIPT = "javascript" class CodeExecutor: @@ -45,63 +47,65 @@ class CodeExecutor: } code_language_to_running_language = { - CodeLanguage.JAVASCRIPT: 'nodejs', + CodeLanguage.JAVASCRIPT: "nodejs", CodeLanguage.JINJA2: CodeLanguage.PYTHON3, CodeLanguage.PYTHON3: CodeLanguage.PYTHON3, } - supported_dependencies_languages: set[CodeLanguage] = { - CodeLanguage.PYTHON3 - } + supported_dependencies_languages: set[CodeLanguage] = {CodeLanguage.PYTHON3} @classmethod - def execute_code(cls, - language: CodeLanguage, - preload: str, - code: str) -> str: + def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str: """ Execute code :param language: code language :param code: code :return: """ - url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / 'v1' / 'sandbox' / 'run' + url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run" - headers = { - 'X-Api-Key': dify_config.CODE_EXECUTION_API_KEY - } + headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY} data = { - 'language': cls.code_language_to_running_language.get(language), - 'code': code, - 'preload': preload, - 'enable_network': True + "language": cls.code_language_to_running_language.get(language), + "code": code, + "preload": preload, + "enable_network": True, } try: - response = post(str(url), json=data, headers=headers, - timeout=Timeout( - connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, - read=dify_config.CODE_EXECUTION_READ_TIMEOUT, - write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, - pool=None)) + response = post( + str(url), + json=data, + headers=headers, + timeout=Timeout( + connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, + read=dify_config.CODE_EXECUTION_READ_TIMEOUT, + write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, + pool=None, + ), + ) if response.status_code == 503: - raise CodeExecutionException('Code execution service is unavailable') + raise CodeExecutionException("Code execution service is unavailable") elif response.status_code != 200: - raise Exception(f'Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running') + raise Exception( + f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running" + ) except CodeExecutionException as e: raise e except Exception as e: - raise CodeExecutionException('Failed to execute code, which is likely a network issue,' - ' please check if the sandbox service is running.' - f' ( Error: {str(e)} )') + raise CodeExecutionException( + "Failed to execute code, which is likely a network issue," + " please check if the sandbox service is running." + f" ( Error: {str(e)} )" + ) try: response = response.json() except: - raise CodeExecutionException('Failed to parse response') + raise CodeExecutionException("Failed to parse response") - if (code := response.get('code')) != 0: + if (code := response.get("code")) != 0: raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}") response = CodeExecutionResponse(**response) @@ -109,7 +113,7 @@ class CodeExecutor: if response.data.error: raise CodeExecutionException(response.data.error) - return response.data.stdout or '' + return response.data.stdout or "" @classmethod def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict: @@ -122,7 +126,7 @@ class CodeExecutor: """ template_transformer = cls.code_template_transformers.get(language) if not template_transformer: - raise CodeExecutionException(f'Unsupported language {language}') + raise CodeExecutionException(f"Unsupported language {language}") runner, preload = template_transformer.transform_caller(code, inputs) diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py index 3f099b7ac5..e233a596b9 100644 --- a/api/core/helper/code_executor/code_node_provider.py +++ b/api/core/helper/code_executor/code_node_provider.py @@ -26,23 +26,9 @@ class CodeNodeProvider(BaseModel): return { "type": "code", "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - }, - { - "variable": "arg2", - "value_selector": [] - } - ], + "variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}], "code_language": cls.get_language(), "code": cls.get_default_code(), - "outputs": { - "result": { - "type": "string", - "children": None - } - } - } + "outputs": {"result": {"type": "string", "children": None}}, + }, } diff --git a/api/core/helper/code_executor/javascript/javascript_code_provider.py b/api/core/helper/code_executor/javascript/javascript_code_provider.py index a157fcc6d1..ae324b83a9 100644 --- a/api/core/helper/code_executor/javascript/javascript_code_provider.py +++ b/api/core/helper/code_executor/javascript/javascript_code_provider.py @@ -18,4 +18,5 @@ class JavascriptCodeProvider(CodeNodeProvider): result: arg1 + arg2 } } - """) + """ + ) diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py index a4d2551972..d67a0903aa 100644 --- a/api/core/helper/code_executor/javascript/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript/javascript_transformer.py @@ -21,5 +21,6 @@ class NodeJsTemplateTransformer(TemplateTransformer): var output_json = JSON.stringify(output_obj) var result = `<>${{output_json}}<>` console.log(result) - """) + """ + ) return runner_script diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py index f1e5da584c..db2eb5ebb6 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -10,8 +10,6 @@ class Jinja2Formatter: :param inputs: inputs :return: """ - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=template, inputs=inputs - ) + result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs) - return result['result'] + return result["result"] diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py index b8cb29600e..63d58edbc7 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -11,9 +11,7 @@ class Jinja2TemplateTransformer(TemplateTransformer): :param response: response :return: """ - return { - 'result': cls.extract_result_str_from_response(response) - } + return {"result": cls.extract_result_str_from_response(response)} @classmethod def get_runner_script(cls) -> str: diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py index 923724b49d..9cca8af7c6 100644 --- a/api/core/helper/code_executor/python3/python3_code_provider.py +++ b/api/core/helper/code_executor/python3/python3_code_provider.py @@ -17,4 +17,5 @@ class Python3CodeProvider(CodeNodeProvider): return { "result": arg1 + arg2, } - """) + """ + ) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index cf66558b65..6f016f27bc 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,9 +5,9 @@ from base64 import b64encode class TemplateTransformer(ABC): - _code_placeholder: str = '{{code}}' - _inputs_placeholder: str = '{{inputs}}' - _result_tag: str = '<>' + _code_placeholder: str = "{{code}}" + _inputs_placeholder: str = "{{inputs}}" + _result_tag: str = "<>" @classmethod def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: @@ -24,9 +24,9 @@ class TemplateTransformer(ABC): @classmethod def extract_result_str_from_response(cls, response: str) -> str: - result = re.search(rf'{cls._result_tag}(.*){cls._result_tag}', response, re.DOTALL) + result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: - raise ValueError('Failed to parse result') + raise ValueError("Failed to parse result") result = result.group(1) return result @@ -50,7 +50,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: dict) -> str: inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode() - input_base64_encoded = b64encode(inputs_json_str).decode('utf-8') + input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded @classmethod @@ -67,4 +67,4 @@ class TemplateTransformer(ABC): """ Get preload script """ - return '' + return "" diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 5e5deb86b4..96341a1b78 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -8,14 +8,15 @@ def obfuscated_token(token: str): if not token: return token if len(token) <= 8: - return '*' * 20 - return token[:6] + '*' * 12 + token[-2:] + return "*" * 20 + return token[:6] + "*" * 12 + token[-2:] def encrypt_token(tenant_id: str, token: str): from models.account import Tenant + if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): - raise ValueError(f'Tenant with id {tenant_id} not found') + raise ValueError(f"Tenant with id {tenant_id} not found") encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) return base64.b64encode(encrypted_token).decode() diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 29cb4acc7d..5e274f8916 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -25,7 +25,7 @@ class ProviderCredentialsCache: cached_provider_credentials = redis_client.get(self.cache_key) if cached_provider_credentials: try: - cached_provider_credentials = cached_provider_credentials.decode('utf-8') + cached_provider_credentials = cached_provider_credentials.decode("utf-8") cached_provider_credentials = json.loads(cached_provider_credentials) except JSONDecodeError: return None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 20feae8554..b880590de2 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -12,19 +12,20 @@ logger = logging.getLogger(__name__) def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: moderation_config = hosting_configuration.moderation_config - if (moderation_config and moderation_config.enabled is True - and 'openai' in hosting_configuration.provider_map - and hosting_configuration.provider_map['openai'].enabled is True + if ( + moderation_config + and moderation_config.enabled is True + and "openai" in hosting_configuration.provider_map + and hosting_configuration.provider_map["openai"].enabled is True ): using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type provider_name = model_config.provider - if using_provider_type == ProviderType.SYSTEM \ - and provider_name in moderation_config.providers: - hosting_openai_config = hosting_configuration.provider_map['openai'] + if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: + hosting_openai_config = hosting_configuration.provider_map["openai"] # 2000 text per chunk length = 2000 - text_chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i : i + length] for i in range(0, len(text), length)] if len(text_chunks) == 0: return True @@ -34,15 +35,13 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) try: model_type_instance = OpenAIModerationModel() moderation_result = model_type_instance.invoke( - model='text-moderation-stable', - credentials=hosting_openai_config.credentials, - text=text_chunk + model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk ) if moderation_result is True: return True except Exception as ex: logger.exception(ex) - raise InvokeBadRequestError('Rate limit exceeded, please try again later.') + raise InvokeBadRequestError("Rate limit exceeded, please try again later.") return False diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 2000577a40..e6e1491548 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -37,8 +37,9 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type] """ Get all the subclasses of the parent type from the module """ - classes = [x for _, x in vars(mod).items() - if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)] + classes = [ + x for _, x in vars(mod).items() if isinstance(x, type) and x != parent_type and issubclass(x, parent_type) + ] return classes @@ -56,6 +57,6 @@ def load_single_subclass_from_source( case 1: return subclasses[0] case 0: - raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}') + raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}") case _: - raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}') \ No newline at end of file + raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}") diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index 32e3806231..3efdc8aa47 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -73,10 +73,10 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) def is_filtered( - include_set: set[str], - exclude_set: set[str], - data: Any, - name_func: Callable[[Any], str], + include_set: set[str], + exclude_set: set[str], + data: Any, + name_func: Callable[[Any], str], ) -> bool: """ Check if the object should be filtered out. @@ -102,9 +102,9 @@ def is_filtered( def sort_by_position_map( - position_map: dict[str, int], - data: list[Any], - name_func: Callable[[Any], str], + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], ) -> list[Any]: """ Sort the objects by the position map. @@ -117,13 +117,13 @@ def sort_by_position_map( if not position_map or not data: return data - return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf'))) + return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf"))) def sort_to_dict_by_position_map( - position_map: dict[str, int], - data: list[Any], - name_func: Callable[[Any], str], + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], ) -> OrderedDict[str, Any]: """ Sort the objects into a ordered dict by the position map. diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 14ca8e943c..4e6d58904e 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -1,31 +1,34 @@ """ Proxy requests to avoid SSRF """ + import logging import os import time import httpx -SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '') -SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') -SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') -SSRF_DEFAULT_MAX_RETRIES = int(os.getenv('SSRF_DEFAULT_MAX_RETRIES', '3')) +SSRF_PROXY_ALL_URL = os.getenv("SSRF_PROXY_ALL_URL", "") +SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "") +SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "") +SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3")) -proxies = { - 'http://': SSRF_PROXY_HTTP_URL, - 'https://': SSRF_PROXY_HTTPS_URL -} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None +proxies = ( + {"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL} + if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL + else None +) BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] + def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") if "follow_redirects" not in kwargs: kwargs["follow_redirects"] = allow_redirects - + retries = 0 while retries <= max_retries: try: @@ -52,24 +55,24 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('GET', url, max_retries=max_retries, **kwargs) + return make_request("GET", url, max_retries=max_retries, **kwargs) def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('POST', url, max_retries=max_retries, **kwargs) + return make_request("POST", url, max_retries=max_retries, **kwargs) def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('PUT', url, max_retries=max_retries, **kwargs) + return make_request("PUT", url, max_retries=max_retries, **kwargs) def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('PATCH', url, max_retries=max_retries, **kwargs) + return make_request("PATCH", url, max_retries=max_retries, **kwargs) def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('DELETE', url, max_retries=max_retries, **kwargs) + return make_request("DELETE", url, max_retries=max_retries, **kwargs) def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('HEAD', url, max_retries=max_retries, **kwargs) + return make_request("HEAD", url, max_retries=max_retries, **kwargs) diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index a6f486e81d..4c3b736186 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -9,14 +9,11 @@ from extensions.ext_redis import redis_client class ToolParameterCacheType(Enum): PARAMETER = "tool_parameter" + class ToolParameterCache: - def __init__(self, - tenant_id: str, - provider: str, - tool_name: str, - cache_type: ToolParameterCacheType, - identity_id: str - ): + def __init__( + self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str + ): self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}" def get(self) -> Optional[dict]: @@ -28,7 +25,7 @@ class ToolParameterCache: cached_tool_parameter = redis_client.get(self.cache_key) if cached_tool_parameter: try: - cached_tool_parameter = cached_tool_parameter.decode('utf-8') + cached_tool_parameter = cached_tool_parameter.decode("utf-8") cached_tool_parameter = json.loads(cached_tool_parameter) except JSONDecodeError: return None @@ -52,4 +49,4 @@ class ToolParameterCache: :return: """ - redis_client.delete(self.cache_key) \ No newline at end of file + redis_client.delete(self.cache_key) diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 6c5d3b8fb6..94b02cf985 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -9,6 +9,7 @@ from extensions.ext_redis import redis_client class ToolProviderCredentialsCacheType(Enum): PROVIDER = "tool_provider" + class ToolProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" @@ -22,7 +23,7 @@ class ToolProviderCredentialsCache: cached_provider_credentials = redis_client.get(self.cache_key) if cached_provider_credentials: try: - cached_provider_credentials = cached_provider_credentials.decode('utf-8') + cached_provider_credentials = cached_provider_credentials.decode("utf-8") cached_provider_credentials = json.loads(cached_provider_credentials) except JSONDecodeError: return None @@ -46,4 +47,4 @@ class ToolProviderCredentialsCache: :return: """ - redis_client.delete(self.cache_key) \ No newline at end of file + redis_client.delete(self.cache_key) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index ddcd751286..eeeccc2349 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -46,7 +46,7 @@ class HostingConfiguration: def init_app(self, app: Flask) -> None: config = app.config - if config.get('EDITION') != 'CLOUD': + if config.get("EDITION") != "CLOUD": return self.provider_map["azure_openai"] = self.init_azure_openai(config) @@ -65,7 +65,7 @@ class HostingConfiguration: credentials = { "openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"), "openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"), - "base_model_name": "gpt-35-turbo" + "base_model_name": "gpt-35-turbo", } quotas = [] @@ -77,26 +77,45 @@ class HostingConfiguration: RestrictModel(model="gpt-4o", base_model_name="gpt-4o", model_type=ModelType.LLM), RestrictModel(model="gpt-4o-mini", base_model_name="gpt-4o-mini", model_type=ModelType.LLM), RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM), + RestrictModel( + model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM + ), RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM), - RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM), - RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING), - RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING), - RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING), - ] + RestrictModel( + model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM + ), + RestrictModel( + model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM + ), + RestrictModel( + model="text-embedding-ada-002", + base_model_name="text-embedding-ada-002", + model_type=ModelType.TEXT_EMBEDDING, + ), + RestrictModel( + model="text-embedding-3-small", + base_model_name="text-embedding-3-small", + model_type=ModelType.TEXT_EMBEDDING, + ), + RestrictModel( + model="text-embedding-3-large", + base_model_name="text-embedding-3-large", + model_type=ModelType.TEXT_EMBEDDING, + ), + ], ) quotas.append(trial_quota) - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, @@ -110,17 +129,12 @@ class HostingConfiguration: if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS") - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit, - restrict_models=trial_models - ) + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) quotas.append(trial_quota) if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS") - paid_quota = PaidHostingQuota( - restrict_models=paid_models - ) + paid_quota = PaidHostingQuota(restrict_models=paid_models) quotas.append(paid_quota) if len(quotas) > 0: @@ -134,12 +148,7 @@ class HostingConfiguration: if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"): credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION") - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, @@ -153,9 +162,7 @@ class HostingConfiguration: if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"): hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0")) - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit - ) + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) quotas.append(trial_quota) if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"): @@ -170,12 +177,7 @@ class HostingConfiguration: if app_config.get("HOSTED_ANTHROPIC_API_BASE"): credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE") - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, @@ -192,7 +194,7 @@ class HostingConfiguration: enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -210,7 +212,7 @@ class HostingConfiguration: enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -228,7 +230,7 @@ class HostingConfiguration: enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -238,21 +240,19 @@ class HostingConfiguration: @staticmethod def init_moderation_config(app_config: Config) -> HostedModerationConfig: - if app_config.get("HOSTED_MODERATION_ENABLED") \ - and app_config.get("HOSTED_MODERATION_PROVIDERS"): + if app_config.get("HOSTED_MODERATION_ENABLED") and app_config.get("HOSTED_MODERATION_PROVIDERS"): return HostedModerationConfig( - enabled=True, - providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(',') + enabled=True, providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(",") ) - return HostedModerationConfig( - enabled=False - ) + return HostedModerationConfig(enabled=False) @staticmethod def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]: models_str = app_config.get(env_var) models_list = models_str.split(",") if models_str else [] - return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if - model_name.strip()] - + return [ + RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) + for model_name in models_list + if model_name.strip() + ] diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index df563f609b..b6968e46cd 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -39,7 +39,6 @@ from services.feature_service import FeatureService class IndexingRunner: - def __init__(self): self.storage = storage self.model_manager = ModelManager() @@ -49,25 +48,26 @@ class IndexingRunner: for dataset_document in dataset_documents: try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) # transform - documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, - processing_rule.to_dict()) + documents = self._transform( + index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + ) # save segment self._load_segments(dataset, dataset_document, documents) @@ -76,20 +76,20 @@ class IndexingRunner: index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, - documents=documents + documents=documents, ) except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except ObjectDeletedError: - logging.warning('Document deleted, document id: {}'.format(dataset_document.id)) + logging.warning("Document deleted, document id: {}".format(dataset_document.id)) except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -98,26 +98,25 @@ class IndexingRunner: """Run the indexing process when the index_status is splitting.""" try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, - document_id=dataset_document.id + dataset_id=dataset.id, document_id=dataset_document.id ).all() for document_segment in document_segments: db.session.delete(document_segment) db.session.commit() # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -125,28 +124,26 @@ class IndexingRunner: text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) # transform - documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, - processing_rule.to_dict()) + documents = self._transform( + index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + ) # save segment self._load_segments(dataset, dataset_document, documents) # load self._load( - index_processor=index_processor, - dataset=dataset, - dataset_document=dataset_document, - documents=documents + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -155,17 +152,14 @@ class IndexingRunner: """Run the indexing process when the index_status is indexing.""" try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, - document_id=dataset_document.id + dataset_id=dataset.id, document_id=dataset_document.id ).all() documents = [] @@ -180,42 +174,48 @@ class IndexingRunner: "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) documents.append(document) # build index # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() self._load( - index_processor=index_processor, - dataset=dataset, - dataset_document=dataset_document, - documents=documents + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, - doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, - indexing_technique: str = 'economy') -> dict: + def indexing_estimate( + self, + tenant_id: str, + extract_settings: list[ExtractSetting], + tmp_processing_rule: dict, + doc_form: str = None, + doc_language: str = "English", + dataset_id: str = None, + indexing_technique: str = "economy", + ) -> dict: """ Estimate the indexing for the document. """ @@ -229,18 +229,16 @@ class IndexingRunner: embedding_model_instance = None if dataset_id: - dataset = Dataset.query.filter_by( - id=dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: - raise ValueError('Dataset not found.') - if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': + raise ValueError("Dataset not found.") + if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) else: embedding_model_instance = self.model_manager.get_default_model_instance( @@ -248,7 +246,7 @@ class IndexingRunner: model_type=ModelType.TEXT_EMBEDDING, ) else: - if indexing_technique == 'high_quality': + if indexing_technique == "high_quality": embedding_model_instance = self.model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, @@ -263,8 +261,7 @@ class IndexingRunner: text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) all_text_docs.extend(text_docs) processing_rule = DatasetProcessRule( - mode=tmp_processing_rule["mode"], - rules=json.dumps(tmp_processing_rule["rules"]) + mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) # get splitter @@ -272,9 +269,7 @@ class IndexingRunner: # split to documents documents = self._split_to_documents_for_estimate( - text_docs=text_docs, - splitter=splitter, - processing_rule=processing_rule + text_docs=text_docs, splitter=splitter, processing_rule=processing_rule ) total_segments += len(documents) @@ -282,110 +277,110 @@ class IndexingRunner: if len(preview_texts) < 5: preview_texts.append(document.page_content) - if doc_form and doc_form == 'qa_model': - + if doc_form and doc_form == "qa_model": if len(preview_texts) > 0: # qa model document - response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], - doc_language) + response = LLMGenerator.generate_qa_document( + current_user.current_tenant_id, preview_texts[0], doc_language + ) document_qa_list = self.format_split_text(response) - return { - "total_segments": total_segments * 20, - "qa_preview": document_qa_list, - "preview": preview_texts - } - return { - "total_segments": total_segments, - "preview": preview_texts - } + return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts} + return {"total_segments": total_segments, "preview": preview_texts} - def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \ - -> list[Document]: + def _extract( + self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict + ) -> list[Document]: # load file if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: return [] data_source_info = dataset_document.data_source_info_dict text_docs = [] - if dataset_document.data_source_type == 'upload_file': - if not data_source_info or 'upload_file_id' not in data_source_info: + if dataset_document.data_source_type == "upload_file": + if not data_source_info or "upload_file_id" not in data_source_info: raise ValueError("no upload file found") - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == data_source_info['upload_file_id']). \ - one_or_none() + file_detail = ( + db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() + ) if file_detail: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=dataset_document.doc_form + datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) - elif dataset_document.data_source_type == 'notion_import': - if (not data_source_info or 'notion_workspace_id' not in data_source_info - or 'notion_page_id' not in data_source_info): + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + elif dataset_document.data_source_type == "notion_import": + if ( + not data_source_info + or "notion_workspace_id" not in data_source_info + or "notion_page_id" not in data_source_info + ): raise ValueError("no notion import info found") extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ - "notion_workspace_id": data_source_info['notion_workspace_id'], - "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['type'], + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], "document": dataset_document, - "tenant_id": dataset_document.tenant_id + "tenant_id": dataset_document.tenant_id, }, - document_model=dataset_document.doc_form + document_model=dataset_document.doc_form, ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) - elif dataset_document.data_source_type == 'website_crawl': - if (not data_source_info or 'provider' not in data_source_info - or 'url' not in data_source_info or 'job_id' not in data_source_info): + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + elif dataset_document.data_source_type == "website_crawl": + if ( + not data_source_info + or "provider" not in data_source_info + or "url" not in data_source_info + or "job_id" not in data_source_info + ): raise ValueError("no website import info found") extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": data_source_info['provider'], - "job_id": data_source_info['job_id'], + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], "tenant_id": dataset_document.tenant_id, - "url": data_source_info['url'], - "mode": data_source_info['mode'], - "only_main_content": data_source_info['only_main_content'] + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], }, - document_model=dataset_document.doc_form + document_model=dataset_document.doc_form, ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="splitting", extra_update_params={ DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), - DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: - text_doc.metadata['document_id'] = dataset_document.id - text_doc.metadata['dataset_id'] = dataset_document.dataset_id + text_doc.metadata["document_id"] = dataset_document.id + text_doc.metadata["dataset_id"] = dataset_document.dataset_id return text_docs @staticmethod def filter_string(text): - text = re.sub(r'<\|', '<', text) - text = re.sub(r'\|>', '>', text) - text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) + text = re.sub(r"<\|", "<", text) + text = re.sub(r"\|>", ">", text) + text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text) # Unicode U+FFFE - text = re.sub('\uFFFE', '', text) + text = re.sub("\ufffe", "", text) return text @staticmethod - def _get_splitter(processing_rule: DatasetProcessRule, - embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter( + processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] + ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ @@ -399,10 +394,10 @@ class IndexingRunner: separator = segmentation["separator"] if separator: - separator = separator.replace('\\n', '\n') + separator = separator.replace("\\n", "\n") - if segmentation.get('chunk_overlap'): - chunk_overlap = segmentation['chunk_overlap'] + if segmentation.get("chunk_overlap"): + chunk_overlap = segmentation["chunk_overlap"] else: chunk_overlap = 0 @@ -411,22 +406,27 @@ class IndexingRunner: chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) else: # Automatic segmentation character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], + chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], + chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) return character_splitter - def _step_split(self, text_docs: list[Document], splitter: TextSplitter, - dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ - -> list[Document]: + def _step_split( + self, + text_docs: list[Document], + splitter: TextSplitter, + dataset: Dataset, + dataset_document: DatasetDocument, + processing_rule: DatasetProcessRule, + ) -> list[Document]: """ Split the text documents into documents and save them to the document segment. """ @@ -436,14 +436,12 @@ class IndexingRunner: processing_rule=processing_rule, tenant_id=dataset.tenant_id, document_form=dataset_document.doc_form, - document_language=dataset_document.doc_language + document_language=dataset_document.doc_language, ) # save node to document segment doc_store = DatasetDocumentStore( - dataset=dataset, - user_id=dataset_document.created_by, - document_id=dataset_document.id + dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments @@ -457,7 +455,7 @@ class IndexingRunner: extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, - } + }, ) # update segment status to indexing @@ -465,15 +463,21 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) return documents - def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule, tenant_id: str, - document_form: str, document_language: str) -> list[Document]: + def _split_to_documents( + self, + text_docs: list[Document], + splitter: TextSplitter, + processing_rule: DatasetProcessRule, + tenant_id: str, + document_form: str, + document_language: str, + ) -> list[Document]: """ Split the text documents into nodes. """ @@ -488,12 +492,11 @@ class IndexingRunner: documents = splitter.split_documents([text_doc]) split_documents = [] for document_node in documents: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): @@ -506,15 +509,21 @@ class IndexingRunner: split_documents.append(document_node) all_documents.extend(split_documents) # processing qa document - if document_form == 'qa_model': + if document_form == "qa_model": for i in range(0, len(all_documents), 10): threads = [] - sub_documents = all_documents[i:i + 10] + sub_documents = all_documents[i : i + 10] for doc in sub_documents: - document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={ - 'flask_app': current_app._get_current_object(), - 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents, - 'document_language': document_language}) + document_format_thread = threading.Thread( + target=self.format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), + "tenant_id": tenant_id, + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": document_language, + }, + ) threads.append(document_format_thread) document_format_thread.start() for thread in threads: @@ -533,12 +542,14 @@ class IndexingRunner: document_qa_list = self.format_split_text(response) qa_documents = [] for result in document_qa_list: - qa_document = Document(page_content=result['question'], metadata=document_node.metadata.model_copy()) + qa_document = Document( + page_content=result["question"], metadata=document_node.metadata.model_copy() + ) doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result['question']) - qa_document.metadata['answer'] = result['answer'] - qa_document.metadata['doc_id'] = doc_id - qa_document.metadata['doc_hash'] = hash + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -546,8 +557,9 @@ class IndexingRunner: all_qa_documents.extend(format_documents) - def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule) -> list[Document]: + def _split_to_documents_for_estimate( + self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule + ) -> list[Document]: """ Split the text documents into nodes. """ @@ -567,8 +579,8 @@ class IndexingRunner: doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document.page_content) - document.metadata['doc_id'] = doc_id - document.metadata['doc_hash'] = hash + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash split_documents.append(document) @@ -586,23 +598,23 @@ class IndexingRunner: else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} - if 'pre_processing_rules' in rules: + if "pre_processing_rules" in rules: pre_processing_rules = rules["pre_processing_rules"] for pre_processing_rule in pre_processing_rules: if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: # Remove extra spaces - pattern = r'\n{3,}' - text = re.sub(pattern, '\n\n', text) - pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' - text = re.sub(pattern, ' ', text) + pattern = r"\n{3,}" + text = re.sub(pattern, "\n\n", text) + pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}" + text = re.sub(pattern, " ", text) elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: # Remove email - pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' - text = re.sub(pattern, '', text) + pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" + text = re.sub(pattern, "", text) # Remove URL - pattern = r'https?://[^\s]+' - text = re.sub(pattern, '', text) + pattern = r"https?://[^\s]+" + text = re.sub(pattern, "", text) return text @@ -611,27 +623,26 @@ class IndexingRunner: regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [ - { - "question": q, - "answer": re.sub(r"\n\s*", "\n", a.strip()) - } - for q, a in matches if q and a - ] + return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] - def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, - dataset_document: DatasetDocument, documents: list[Document]) -> None: + def _load( + self, + index_processor: BaseIndexProcessor, + dataset: Dataset, + dataset_document: DatasetDocument, + documents: list[Document], + ) -> None: """ insert index and update document/segment status to completed """ embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # chunk nodes by chunk size @@ -640,18 +651,27 @@ class IndexingRunner: chunk_size = 10 # create keyword index - create_keyword_thread = threading.Thread(target=self._process_keyword_index, - args=(current_app._get_current_object(), - dataset.id, dataset_document.id, documents)) + create_keyword_thread = threading.Thread( + target=self._process_keyword_index, + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), + ) create_keyword_thread.start() - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [] for i in range(0, len(documents), chunk_size): - chunk_documents = documents[i:i + chunk_size] - futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor, - chunk_documents, dataset, - dataset_document, embedding_model_instance)) + chunk_documents = documents[i : i + chunk_size] + futures.append( + executor.submit( + self._process_chunk, + current_app._get_current_object(), + index_processor, + chunk_documents, + dataset, + dataset_document, + embedding_model_instance, + ) + ) for future in futures: tokens += future.result() @@ -668,7 +688,7 @@ class IndexingRunner: DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.error: None, - } + }, ) @staticmethod @@ -679,23 +699,26 @@ class IndexingRunner: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) - if dataset.indexing_technique != 'high_quality': - document_ids = [document.metadata['doc_id'] for document in documents] + if dataset.indexing_technique != "high_quality": + document_ids = [document.metadata["doc_id"] for document in documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == document_id, DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing" - ).update({ - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - }) + DocumentSegment.status == "indexing", + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + } + ) db.session.commit() - def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document, - embedding_model_instance): + def _process_chunk( + self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance + ): with flask_app.app_context(): # check document is paused self._check_document_paused_status(dataset_document.id) @@ -703,26 +726,26 @@ class IndexingRunner: tokens = 0 if embedding_model_instance: tokens += sum( - embedding_model_instance.get_text_embedding_num_tokens( - [document.page_content] - ) + embedding_model_instance.get_text_embedding_num_tokens([document.page_content]) for document in chunk_documents ) # load index index_processor.load(dataset, chunk_documents, with_keywords=False) - document_ids = [document.metadata['doc_id'] for document in chunk_documents] + document_ids = [document.metadata["doc_id"] for document in chunk_documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == dataset_document.id, DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing" - ).update({ - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - }) + DocumentSegment.status == "indexing", + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + } + ) db.session.commit() @@ -730,14 +753,15 @@ class IndexingRunner: @staticmethod def _check_document_paused_status(document_id: str): - indexing_cache_key = 'document_{}_is_paused'.format(document_id) + indexing_cache_key = "document_{}_is_paused".format(document_id) result = redis_client.get(indexing_cache_key) if result: raise DocumentIsPausedException() @staticmethod - def _update_document_index_status(document_id: str, after_indexing_status: str, - extra_update_params: Optional[dict] = None) -> None: + def _update_document_index_status( + document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None + ) -> None: """ Update the document indexing status. """ @@ -748,9 +772,7 @@ class IndexingRunner: if not document: raise DocumentIsDeletedPausedException() - update_params = { - DatasetDocument.indexing_status: after_indexing_status - } + update_params = {DatasetDocument.indexing_status: after_indexing_status} if extra_update_params: update_params.update(extra_update_params) @@ -780,7 +802,7 @@ class IndexingRunner: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) # save vector index @@ -788,17 +810,23 @@ class IndexingRunner: index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.load(dataset, documents) - def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset, - text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]: + def _transform( + self, + index_processor: BaseIndexProcessor, + dataset: Dataset, + text_docs: list[Document], + doc_language: str, + process_rule: dict, + ) -> list[Document]: # get embedding model instance embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) else: embedding_model_instance = self.model_manager.get_default_model_instance( @@ -806,18 +834,20 @@ class IndexingRunner: model_type=ModelType.TEXT_EMBEDDING, ) - documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance, - process_rule=process_rule, tenant_id=dataset.tenant_id, - doc_language=doc_language) + documents = index_processor.transform( + text_docs, + embedding_model_instance=embedding_model_instance, + process_rule=process_rule, + tenant_id=dataset.tenant_id, + doc_language=doc_language, + ) return documents def _load_segments(self, dataset, dataset_document, documents): # save node to document segment doc_store = DatasetDocumentStore( - dataset=dataset, - user_id=dataset_document.created_by, - document_id=dataset_document.id + dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments @@ -831,7 +861,7 @@ class IndexingRunner: extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, - } + }, ) # update segment status to indexing @@ -839,8 +869,8 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) pass diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 8c13b4a45c..78a6d6e683 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -43,21 +43,16 @@ class LLMGenerator: with measure_time() as timer: response = model_instance.invoke_llm( - prompt_messages=prompts, - model_parameters={ - "max_tokens": 100, - "temperature": 1 - }, - stream=False + prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False ) answer = response.message.content - cleaned_answer = re.sub(r'^.*(\{.*\}).*$', r'\1', answer, flags=re.DOTALL) + cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) result_dict = json.loads(cleaned_answer) - answer = result_dict['Your Output'] + answer = result_dict["Your Output"] name = answer.strip() if len(name) > 75: - name = name[:75] + '...' + name = name[:75] + "..." # get tracing instance trace_manager = TraceQueueManager(app_id=app_id) @@ -79,14 +74,9 @@ class LLMGenerator: output_parser = SuggestedQuestionsAfterAnswerOutputParser() format_instructions = output_parser.get_format_instructions() - prompt_template = PromptTemplateParser( - template="{{histories}}\n{{format_instructions}}\nquestions:\n" - ) + prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n") - prompt = prompt_template.format({ - "histories": histories, - "format_instructions": format_instructions - }) + prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) try: model_manager = ModelManager() @@ -101,12 +91,7 @@ class LLMGenerator: try: response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters={ - "max_tokens": 256, - "temperature": 0 - }, - stream=False + prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False ) questions = output_parser.parse(response.message.content) @@ -119,32 +104,24 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512) -> dict: + def generate_rule_config( + cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512 + ) -> dict: output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" - rule_config = { - "prompt": "", - "variables": [], - "opening_statement": "", - "error": "" - } - model_parameters = { - "max_tokens": rule_config_max_tokens, - "temperature": 0.01 - } + rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} + model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01} if no_variable: - prompt_template = PromptTemplateParser( - WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE - ) + prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_generate = prompt_template.format( inputs={ "TASK_DESCRIPTION": instruction, }, - remove_template_variables=False + remove_template_variables=False, ) prompt_messages = [UserPromptMessage(content=prompt_generate)] @@ -158,13 +135,11 @@ class LLMGenerator: try: response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False ) rule_config["prompt"] = response.message.content - + except InvokeError as e: error = str(e) error_step = "generate rule config" @@ -179,24 +154,18 @@ class LLMGenerator: # get rule config prompt, parameter and statement prompt_generate, parameter_generate, statement_generate = output_parser.get_format_instructions() - prompt_template = PromptTemplateParser( - prompt_generate - ) + prompt_template = PromptTemplateParser(prompt_generate) - parameter_template = PromptTemplateParser( - parameter_generate - ) + parameter_template = PromptTemplateParser(parameter_generate) - statement_template = PromptTemplateParser( - statement_generate - ) + statement_template = PromptTemplateParser(statement_generate) # format the prompt_generate_prompt prompt_generate_prompt = prompt_template.format( inputs={ "TASK_DESCRIPTION": instruction, }, - remove_template_variables=False + remove_template_variables=False, ) prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] @@ -213,9 +182,7 @@ class LLMGenerator: try: # the first step to generate the task prompt prompt_content = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False ) except InvokeError as e: error = str(e) @@ -230,7 +197,7 @@ class LLMGenerator: inputs={ "INPUT_TEXT": prompt_content.message.content, }, - remove_template_variables=False + remove_template_variables=False, ) parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)] @@ -240,15 +207,13 @@ class LLMGenerator: "TASK_DESCRIPTION": instruction, "INPUT_TEXT": prompt_content.message.content, }, - remove_template_variables=False + remove_template_variables=False, ) statement_messages = [UserPromptMessage(content=statement_generate_prompt)] try: parameter_content = model_instance.invoke_llm( - prompt_messages=parameter_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False ) rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content) except InvokeError as e: @@ -257,9 +222,7 @@ class LLMGenerator: try: statement_content = model_instance.invoke_llm( - prompt_messages=statement_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=statement_messages, model_parameters=model_parameters, stream=False ) rule_config["opening_statement"] = statement_content.message.content except InvokeError as e: @@ -284,18 +247,10 @@ class LLMGenerator: model_type=ModelType.LLM, ) - prompt_messages = [ - SystemPromptMessage(content=prompt), - UserPromptMessage(content=query) - ] + prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters={ - 'temperature': 0.01, - "max_tokens": 2000 - }, - stream=False + prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False ) answer = response.message.content diff --git a/api/core/llm_generator/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py index 8856f0c685..b6932698cb 100644 --- a/api/core/llm_generator/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -10,9 +10,12 @@ from libs.json_in_md_parser import parse_and_check_json_markdown class RuleConfigGeneratorOutputParser: - def get_format_instructions(self) -> tuple[str, str, str]: - return RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE + return ( + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, + ) def parse(self, text: str) -> Any: try: @@ -21,16 +24,9 @@ class RuleConfigGeneratorOutputParser: if not isinstance(parsed["prompt"], str): raise ValueError("Expected 'prompt' to be a string.") if not isinstance(parsed["variables"], list): - raise ValueError( - "Expected 'variables' to be a list." - ) + raise ValueError("Expected 'variables' to be a list.") if not isinstance(parsed["opening_statement"], str): - raise ValueError( - "Expected 'opening_statement' to be a str." - ) + raise ValueError("Expected 'opening_statement' to be a str.") return parsed except Exception as e: - raise OutputParserException( - f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}" - ) - + raise OutputParserException(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index 3f046c68fc..182aeed98f 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -6,7 +6,6 @@ from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCT class SuggestedQuestionsAfterAnswerOutputParser: - def get_format_instructions(self) -> str: return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT @@ -15,7 +14,7 @@ class SuggestedQuestionsAfterAnswerOutputParser: if action_match is not None: json_obj = json.loads(action_match.group(0).strip()) else: - json_obj= [] + json_obj = [] print(f"Could not parse LLM output: {text}") return json_obj diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index dbd6e26c7c..7ab257872f 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -66,19 +66,19 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( "and keeping each question under 20 characters.\n" "MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n" "The output must be an array in JSON format following the specified schema:\n" - "[\"question1\",\"question2\",\"question3\"]\n" + '["question1","question2","question3"]\n' ) GENERATOR_QA_PROMPT = ( - ' The user will send a long text. Generate a Question and Answer pairs only using the knowledge in the long text. Please think step by step.' - 'Step 1: Understand and summarize the main content of this text.\n' - 'Step 2: What key information or concepts are mentioned in this text?\n' - 'Step 3: Decompose or combine multiple pieces of information and concepts.\n' - 'Step 4: Generate questions and answers based on these key information and concepts.\n' - ' The questions should be clear and detailed, and the answers should be detailed and complete. ' - 'You must answer in {language}, in a style that is clear and detailed in {language}. No language other than {language} should be used. \n' - ' Use the following format: Q1:\nA1:\nQ2:\nA2:...\n' - '' + " The user will send a long text. Generate a Question and Answer pairs only using the knowledge in the long text. Please think step by step." + "Step 1: Understand and summarize the main content of this text.\n" + "Step 2: What key information or concepts are mentioned in this text?\n" + "Step 3: Decompose or combine multiple pieces of information and concepts.\n" + "Step 4: Generate questions and answers based on these key information and concepts.\n" + " The questions should be clear and detailed, and the answers should be detailed and complete. " + "You must answer in {language}, in a style that is clear and detailed in {language}. No language other than {language} should be used. \n" + " Use the following format: Q1:\nA1:\nQ2:\nA2:...\n" + "" ) WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """ diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index b33d4dd7cb..54b1d8212b 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -21,8 +21,9 @@ class TokenBufferMemory: self.conversation = conversation self.model_instance = model_instance - def get_history_prompt_messages(self, max_token_limit: int = 2000, - message_limit: Optional[int] = None) -> list[PromptMessage]: + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> list[PromptMessage]: """ Get history prompt messages. :param max_token_limit: max token limit @@ -31,16 +32,11 @@ class TokenBufferMemory: app_record = self.conversation.app # fetch limited messages, and return reversed - query = db.session.query( - Message.id, - Message.query, - Message.answer, - Message.created_at, - Message.workflow_run_id - ).filter( - Message.conversation_id == self.conversation.id, - Message.answer != '' - ).order_by(Message.created_at.desc()) + query = ( + db.session.query(Message.id, Message.query, Message.answer, Message.created_at, Message.workflow_run_id) + .filter(Message.conversation_id == self.conversation.id, Message.answer != "") + .order_by(Message.created_at.desc()) + ) if message_limit and message_limit > 0: message_limit = message_limit if message_limit <= 500 else 500 @@ -50,10 +46,7 @@ class TokenBufferMemory: messages = query.limit(message_limit).all() messages = list(reversed(messages)) - message_file_parser = MessageFileParser( - tenant_id=app_record.tenant_id, - app_id=app_record.id - ) + message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id) prompt_messages = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() @@ -63,20 +56,17 @@ class TokenBufferMemory: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) else: if message.workflow_run_id: - workflow_run = (db.session.query(WorkflowRun) - .filter(WorkflowRun.id == message.workflow_run_id).first()) + workflow_run = ( + db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() + ) if workflow_run: file_extra_config = FileUploadConfigManager.convert( - workflow_run.workflow.features_dict, - is_vision=False + workflow_run.workflow.features_dict, is_vision=False ) if file_extra_config: - file_objs = message_file_parser.transform_message_files( - files, - file_extra_config - ) + file_objs = message_file_parser.transform_message_files(files, file_extra_config) else: file_objs = [] @@ -97,24 +87,23 @@ class TokenBufferMemory: return [] # prune the chat message if it exceeds the max token limit - curr_message_tokens = self.model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) if curr_message_tokens > max_token_limit: pruned_memory = [] - while curr_message_tokens > max_token_limit and len(prompt_messages)>1: + while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: pruned_memory.append(prompt_messages.pop(0)) - curr_message_tokens = self.model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) return prompt_messages - def get_history_prompt_text(self, human_prefix: str = "Human", - ai_prefix: str = "Assistant", - max_token_limit: int = 2000, - message_limit: Optional[int] = None) -> str: + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ) -> str: """ Get history prompt text. :param human_prefix: human prefix @@ -123,10 +112,7 @@ class TokenBufferMemory: :param message_limit: message limit :return: """ - prompt_messages = self.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=message_limit - ) + prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) string_messages = [] for m in prompt_messages: diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index bba004a32a..92da53c9a4 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -18,12 +18,21 @@ class Callback: Base class for callbacks. Only for LLM. """ + raise_error: bool = False - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_before_invoke( + self, + llm_instance: AIModel, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Before invoke callback @@ -39,10 +48,19 @@ class Callback: """ raise NotImplementedError() - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): + def on_new_chunk( + self, + llm_instance: AIModel, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ): """ On new chunk callback @@ -59,10 +77,19 @@ class Callback: """ raise NotImplementedError() - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_after_invoke( + self, + llm_instance: AIModel, + result: LLMResult, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ After invoke callback @@ -79,10 +106,19 @@ class Callback: """ raise NotImplementedError() - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_invoke_error( + self, + llm_instance: AIModel, + ex: Exception, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Invoke error callback @@ -99,9 +135,7 @@ class Callback: """ raise NotImplementedError() - def print_text( - self, text: str, color: Optional[str] = None, end: str = "" - ) -> None: + def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None: """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(text_to_print, end=end) diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 0406853b88..3b6b825244 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -10,11 +10,20 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) + class LoggingCallback(Callback): - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_before_invoke( + self, + llm_instance: AIModel, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Before invoke callback @@ -28,40 +37,49 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_before_invoke]\n", color='blue') - self.print_text(f"Model: {model}\n", color='blue') - self.print_text("Parameters:\n", color='blue') + self.print_text("\n[on_llm_before_invoke]\n", color="blue") + self.print_text(f"Model: {model}\n", color="blue") + self.print_text("Parameters:\n", color="blue") for key, value in model_parameters.items(): - self.print_text(f"\t{key}: {value}\n", color='blue') + self.print_text(f"\t{key}: {value}\n", color="blue") if stop: - self.print_text(f"\tstop: {stop}\n", color='blue') + self.print_text(f"\tstop: {stop}\n", color="blue") if tools: - self.print_text("\tTools:\n", color='blue') + self.print_text("\tTools:\n", color="blue") for tool in tools: - self.print_text(f"\t\t{tool.name}\n", color='blue') + self.print_text(f"\t\t{tool.name}\n", color="blue") - self.print_text(f"Stream: {stream}\n", color='blue') + self.print_text(f"Stream: {stream}\n", color="blue") if user: - self.print_text(f"User: {user}\n", color='blue') + self.print_text(f"User: {user}\n", color="blue") - self.print_text("Prompt messages:\n", color='blue') + self.print_text("Prompt messages:\n", color="blue") for prompt_message in prompt_messages: if prompt_message.name: - self.print_text(f"\tname: {prompt_message.name}\n", color='blue') + self.print_text(f"\tname: {prompt_message.name}\n", color="blue") - self.print_text(f"\trole: {prompt_message.role.value}\n", color='blue') - self.print_text(f"\tcontent: {prompt_message.content}\n", color='blue') + self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue") + self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue") if stream: self.print_text("\n[on_llm_new_chunk]") - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): + def on_new_chunk( + self, + llm_instance: AIModel, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ): """ On new chunk callback @@ -79,10 +97,19 @@ class LoggingCallback(Callback): sys.stdout.write(chunk.delta.message.content) sys.stdout.flush() - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_after_invoke( + self, + llm_instance: AIModel, + result: LLMResult, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ After invoke callback @@ -97,24 +124,33 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_after_invoke]\n", color='yellow') - self.print_text(f"Content: {result.message.content}\n", color='yellow') + self.print_text("\n[on_llm_after_invoke]\n", color="yellow") + self.print_text(f"Content: {result.message.content}\n", color="yellow") if result.message.tool_calls: - self.print_text("Tool calls:\n", color='yellow') + self.print_text("Tool calls:\n", color="yellow") for tool_call in result.message.tool_calls: - self.print_text(f"\t{tool_call.id}\n", color='yellow') - self.print_text(f"\t{tool_call.function.name}\n", color='yellow') - self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color='yellow') + self.print_text(f"\t{tool_call.id}\n", color="yellow") + self.print_text(f"\t{tool_call.function.name}\n", color="yellow") + self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow") - self.print_text(f"Model: {result.model}\n", color='yellow') - self.print_text(f"Usage: {result.usage}\n", color='yellow') - self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color='yellow') + self.print_text(f"Model: {result.model}\n", color="yellow") + self.print_text(f"Usage: {result.usage}\n", color="yellow") + self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow") - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_invoke_error( + self, + llm_instance: AIModel, + ex: Exception, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Invoke error callback @@ -129,5 +165,5 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_invoke_error]\n", color='red') + self.print_text("\n[on_llm_invoke_error]\n", color="red") logger.exception(ex) diff --git a/api/core/model_runtime/entities/common_entities.py b/api/core/model_runtime/entities/common_entities.py index 175c13cfdc..659ad59bd6 100644 --- a/api/core/model_runtime/entities/common_entities.py +++ b/api/core/model_runtime/entities/common_entities.py @@ -7,6 +7,7 @@ class I18nObject(BaseModel): """ Model class for i18n object. """ + zh_Hans: Optional[str] = None en_US: str diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index e04d9fcbbb..e94be6f918 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -2,123 +2,123 @@ from core.model_runtime.entities.model_entities import DefaultParameterName PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { - 'label': { - 'en_US': 'Temperature', - 'zh_Hans': '温度', + "label": { + "en_US": "Temperature", + "zh_Hans": "温度", }, - 'type': 'float', - 'help': { - 'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.', - 'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。', + "type": "float", + "help": { + "en_US": "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.", + "zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。", }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.TOP_P: { - 'label': { - 'en_US': 'Top P', - 'zh_Hans': 'Top P', + "label": { + "en_US": "Top P", + "zh_Hans": "Top P", }, - 'type': 'float', - 'help': { - 'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.', - 'zh_Hans': '通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。', + "type": "float", + "help": { + "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.", + "zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。", }, - 'required': False, - 'default': 1.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 1.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.TOP_K: { - 'label': { - 'en_US': 'Top K', - 'zh_Hans': 'Top K', + "label": { + "en_US": "Top K", + "zh_Hans": "Top K", }, - 'type': 'int', - 'help': { - 'en_US': 'Limits the number of tokens to consider for each step by keeping only the k most likely tokens.', - 'zh_Hans': '通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。', + "type": "int", + "help": { + "en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.", + "zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。", }, - 'required': False, - 'default': 50, - 'min': 1, - 'max': 100, - 'precision': 0, + "required": False, + "default": 50, + "min": 1, + "max": 100, + "precision": 0, }, DefaultParameterName.PRESENCE_PENALTY: { - 'label': { - 'en_US': 'Presence Penalty', - 'zh_Hans': '存在惩罚', + "label": { + "en_US": "Presence Penalty", + "zh_Hans": "存在惩罚", }, - 'type': 'float', - 'help': { - 'en_US': 'Applies a penalty to the log-probability of tokens already in the text.', - 'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。', + "type": "float", + "help": { + "en_US": "Applies a penalty to the log-probability of tokens already in the text.", + "zh_Hans": "对文本中已有的标记的对数概率施加惩罚。", }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.FREQUENCY_PENALTY: { - 'label': { - 'en_US': 'Frequency Penalty', - 'zh_Hans': '频率惩罚', + "label": { + "en_US": "Frequency Penalty", + "zh_Hans": "频率惩罚", }, - 'type': 'float', - 'help': { - 'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.', - 'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。', + "type": "float", + "help": { + "en_US": "Applies a penalty to the log-probability of tokens that appear in the text.", + "zh_Hans": "对文本中出现的标记的对数概率施加惩罚。", }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.MAX_TOKENS: { - 'label': { - 'en_US': 'Max Tokens', - 'zh_Hans': '最大标记', + "label": { + "en_US": "Max Tokens", + "zh_Hans": "最大标记", }, - 'type': 'int', - 'help': { - 'en_US': 'Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.', - 'zh_Hans': '指定生成结果长度的上限。如果生成结果截断,可以调大该参数。', + "type": "int", + "help": { + "en_US": "Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.", + "zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。", }, - 'required': False, - 'default': 64, - 'min': 1, - 'max': 2048, - 'precision': 0, + "required": False, + "default": 64, + "min": 1, + "max": 2048, + "precision": 0, }, DefaultParameterName.RESPONSE_FORMAT: { - 'label': { - 'en_US': 'Response Format', - 'zh_Hans': '回复格式', + "label": { + "en_US": "Response Format", + "zh_Hans": "回复格式", }, - 'type': 'string', - 'help': { - 'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.', - 'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等', + "type": "string", + "help": { + "en_US": "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.", + "zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等", }, - 'required': False, - 'options': ['JSON', 'XML'], + "required": False, + "options": ["JSON", "XML"], }, DefaultParameterName.JSON_SCHEMA: { - 'label': { - 'en_US': 'JSON Schema', + "label": { + "en_US": "JSON Schema", }, - 'type': 'text', - 'help': { - 'en_US': 'Set a response json schema will ensure LLM to adhere it.', - 'zh_Hans': '设置返回的json schema,llm将按照它返回', + "type": "text", + "help": { + "en_US": "Set a response json schema will ensure LLM to adhere it.", + "zh_Hans": "设置返回的json schema,llm将按照它返回", }, - 'required': False, + "required": False, }, } diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index 59a4c103a2..52b590f66a 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -12,11 +12,12 @@ class LLMMode(Enum): """ Enum class for large language model mode. """ + COMPLETION = "completion" CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'LLMMode': + def value_of(cls, value: str) -> "LLMMode": """ Get value of given mode. @@ -26,13 +27,14 @@ class LLMMode(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") class LLMUsage(ModelUsage): """ Model class for llm usage. """ + prompt_tokens: int prompt_unit_price: Decimal prompt_price_unit: Decimal @@ -50,20 +52,20 @@ class LLMUsage(ModelUsage): def empty_usage(cls): return cls( prompt_tokens=0, - prompt_unit_price=Decimal('0.0'), - prompt_price_unit=Decimal('0.0'), - prompt_price=Decimal('0.0'), + prompt_unit_price=Decimal("0.0"), + prompt_price_unit=Decimal("0.0"), + prompt_price=Decimal("0.0"), completion_tokens=0, - completion_unit_price=Decimal('0.0'), - completion_price_unit=Decimal('0.0'), - completion_price=Decimal('0.0'), + completion_unit_price=Decimal("0.0"), + completion_price_unit=Decimal("0.0"), + completion_price=Decimal("0.0"), total_tokens=0, - total_price=Decimal('0.0'), - currency='USD', - latency=0.0 + total_price=Decimal("0.0"), + currency="USD", + latency=0.0, ) - def plus(self, other: 'LLMUsage') -> 'LLMUsage': + def plus(self, other: "LLMUsage") -> "LLMUsage": """ Add two LLMUsage instances together. @@ -85,10 +87,10 @@ class LLMUsage(ModelUsage): total_tokens=self.total_tokens + other.total_tokens, total_price=self.total_price + other.total_price, currency=other.currency, - latency=self.latency + other.latency + latency=self.latency + other.latency, ) - def __add__(self, other: 'LLMUsage') -> 'LLMUsage': + def __add__(self, other: "LLMUsage") -> "LLMUsage": """ Overload the + operator to add two LLMUsage instances. @@ -97,10 +99,12 @@ class LLMUsage(ModelUsage): """ return self.plus(other) + class LLMResult(BaseModel): """ Model class for llm result. """ + model: str prompt_messages: list[PromptMessage] message: AssistantPromptMessage @@ -112,6 +116,7 @@ class LLMResultChunkDelta(BaseModel): """ Model class for llm result chunk delta. """ + index: int message: AssistantPromptMessage usage: Optional[LLMUsage] = None @@ -122,6 +127,7 @@ class LLMResultChunk(BaseModel): """ Model class for llm result chunk. """ + model: str prompt_messages: list[PromptMessage] system_fingerprint: Optional[str] = None @@ -132,4 +138,5 @@ class NumTokensResult(PriceInfo): """ Model class for number of tokens result. """ + tokens: int diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index e8e6963b56..e51bb18deb 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -9,13 +9,14 @@ class PromptMessageRole(Enum): """ Enum class for prompt message. """ + SYSTEM = "system" USER = "user" ASSISTANT = "assistant" TOOL = "tool" @classmethod - def value_of(cls, value: str) -> 'PromptMessageRole': + def value_of(cls, value: str) -> "PromptMessageRole": """ Get value of given mode. @@ -25,13 +26,14 @@ class PromptMessageRole(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid prompt message type value {value}') + raise ValueError(f"invalid prompt message type value {value}") class PromptMessageTool(BaseModel): """ Model class for prompt message tool. """ + name: str description: str parameters: dict @@ -41,7 +43,8 @@ class PromptMessageFunction(BaseModel): """ Model class for prompt message function. """ - type: str = 'function' + + type: str = "function" function: PromptMessageTool @@ -49,14 +52,16 @@ class PromptMessageContentType(Enum): """ Enum class for prompt message content type. """ - TEXT = 'text' - IMAGE = 'image' + + TEXT = "text" + IMAGE = "image" class PromptMessageContent(BaseModel): """ Model class for prompt message content. """ + type: PromptMessageContentType data: str @@ -65,6 +70,7 @@ class TextPromptMessageContent(PromptMessageContent): """ Model class for text prompt message content. """ + type: PromptMessageContentType = PromptMessageContentType.TEXT @@ -72,9 +78,10 @@ class ImagePromptMessageContent(PromptMessageContent): """ Model class for image prompt message content. """ + class DETAIL(Enum): - LOW = 'low' - HIGH = 'high' + LOW = "low" + HIGH = "high" type: PromptMessageContentType = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW @@ -84,6 +91,7 @@ class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ + role: PromptMessageRole content: Optional[str | list[PromptMessageContent]] = None name: Optional[str] = None @@ -101,6 +109,7 @@ class UserPromptMessage(PromptMessage): """ Model class for user prompt message. """ + role: PromptMessageRole = PromptMessageRole.USER @@ -108,14 +117,17 @@ class AssistantPromptMessage(PromptMessage): """ Model class for assistant prompt message. """ + class ToolCall(BaseModel): """ Model class for assistant prompt message tool call. """ + class ToolCallFunction(BaseModel): """ Model class for assistant prompt message tool call function. """ + name: str arguments: str @@ -123,7 +135,7 @@ class AssistantPromptMessage(PromptMessage): type: str function: ToolCallFunction - @field_validator('id', mode='before') + @field_validator("id", mode="before") @classmethod def transform_id_to_str(cls, value) -> str: if not isinstance(value, str): @@ -145,10 +157,12 @@ class AssistantPromptMessage(PromptMessage): return True + class SystemPromptMessage(PromptMessage): """ Model class for system prompt message. """ + role: PromptMessageRole = PromptMessageRole.SYSTEM @@ -156,6 +170,7 @@ class ToolPromptMessage(PromptMessage): """ Model class for tool prompt message. """ + role: PromptMessageRole = PromptMessageRole.TOOL tool_call_id: str diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index d6377d7e88..d898ef1490 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -11,6 +11,7 @@ class ModelType(Enum): """ Enum class for model type. """ + LLM = "llm" TEXT_EMBEDDING = "text-embedding" RERANK = "rerank" @@ -26,22 +27,22 @@ class ModelType(Enum): :return: model type """ - if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value: + if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value: return cls.LLM - elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value: + elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value: return cls.TEXT_EMBEDDING - elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value: + elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value: return cls.RERANK - elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value: + elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value: return cls.SPEECH2TEXT - elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value: + elif origin_model_type == "tts" or origin_model_type == cls.TTS.value: return cls.TTS - elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value: + elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value: return cls.TEXT2IMG elif origin_model_type == cls.MODERATION.value: return cls.MODERATION else: - raise ValueError(f'invalid origin model type {origin_model_type}') + raise ValueError(f"invalid origin model type {origin_model_type}") def to_origin_model_type(self) -> str: """ @@ -50,26 +51,28 @@ class ModelType(Enum): :return: origin model type """ if self == self.LLM: - return 'text-generation' + return "text-generation" elif self == self.TEXT_EMBEDDING: - return 'embeddings' + return "embeddings" elif self == self.RERANK: - return 'reranking' + return "reranking" elif self == self.SPEECH2TEXT: - return 'speech2text' + return "speech2text" elif self == self.TTS: - return 'tts' + return "tts" elif self == self.MODERATION: - return 'moderation' + return "moderation" elif self == self.TEXT2IMG: - return 'text2img' + return "text2img" else: - raise ValueError(f'invalid model type {self}') + raise ValueError(f"invalid model type {self}") + class FetchFrom(Enum): """ Enum class for fetch from. """ + PREDEFINED_MODEL = "predefined-model" CUSTOMIZABLE_MODEL = "customizable-model" @@ -78,6 +81,7 @@ class ModelFeature(Enum): """ Enum class for llm feature. """ + TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" @@ -89,6 +93,7 @@ class DefaultParameterName(str, Enum): """ Enum class for parameter template variable. """ + TEMPERATURE = "temperature" TOP_P = "top_p" TOP_K = "top_k" @@ -99,7 +104,7 @@ class DefaultParameterName(str, Enum): JSON_SCHEMA = "json_schema" @classmethod - def value_of(cls, value: Any) -> 'DefaultParameterName': + def value_of(cls, value: Any) -> "DefaultParameterName": """ Get parameter name from value. @@ -109,13 +114,14 @@ class DefaultParameterName(str, Enum): for name in cls: if name.value == value: return name - raise ValueError(f'invalid parameter name {value}') + raise ValueError(f"invalid parameter name {value}") class ParameterType(Enum): """ Enum class for parameter type. """ + FLOAT = "float" INT = "int" STRING = "string" @@ -127,6 +133,7 @@ class ModelPropertyKey(Enum): """ Enum class for model property key. """ + MODE = "mode" CONTEXT_SIZE = "context_size" MAX_CHUNKS = "max_chunks" @@ -144,6 +151,7 @@ class ProviderModel(BaseModel): """ Model class for provider model. """ + model: str label: I18nObject model_type: ModelType @@ -158,6 +166,7 @@ class ParameterRule(BaseModel): """ Model class for parameter rule. """ + name: str use_template: Optional[str] = None label: I18nObject @@ -175,6 +184,7 @@ class PriceConfig(BaseModel): """ Model class for pricing info. """ + input: Decimal output: Optional[Decimal] = None unit: Decimal @@ -185,6 +195,7 @@ class AIModelEntity(ProviderModel): """ Model class for AI model. """ + parameter_rules: list[ParameterRule] = [] pricing: Optional[PriceConfig] = None @@ -197,6 +208,7 @@ class PriceType(Enum): """ Enum class for price type. """ + INPUT = "input" OUTPUT = "output" @@ -205,6 +217,7 @@ class PriceInfo(BaseModel): """ Model class for price info. """ + unit_price: Decimal unit: Decimal total_amount: Decimal diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index f88f89d588..bfe861a97f 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -12,6 +12,7 @@ class ConfigurateMethod(Enum): """ Enum class for configurate method of provider model. """ + PREDEFINED_MODEL = "predefined-model" CUSTOMIZABLE_MODEL = "customizable-model" @@ -20,6 +21,7 @@ class FormType(Enum): """ Enum class for form type. """ + TEXT_INPUT = "text-input" SECRET_INPUT = "secret-input" SELECT = "select" @@ -31,6 +33,7 @@ class FormShowOnObject(BaseModel): """ Model class for form show on. """ + variable: str value: str @@ -39,6 +42,7 @@ class FormOption(BaseModel): """ Model class for form option. """ + label: I18nObject value: str show_on: list[FormShowOnObject] = [] @@ -46,15 +50,14 @@ class FormOption(BaseModel): def __init__(self, **data): super().__init__(**data) if not self.label: - self.label = I18nObject( - en_US=self.value - ) + self.label = I18nObject(en_US=self.value) class CredentialFormSchema(BaseModel): """ Model class for credential form schema. """ + variable: str label: I18nObject type: FormType @@ -70,6 +73,7 @@ class ProviderCredentialSchema(BaseModel): """ Model class for provider credential schema. """ + credential_form_schemas: list[CredentialFormSchema] @@ -82,6 +86,7 @@ class ModelCredentialSchema(BaseModel): """ Model class for model credential schema. """ + model: FieldModelSchema credential_form_schemas: list[CredentialFormSchema] @@ -90,6 +95,7 @@ class SimpleProviderEntity(BaseModel): """ Simple model class for provider. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -102,6 +108,7 @@ class ProviderHelpEntity(BaseModel): """ Model class for provider help. """ + title: I18nObject url: I18nObject @@ -110,6 +117,7 @@ class ProviderEntity(BaseModel): """ Model class for provider. """ + provider: str label: I18nObject description: Optional[I18nObject] = None @@ -138,7 +146,7 @@ class ProviderEntity(BaseModel): icon_small=self.icon_small, icon_large=self.icon_large, supported_model_types=self.supported_model_types, - models=self.models + models=self.models, ) @@ -146,5 +154,6 @@ class ProviderConfig(BaseModel): """ Model class for provider config. """ + provider: str credentials: dict diff --git a/api/core/model_runtime/entities/rerank_entities.py b/api/core/model_runtime/entities/rerank_entities.py index d51efd2b3b..99709e1bcd 100644 --- a/api/core/model_runtime/entities/rerank_entities.py +++ b/api/core/model_runtime/entities/rerank_entities.py @@ -5,6 +5,7 @@ class RerankDocument(BaseModel): """ Model class for rerank document. """ + index: int text: str score: float @@ -14,5 +15,6 @@ class RerankResult(BaseModel): """ Model class for rerank result. """ + model: str docs: list[RerankDocument] diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 7be3def379..846b89d658 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -9,6 +9,7 @@ class EmbeddingUsage(ModelUsage): """ Model class for embedding usage. """ + tokens: int total_tokens: int unit_price: Decimal @@ -22,7 +23,7 @@ class TextEmbeddingResult(BaseModel): """ Model class for text embedding result. """ + model: str embeddings: list[list[float]] usage: EmbeddingUsage - diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index 0513cfaf67..edfb19c7d0 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -3,6 +3,7 @@ from typing import Optional class InvokeError(Exception): """Base class for all LLM exceptions.""" + description: Optional[str] = None def __init__(self, description: Optional[str] = None) -> None: @@ -14,24 +15,29 @@ class InvokeError(Exception): class InvokeConnectionError(InvokeError): """Raised when the Invoke returns connection error.""" + description = "Connection Error" class InvokeServerUnavailableError(InvokeError): """Raised when the Invoke returns server unavailable error.""" + description = "Server Unavailable Error" class InvokeRateLimitError(InvokeError): """Raised when the Invoke returns rate limit error.""" + description = "Rate Limit Error" class InvokeAuthorizationError(InvokeError): """Raised when the Invoke returns authorization error.""" + description = "Incorrect model credentials provided, please check and try again. " class InvokeBadRequestError(InvokeError): """Raised when the Invoke returns bad request.""" + description = "Bad Request Error" diff --git a/api/core/model_runtime/errors/validate.py b/api/core/model_runtime/errors/validate.py index 8db79a52bb..7fcd2133f9 100644 --- a/api/core/model_runtime/errors/validate.py +++ b/api/core/model_runtime/errors/validate.py @@ -2,4 +2,5 @@ class CredentialsValidateFailedError(Exception): """ Credentials validate failed error """ + pass diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 716bb63566..09d2d7e54d 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -66,12 +66,14 @@ class AIModel(ABC): :param error: model invoke error :return: unified error """ - provider_name = self.__class__.__module__.split('.')[-3] + provider_name = self.__class__.__module__.split(".")[-3] for invoke_error, model_errors in self._invoke_error_mapping.items(): if isinstance(error, tuple(model_errors)): if invoke_error == InvokeAuthorizationError: - return invoke_error(description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. ") + return invoke_error( + description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. " + ) return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}") @@ -115,7 +117,7 @@ class AIModel(ABC): if not price_config: raise ValueError(f"Price config not found for model {model}") total_amount = tokens * unit_price * price_config.unit - total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP) return PriceInfo( unit_price=unit_price, @@ -136,24 +138,26 @@ class AIModel(ABC): model_schemas = [] # get module name - model_type = self.__class__.__module__.split('.')[-1] + model_type = self.__class__.__module__.split(".")[-1] # get provider name - provider_name = self.__class__.__module__.split('.')[-3] + provider_name = self.__class__.__module__.split(".")[-3] # get the path of current classes current_path = os.path.abspath(__file__) # get parent path of the current path - provider_model_type_path = os.path.join(os.path.dirname(os.path.dirname(current_path)), provider_name, model_type) + provider_model_type_path = os.path.join( + os.path.dirname(os.path.dirname(current_path)), provider_name, model_type + ) # get all yaml files path under provider_model_type_path that do not start with __ model_schema_yaml_paths = [ os.path.join(provider_model_type_path, model_schema_yaml) for model_schema_yaml in os.listdir(provider_model_type_path) - if not model_schema_yaml.startswith('__') - and not model_schema_yaml.startswith('_') + if not model_schema_yaml.startswith("__") + and not model_schema_yaml.startswith("_") and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) - and model_schema_yaml.endswith('.yaml') + and model_schema_yaml.endswith(".yaml") ] # get _position.yaml file path @@ -165,10 +169,10 @@ class AIModel(ABC): yaml_data = load_yaml_file(model_schema_yaml_path) new_parameter_rules = [] - for parameter_rule in yaml_data.get('parameter_rules', []): - if 'use_template' in parameter_rule: + for parameter_rule in yaml_data.get("parameter_rules", []): + if "use_template" in parameter_rule: try: - default_parameter_name = DefaultParameterName.value_of(parameter_rule['use_template']) + default_parameter_name = DefaultParameterName.value_of(parameter_rule["use_template"]) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) copy_default_parameter_rule = default_parameter_rule.copy() copy_default_parameter_rule.update(parameter_rule) @@ -176,31 +180,26 @@ class AIModel(ABC): except ValueError: pass - if 'label' not in parameter_rule: - parameter_rule['label'] = { - 'zh_Hans': parameter_rule['name'], - 'en_US': parameter_rule['name'] - } + if "label" not in parameter_rule: + parameter_rule["label"] = {"zh_Hans": parameter_rule["name"], "en_US": parameter_rule["name"]} new_parameter_rules.append(parameter_rule) - yaml_data['parameter_rules'] = new_parameter_rules + yaml_data["parameter_rules"] = new_parameter_rules - if 'label' not in yaml_data: - yaml_data['label'] = { - 'zh_Hans': yaml_data['model'], - 'en_US': yaml_data['model'] - } + if "label" not in yaml_data: + yaml_data["label"] = {"zh_Hans": yaml_data["model"], "en_US": yaml_data["model"]} - yaml_data['fetch_from'] = FetchFrom.PREDEFINED_MODEL.value + yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value try: # yaml_data to entity model_schema = AIModelEntity(**yaml_data) except Exception as e: model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml") - raise Exception(f'Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:' - f' {str(e)}') + raise Exception( + f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:" f" {str(e)}" + ) # cache model schema model_schemas.append(model_schema) @@ -235,7 +234,9 @@ class AIModel(ABC): return None - def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials( + self, model: str, credentials: Mapping + ) -> Optional[AIModelEntity]: """ Get customizable model schema from credentials @@ -261,19 +262,19 @@ class AIModel(ABC): try: default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max and 'max' in default_parameter_rule: - parameter_rule.max = default_parameter_rule['max'] - if not parameter_rule.min and 'min' in default_parameter_rule: - parameter_rule.min = default_parameter_rule['min'] - if not parameter_rule.default and 'default' in default_parameter_rule: - parameter_rule.default = default_parameter_rule['default'] - if not parameter_rule.precision and 'precision' in default_parameter_rule: - parameter_rule.precision = default_parameter_rule['precision'] - if not parameter_rule.required and 'required' in default_parameter_rule: - parameter_rule.required = default_parameter_rule['required'] - if not parameter_rule.help and 'help' in default_parameter_rule: + if not parameter_rule.max and "max" in default_parameter_rule: + parameter_rule.max = default_parameter_rule["max"] + if not parameter_rule.min and "min" in default_parameter_rule: + parameter_rule.min = default_parameter_rule["min"] + if not parameter_rule.default and "default" in default_parameter_rule: + parameter_rule.default = default_parameter_rule["default"] + if not parameter_rule.precision and "precision" in default_parameter_rule: + parameter_rule.precision = default_parameter_rule["precision"] + if not parameter_rule.required and "required" in default_parameter_rule: + parameter_rule.required = default_parameter_rule["required"] + if not parameter_rule.help and "help" in default_parameter_rule: parameter_rule.help = I18nObject( - en_US=default_parameter_rule['help']['en_US'], + en_US=default_parameter_rule["help"]["en_US"], ) if ( parameter_rule.help diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index cfc8942c79..5c39186e65 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -35,16 +35,24 @@ class LargeLanguageModel(AIModel): """ Model class for large language model. """ + model_type: ModelType = ModelType.LLM # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ - -> Union[LLMResult, Generator]: + def invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -69,7 +77,7 @@ class LargeLanguageModel(AIModel): callbacks = callbacks or [] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if bool(os.environ.get("DEBUG", "False").lower() == "true"): callbacks.append(LoggingCallback()) # trigger before invoke callbacks @@ -82,7 +90,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) try: @@ -96,7 +104,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) else: result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -111,7 +119,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) raise self._transform_invoke_error(e) @@ -127,7 +135,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) elif isinstance(result, LLMResult): self._trigger_after_invoke_callbacks( @@ -140,15 +148,23 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) return result - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper, ensure the response is a code block with output markdown quote @@ -183,7 +199,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) model_parameters.pop("response_format") @@ -195,15 +211,16 @@ if you are not sure about the structure. if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", str(prompt_messages[0].content)) + content=block_prompts.replace("{{instructions}}", str(prompt_messages[0].content)) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", f"Please output a valid {code_block} object.") - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=block_prompts.replace("{{instructions}}", f"Please output a valid {code_block} object.") + ), + ) if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): # add ```JSON\n to the last text message @@ -216,9 +233,7 @@ if you are not sure about the structure. break else: # append a user message - prompt_messages.append(UserPromptMessage( - content=f"```{code_block}\n" - )) + prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n")) response = self._invoke( model=model, @@ -228,33 +243,30 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if isinstance(response, Generator): first_chunk = next(response) + def new_generator(): yield first_chunk yield from response if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"): return self._code_block_mode_stream_processor_with_backtick( - model=model, - prompt_messages=prompt_messages, - input_generator=new_generator() + model=model, prompt_messages=prompt_messages, input_generator=new_generator() ) else: return self._code_block_mode_stream_processor( - model=model, - prompt_messages=prompt_messages, - input_generator=new_generator() + model=model, prompt_messages=prompt_messages, input_generator=new_generator() ) return response - def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], - input_generator: Generator[LLMResultChunk, None, None] - ) -> Generator[LLMResultChunk, None, None]: + def _code_block_mode_stream_processor( + self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None] + ) -> Generator[LLMResultChunk, None, None]: """ Code block mode stream processor, ensure the response is a code block with output markdown quote @@ -303,16 +315,13 @@ if you are not sure about the structure. prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=new_piece, - tool_calls=[] - ), - ) + message=AssistantPromptMessage(content=new_piece, tool_calls=[]), + ), ) - def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, - input_generator: Generator[LLMResultChunk, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _code_block_mode_stream_processor_with_backtick( + self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None] + ) -> Generator[LLMResultChunk, None, None]: """ Code block mode stream processor, ensure the response is a code block with output markdown quote. This version skips the language identifier that follows the opening triple backticks. @@ -378,18 +387,23 @@ if you are not sure about the structure. prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=new_piece, - tool_calls=[] - ), - ) + message=AssistantPromptMessage(content=new_piece, tool_calls=[]), + ), ) - def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator: + def _invoke_result_generator( + self, + model: str, + result: Generator, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Generator: """ Invoke result generator @@ -397,9 +411,7 @@ if you are not sure about the structure. :return: result generator """ callbacks = callbacks or [] - prompt_message = AssistantPromptMessage( - content="" - ) + prompt_message = AssistantPromptMessage(content="") usage = None system_fingerprint = None real_model = model @@ -418,7 +430,7 @@ if you are not sure about the structure. stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) prompt_message.content += chunk.delta.message.content @@ -438,7 +450,7 @@ if you are not sure about the structure. prompt_messages=prompt_messages, message=prompt_message, usage=usage if usage else LLMUsage.empty_usage(), - system_fingerprint=system_fingerprint + system_fingerprint=system_fingerprint, ), credentials=credentials, prompt_messages=prompt_messages, @@ -447,15 +459,21 @@ if you are not sure about the structure. stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) @abstractmethod - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -472,8 +490,13 @@ if you are not sure about the structure. raise NotImplementedError @abstractmethod - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -519,7 +542,9 @@ if you are not sure about the structure. return mode - def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + def _calc_response_usage( + self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int + ) -> LLMUsage: """ Calculate response usage @@ -539,10 +564,7 @@ if you are not sure about the structure. # get completion price info completion_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.OUTPUT, - tokens=completion_tokens + model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens ) # transform usage @@ -558,16 +580,23 @@ if you are not sure about the structure. total_tokens=prompt_tokens + completion_tokens, total_price=prompt_price_info.total_amount + completion_price_info.total_amount, currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_before_invoke_callbacks( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger before invoke callbacks @@ -593,7 +622,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -601,11 +630,19 @@ if you are not sure about the structure. else: logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}") - def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_new_chunk_callbacks( + self, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger new chunk callbacks @@ -632,7 +669,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -640,11 +677,19 @@ if you are not sure about the structure. else: logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}") - def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_after_invoke_callbacks( + self, + model: str, + result: LLMResult, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger after invoke callbacks @@ -672,7 +717,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -680,11 +725,19 @@ if you are not sure about the structure. else: logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}") - def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_invoke_error_callbacks( + self, + model: str, + ex: Exception, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger invoke error callbacks @@ -712,7 +765,7 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -758,11 +811,13 @@ if you are not sure about the structure. # validate parameter value range if parameter_rule.min is not None and parameter_value < parameter_rule.min: raise ValueError( - f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.") + f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}." + ) if parameter_rule.max is not None and parameter_value > parameter_rule.max: raise ValueError( - f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") + f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}." + ) elif parameter_rule.type == ParameterType.FLOAT: if not isinstance(parameter_value, float | int): raise ValueError(f"Model Parameter {parameter_name} should be float.") @@ -775,16 +830,19 @@ if you are not sure about the structure. else: if parameter_value != round(parameter_value, parameter_rule.precision): raise ValueError( - f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places.") + f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places." + ) # validate parameter value range if parameter_rule.min is not None and parameter_value < parameter_rule.min: raise ValueError( - f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.") + f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}." + ) if parameter_rule.max is not None and parameter_value > parameter_rule.max: raise ValueError( - f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") + f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}." + ) elif parameter_rule.type == ParameterType.BOOLEAN: if not isinstance(parameter_value, bool): raise ValueError(f"Model Parameter {parameter_name} should be bool.") diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 780460a3f7..4374093de4 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -29,32 +29,32 @@ class ModelProvider(ABC): def get_provider_schema(self) -> ProviderEntity: """ Get provider schema - + :return: provider schema """ if self.provider_schema: return self.provider_schema - + # get dirname of the current path - provider_name = self.__class__.__module__.split('.')[-1] + provider_name = self.__class__.__module__.split(".")[-1] # get the path of the model_provider classes base_path = os.path.abspath(__file__) current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name) - + # read provider schema from yaml file - yaml_path = os.path.join(current_path, f'{provider_name}.yaml') + yaml_path = os.path.join(current_path, f"{provider_name}.yaml") yaml_data = load_yaml_file(yaml_path) - + try: # yaml_data to entity provider_schema = ProviderEntity(**yaml_data) except Exception as e: - raise Exception(f'Invalid provider schema for {provider_name}: {str(e)}') + raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}") # cache schema self.provider_schema = provider_schema - + return provider_schema def models(self, model_type: ModelType) -> list[AIModelEntity]: @@ -92,15 +92,15 @@ class ModelProvider(ABC): # get the path of the model type classes base_path = os.path.abspath(__file__) - model_type_name = model_type.value.replace('-', '_') + model_type_name = model_type.value.replace("-", "_") model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name) - model_type_py_path = os.path.join(model_type_path, f'{model_type_name}.py') + model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py") if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path): - raise Exception(f'Invalid model type {model_type} for provider {provider_name}') + raise Exception(f"Invalid model type {model_type} for provider {provider_name}") # Dynamic loading {model_type_name}.py file and find the subclass of AIModel - parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) + parent_module = ".".join(self.__class__.__module__.split(".")[:-1]) mod = import_module_from_source( module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path ) diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/core/model_runtime/model_providers/__base/moderation_model.py index 2b17f292c5..d04414ccb8 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/core/model_runtime/model_providers/__base/moderation_model.py @@ -12,14 +12,13 @@ class ModerationModel(AIModel): """ Model class for moderation model. """ + model_type: ModelType = ModelType.MODERATION # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -37,9 +36,7 @@ class ModerationModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke large language model @@ -50,4 +47,3 @@ class ModerationModel(AIModel): :return: false if text is safe, true otherwise """ raise NotImplementedError - diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 2c86f25180..5fb9604742 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -11,12 +11,19 @@ class RerankModel(AIModel): """ Base Model class for rerank model. """ + model_type: ModelType = ModelType.RERANK - def invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -37,10 +44,16 @@ class RerankModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/core/model_runtime/model_providers/__base/speech2text_model.py index 4fb11025fe..b6b0b73743 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/core/model_runtime/model_providers/__base/speech2text_model.py @@ -12,14 +12,13 @@ class Speech2TextModel(AIModel): """ Model class for speech2text model. """ + model_type: ModelType = ModelType.SPEECH2TEXT # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -35,9 +34,7 @@ class Speech2TextModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -59,4 +56,4 @@ class Speech2TextModel(AIModel): current_dir = os.path.dirname(os.path.abspath(__file__)) # Construct the path to the audio file - return os.path.join(current_dir, 'audio.mp3') + return os.path.join(current_dir, "audio.mp3") diff --git a/api/core/model_runtime/model_providers/__base/text2img_model.py b/api/core/model_runtime/model_providers/__base/text2img_model.py index e0f1adb1c4..a5810e2f0e 100644 --- a/api/core/model_runtime/model_providers/__base/text2img_model.py +++ b/api/core/model_runtime/model_providers/__base/text2img_model.py @@ -11,14 +11,15 @@ class Text2ImageModel(AIModel): """ Model class for text2img model. """ + model_type: ModelType = ModelType.TEXT2IMG # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, prompt: str, - model_parameters: dict, user: Optional[str] = None) \ - -> list[IO[bytes]]: + def invoke( + self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None + ) -> list[IO[bytes]]: """ Invoke Text2Image model @@ -36,9 +37,9 @@ class Text2ImageModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, prompt: str, - model_parameters: dict, user: Optional[str] = None) \ - -> list[IO[bytes]]: + def _invoke( + self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None + ) -> list[IO[bytes]]: """ Invoke Text2Image model diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 381d2f6cd1..54a4486023 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -13,14 +13,15 @@ class TextEmbeddingModel(AIModel): """ Model class for text embedding model. """ + model_type: ModelType = ModelType.TEXT_EMBEDDING # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke large language model @@ -38,9 +39,9 @@ class TextEmbeddingModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke large language model diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 6059b3f561..5fe6dda6ad 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -7,27 +7,28 @@ from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer _tokenizer = None _lock = Lock() + class GPT2Tokenizer: @staticmethod def _get_num_tokens_by_gpt2(text: str) -> int: """ - use gpt2 tokenizer to get num tokens + use gpt2 tokenizer to get num tokens """ _tokenizer = GPT2Tokenizer.get_encoder() tokens = _tokenizer.encode(text, verbose=False) return len(tokens) - + @staticmethod def get_num_tokens(text: str) -> int: return GPT2Tokenizer._get_num_tokens_by_gpt2(text) - + @staticmethod def get_encoder() -> Any: global _tokenizer, _lock with _lock: if _tokenizer is None: base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'gpt2') + gpt2_tokenizer_path = join(dirname(base_path), "gpt2") _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) - return _tokenizer \ No newline at end of file + return _tokenizer diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 2dfd323a47..70be9322a7 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -15,13 +15,15 @@ class TTSModel(AIModel): """ Model class for TTS model. """ + model_type: ModelType = ModelType.TTS # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ Invoke large language model @@ -35,14 +37,21 @@ class TTSModel(AIModel): :return: translated audio file """ try: - return self._invoke(model=model, credentials=credentials, user=user, - content_text=content_text, voice=voice, tenant_id=tenant_id) + return self._invoke( + model=model, + credentials=credentials, + user=user, + content_text=content_text, + voice=voice, + tenant_id=tenant_id, + ) except Exception as e: raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ Invoke large language model @@ -71,10 +80,13 @@ class TTSModel(AIModel): if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties: voices = model_schema.model_properties[ModelPropertyKey.VOICES] if language: - return [{'name': d['name'], 'value': d['mode']} for d in voices if - language and language in d.get('language')] + return [ + {"name": d["name"], "value": d["mode"]} + for d in voices + if language and language in d.get("language") + ] else: - return [{'name': d['name'], 'value': d['mode']} for d in voices] + return [{"name": d["name"], "value": d["mode"]} for d in voices] def _get_model_default_voice(self, model: str, credentials: dict) -> any: """ @@ -123,23 +135,23 @@ class TTSModel(AIModel): return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] @staticmethod - def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'): + def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"): match = re.compile(pattern) tx = match.finditer(org_text) start = 0 result = [] - one_sentence = '' + one_sentence = "" for i in tx: end = i.regs[0][1] tmp = org_text[start:end] if len(one_sentence + tmp) > max_length: result.append(one_sentence) - one_sentence = '' + one_sentence = "" one_sentence += tmp start = end last_sens = org_text[start:] if last_sens: one_sentence += last_sens - if one_sentence != '': + if one_sentence != "": result.append(one_sentence) return result diff --git a/api/core/model_runtime/model_providers/anthropic/anthropic.py b/api/core/model_runtime/model_providers/anthropic/anthropic.py index 325c6c060e..5b12f04a3e 100644 --- a/api/core/model_runtime/model_providers/anthropic/anthropic.py +++ b/api/core/model_runtime/model_providers/anthropic/anthropic.py @@ -20,12 +20,9 @@ class AnthropicProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `claude-3-opus-20240229` model for validate, - model_instance.validate_credentials( - model='claude-3-opus-20240229', - credentials=credentials - ) + model_instance.validate_credentials(model="claude-3-opus-20240229", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 81be1a06a7..30e9d2e9f2 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -55,11 +55,17 @@ if you are not sure about the structure. class AnthropicLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -76,10 +82,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # invoke model return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -96,41 +109,39 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): credentials_kwargs = self._to_credential_kwargs(credentials) # transform model parameters from completion api of anthropic to chat api - if 'max_tokens_to_sample' in model_parameters: - model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample') + if "max_tokens_to_sample" in model_parameters: + model_parameters["max_tokens"] = model_parameters.pop("max_tokens_to_sample") # init model client client = Anthropic(**credentials_kwargs) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop if user: - extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user) + extra_model_kwargs["metadata"] = completion_create_params.Metadata(user_id=user) system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages) if system: - extra_model_kwargs['system'] = system + extra_model_kwargs["system"] = system # Add the new header for claude-3-5-sonnet-20240620 model extra_headers = {} if model == "claude-3-5-sonnet-20240620": - if model_parameters.get('max_tokens') > 4096: + if model_parameters.get("max_tokens") > 4096: extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" if tools: - extra_model_kwargs['tools'] = [ - self._transform_tool_prompt(tool) for tool in tools - ] + extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] response = client.beta.tools.messages.create( model=model, messages=prompt_message_dicts, stream=stream, extra_headers=extra_headers, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) else: # chat model @@ -140,22 +151,30 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): stream=stream, extra_headers=extra_headers, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_chat_generate_response(model, credentials, response, prompt_messages) - - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if model_parameters.get('response_format'): + if model_parameters.get("response_format"): stop = stop or [] # chat model self._transform_chat_json_prompts( @@ -167,24 +186,27 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict: - return { - 'name': tool.name, - 'description': tool.description, - 'input_schema': tool.parameters - } + return {"name": tool.name, "description": tool.description, "input_schema": tool.parameters} - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -197,22 +219,30 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -228,9 +258,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): tokens = client.count_tokens(prompt) tool_call_inner_prompts_tokens_map = { - 'claude-3-opus-20240229': 395, - 'claude-3-haiku-20240307': 264, - 'claude-3-sonnet-20240229': 159 + "claude-3-opus-20240229": 395, + "claude-3-haiku-20240307": 264, + "claude-3-sonnet-20240229": 159, } if model in tool_call_inner_prompts_tokens_map and tools: @@ -257,13 +287,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): "temperature": 0, "max_tokens": 20, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage], - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: Union[Message, ToolsBetaMessage], + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm chat response @@ -274,22 +309,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content='', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content="", tool_calls=[]) for content in response.content: - if content.type == 'text': + if content.type == "text": assistant_prompt_message.content += content.text - elif content.type == 'tool_use': + elif content.type == "tool_use": tool_call = AssistantPromptMessage.ToolCall( id=content.id, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=content.name, - arguments=json.dumps(content.input) - ) + name=content.name, arguments=json.dumps(content.input) + ), ) assistant_prompt_message.tool_calls.append(tool_call) @@ -308,17 +339,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # transform response response = LLMResult( - model=response.model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Stream[MessageStreamEvent], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response( + self, model: str, credentials: dict, response: Stream[MessageStreamEvent], prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm chat stream response @@ -327,7 +355,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -338,24 +366,23 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): for chunk in response: if isinstance(chunk, MessageStartEvent): - if hasattr(chunk, 'content_block'): + if hasattr(chunk, "content_block"): content_block = chunk.content_block if isinstance(content_block, dict): - if content_block.get('type') == 'tool_use': + if content_block.get("type") == "tool_use": tool_call = AssistantPromptMessage.ToolCall( - id=content_block.get('id'), - type='function', + id=content_block.get("id"), + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=content_block.get('name'), - arguments='' - ) + name=content_block.get("name"), arguments="" + ), ) tool_calls.append(tool_call) - elif hasattr(chunk, 'delta'): + elif hasattr(chunk, "delta"): delta = chunk.delta if isinstance(delta, dict) and len(tool_calls) > 0: - if delta.get('type') == 'input_json_delta': - tool_calls[-1].function.arguments += delta.get('partial_json', '') + if delta.get("type") == "input_json_delta": + tool_calls[-1].function.arguments += delta.get("partial_json", "") elif chunk.message: return_model = chunk.message.model input_tokens = chunk.message.usage.input_tokens @@ -369,29 +396,24 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # transform empty tool call arguments to {} for tool_call in tool_calls: if not tool_call.function.arguments: - tool_call.function.arguments = '{}' + tool_call.function.arguments = "{}" yield LLMResultChunk( model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index + 1, - message=AssistantPromptMessage( - content='', - tool_calls=tool_calls - ), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else '' + chunk_text = chunk.delta.text if chunk.delta.text else "" full_assistant_content += chunk_text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=chunk_text - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk_text) index = chunk.index @@ -401,7 +423,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=chunk.index, message=assistant_prompt_message, - ) + ), ) def _to_credential_kwargs(self, credentials: dict) -> dict: @@ -412,14 +434,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: """ credentials_kwargs = { - "api_key": credentials['anthropic_api_key'], + "api_key": credentials["anthropic_api_key"], "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } - if credentials.get('anthropic_api_url'): - credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/') - credentials_kwargs['base_url'] = credentials['anthropic_api_url'] + if credentials.get("anthropic_api_url"): + credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip("/") + credentials_kwargs["base_url"] = credentials["anthropic_api_url"] return credentials_kwargs @@ -452,10 +474,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -465,25 +484,25 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): image_content = requests.get(message_content.data).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: - raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") + raise ValueError( + f"Failed to fetch image data from url {message_content.data}, {ex}" + ) else: data_split = message_content.data.split(";base64,") mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": base64_data - } + "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) prompt_message_dicts.append({"role": "user", "content": sub_messages}) @@ -492,34 +511,28 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): content = [] if message.tool_calls: for tool_call in message.tool_calls: - content.append({ - "type": "tool_use", - "id": tool_call.id, - "name": tool_call.function.name, - "input": json.loads(tool_call.function.arguments) - }) + content.append( + { + "type": "tool_use", + "id": tool_call.id, + "name": tool_call.function.name, + "input": json.loads(tool_call.function.arguments), + } + ) if message.content: - content.append({ - "type": "text", - "text": message.content - }) - + content.append({"type": "text", "text": message.content}) + if prompt_message_dicts[-1]["role"] == "assistant": prompt_message_dicts[-1]["content"].extend(content) else: - prompt_message_dicts.append({ - "role": "assistant", - "content": content - }) + prompt_message_dicts.append({"role": "assistant", "content": content}) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] + "content": [ + {"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content} + ], } prompt_message_dicts.append(message_dict) else: @@ -576,16 +589,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: - return '' + return "" messages = messages.copy() # don't mutate the original list if not isinstance(messages[-1], AssistantPromptMessage): messages.append(AssistantPromptMessage(content="")) - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -601,24 +611,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - anthropic.APIConnectionError, - anthropic.APITimeoutError - ], - InvokeServerUnavailableError: [ - anthropic.InternalServerError - ], - InvokeRateLimitError: [ - anthropic.RateLimitError - ], - InvokeAuthorizationError: [ - anthropic.AuthenticationError, - anthropic.PermissionDeniedError - ], + InvokeConnectionError: [anthropic.APIConnectionError, anthropic.APITimeoutError], + InvokeServerUnavailableError: [anthropic.InternalServerError], + InvokeRateLimitError: [anthropic.RateLimitError], + InvokeAuthorizationError: [anthropic.AuthenticationError, anthropic.PermissionDeniedError], InvokeBadRequestError: [ anthropic.BadRequestError, anthropic.NotFoundError, anthropic.UnprocessableEntityError, - anthropic.APIError - ] + anthropic.APIError, + ], } diff --git a/api/core/model_runtime/model_providers/azure_openai/_common.py b/api/core/model_runtime/model_providers/azure_openai/_common.py index 31c788d226..32a0269af4 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_common.py +++ b/api/core/model_runtime/model_providers/azure_openai/_common.py @@ -15,10 +15,10 @@ from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPEN class _CommonAzureOpenAI: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: - api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION) + api_version = credentials.get("openai_api_version", AZURE_OPENAI_API_VERSION) credentials_kwargs = { - "api_key": credentials['openai_api_key'], - "azure_endpoint": credentials['openai_api_base'], + "api_key": credentials["openai_api_key"], + "azure_endpoint": credentials["openai_api_base"], "api_version": api_version, "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, @@ -29,24 +29,14 @@ class _CommonAzureOpenAI: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - openai.APIConnectionError, - openai.APITimeoutError - ], - InvokeServerUnavailableError: [ - openai.InternalServerError - ], - InvokeRateLimitError: [ - openai.RateLimitError - ], - InvokeAuthorizationError: [ - openai.AuthenticationError, - openai.PermissionDeniedError - ], + InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError], + InvokeServerUnavailableError: [openai.InternalServerError], + InvokeRateLimitError: [openai.RateLimitError], + InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError], InvokeBadRequestError: [ openai.BadRequestError, openai.NotFoundError, openai.UnprocessableEntityError, - openai.APIError - ] + openai.APIError, + ], } diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index f4f7d964ef..c2744691c3 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -14,11 +14,12 @@ from core.model_runtime.entities.model_entities import ( PriceConfig, ) -AZURE_OPENAI_API_VERSION = '2024-02-15-preview' +AZURE_OPENAI_API_VERSION = "2024-02-15-preview" + def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule: rule = ParameterRule( - name='max_tokens', + name="max_tokens", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS], ) rule.default = default @@ -34,11 +35,11 @@ class AzureBaseModel(BaseModel): LLM_BASE_MODELS = [ AzureBaseModel( - base_model_name='gpt-35-turbo', + base_model_name="gpt-35-turbo", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -53,51 +54,47 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.0005, output=0.0015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-16k', + base_model_name="gpt-35-turbo-16k", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -112,37 +109,37 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), - _get_max_tokens(default=512, min_val=1, max_val=16385) + _get_max_tokens(default=512, min_val=1, max_val=16385), ], pricing=PriceConfig( input=0.003, output=0.004, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-0125', + base_model_name="gpt-35-turbo-0125", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -157,51 +154,47 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.0005, output=0.0015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4', + base_model_name="gpt-4", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -216,32 +209,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=8192), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -249,34 +239,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.03, output=0.06, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-32k', + base_model_name="gpt-4-32k", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -291,32 +277,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=32768), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -324,34 +307,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.06, output=0.12, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-0125-preview', + base_model_name="gpt-4-0125-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -366,32 +345,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -399,34 +375,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-1106-preview', + base_model_name="gpt-4-1106-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -441,32 +413,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -474,34 +443,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-mini', + base_model_name="gpt-4o-mini", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -517,32 +482,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=16384), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -550,34 +512,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.150, output=0.600, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-mini-2024-07-18', + base_model_name="gpt-4o-mini-2024-07-18", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -593,32 +551,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=16384), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -626,46 +581,40 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object', 'json_schema'] + options=["text", "json_object", "json_schema"], ), ParameterRule( - name='json_schema', - label=I18nObject( - en_US='JSON Schema' - ), - type='text', + name="json_schema", + label=I18nObject(en_US="JSON Schema"), + type="text", help=I18nObject( - zh_Hans='设置返回的json schema,llm将按照它返回', - en_US='Set a response json schema will ensure LLM to adhere it.' + zh_Hans="设置返回的json schema,llm将按照它返回", + en_US="Set a response json schema will ensure LLM to adhere it.", ), - required=False + required=False, ), ], pricing=PriceConfig( input=0.150, output=0.600, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o', + base_model_name="gpt-4o", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -681,32 +630,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -714,34 +660,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-2024-05-13', + base_model_name="gpt-4o-2024-05-13", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -757,32 +699,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -790,34 +729,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-2024-08-06', + base_model_name="gpt-4o-2024-08-06", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -833,32 +768,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -866,46 +798,40 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object', 'json_schema'] + options=["text", "json_object", "json_schema"], ), ParameterRule( - name='json_schema', - label=I18nObject( - en_US='JSON Schema' - ), - type='text', + name="json_schema", + label=I18nObject(en_US="JSON Schema"), + type="text", help=I18nObject( - zh_Hans='设置返回的json schema,llm将按照它返回', - en_US='Set a response json schema will ensure LLM to adhere it.' + zh_Hans="设置返回的json schema,llm将按照它返回", + en_US="Set a response json schema will ensure LLM to adhere it.", ), - required=False + required=False, ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-turbo', + base_model_name="gpt-4-turbo", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -921,32 +847,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -954,34 +877,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-turbo-2024-04-09', + base_model_name="gpt-4-turbo-2024-04-09", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -997,32 +916,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -1030,39 +946,33 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-vision-preview', + base_model_name="gpt-4-vision-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, - features=[ - ModelFeature.VISION - ], + features=[ModelFeature.VISION], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ ModelPropertyKey.MODE: LLMMode.CHAT.value, @@ -1070,32 +980,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -1103,34 +1010,30 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-instruct', + base_model_name="gpt-35-turbo-instruct", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, @@ -1140,19 +1043,19 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), @@ -1161,16 +1064,16 @@ LLM_BASE_MODELS = [ input=0.0015, output=0.002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-davinci-003', + base_model_name="text-davinci-003", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, @@ -1180,19 +1083,19 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), @@ -1201,20 +1104,18 @@ LLM_BASE_MODELS = [ input=0.02, output=0.02, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] EMBEDDING_BASE_MODELS = [ AzureBaseModel( - base_model_name='text-embedding-ada-002', + base_model_name="text-embedding-ada-002", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1224,17 +1125,15 @@ EMBEDDING_BASE_MODELS = [ pricing=PriceConfig( input=0.0001, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-embedding-3-small', + base_model_name="text-embedding-3-small", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1244,17 +1143,15 @@ EMBEDDING_BASE_MODELS = [ pricing=PriceConfig( input=0.00002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-embedding-3-large', + base_model_name="text-embedding-3-large", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1264,135 +1161,129 @@ EMBEDDING_BASE_MODELS = [ pricing=PriceConfig( input=0.00013, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] SPEECH2TEXT_BASE_MODELS = [ AzureBaseModel( - base_model_name='whisper-1', + base_model_name="whisper-1", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, model_properties={ ModelPropertyKey.FILE_UPLOAD_LIMIT: 25, - ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm' - } - ) + ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: "flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm", + }, + ), ) ] TTS_BASE_MODELS = [ AzureBaseModel( - base_model_name='tts-1', + base_model_name="tts-1", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={ - ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.DEFAULT_VOICE: "alloy", ModelPropertyKey.VOICES: [ { - 'mode': 'alloy', - 'name': 'Alloy', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "alloy", + "name": "Alloy", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'echo', - 'name': 'Echo', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "echo", + "name": "Echo", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'fable', - 'name': 'Fable', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "fable", + "name": "Fable", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'onyx', - 'name': 'Onyx', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "onyx", + "name": "Onyx", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'nova', - 'name': 'Nova', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "nova", + "name": "Nova", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'shimmer', - 'name': 'Shimmer', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "shimmer", + "name": "Shimmer", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDIO_TYPE: 'mp3', - ModelPropertyKey.MAX_WORKERS: 5 + ModelPropertyKey.AUDIO_TYPE: "mp3", + ModelPropertyKey.MAX_WORKERS: 5, }, pricing=PriceConfig( input=0.015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='tts-1-hd', + base_model_name="tts-1-hd", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={ - ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.DEFAULT_VOICE: "alloy", ModelPropertyKey.VOICES: [ { - 'mode': 'alloy', - 'name': 'Alloy', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "alloy", + "name": "Alloy", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'echo', - 'name': 'Echo', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "echo", + "name": "Echo", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'fable', - 'name': 'Fable', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "fable", + "name": "Fable", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'onyx', - 'name': 'Onyx', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "onyx", + "name": "Onyx", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'nova', - 'name': 'Nova', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "nova", + "name": "Nova", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'shimmer', - 'name': 'Shimmer', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "shimmer", + "name": "Shimmer", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDIO_TYPE: 'mp3', - ModelPropertyKey.MAX_WORKERS: 5 + ModelPropertyKey.AUDIO_TYPE: "mp3", + ModelPropertyKey.MAX_WORKERS: 5, }, pricing=PriceConfig( input=0.03, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.py b/api/core/model_runtime/model_providers/azure_openai/azure_openai.py index 68977b2266..2e3c6aab05 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.py +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class AzureOpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index c0c782e42b..2ad2289869 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -34,16 +34,20 @@ logger = logging.getLogger(__name__) class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - - base_model_name = credentials.get('base_model_name') + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise ValueError('Base Model Name is required') + raise ValueError("Base Model Name is required") ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: @@ -56,7 +60,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: # text completion model @@ -67,7 +71,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) def get_num_tokens( @@ -75,14 +79,14 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ) -> int: - base_model_name = credentials.get('base_model_name') + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise ValueError('Base Model Name is required') + raise ValueError("Base Model Name is required") model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not model_entity: - raise ValueError(f'Base Model Name {base_model_name} is invalid') + raise ValueError(f"Base Model Name {base_model_name} is invalid") model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE) if model_mode == LLMMode.CHAT.value: @@ -92,21 +96,21 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # text completion model, do not support tool calling content = prompt_messages[0].content assert isinstance(content, str) - return self._num_tokens_from_string(credentials,content) + return self._num_tokens_from_string(credentials, content) def validate_credentials(self, model: str, credentials: dict) -> None: - if 'openai_api_base' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") - if 'openai_api_key' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API key is required') + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") - if 'base_model_name' not in credentials: - raise CredentialsValidateFailedError('Base Model Name is required') + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") - base_model_name = credentials.get('base_model_name') + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise CredentialsValidateFailedError('Base Model Name is required') + raise CredentialsValidateFailedError("Base Model Name is required") ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not ai_model_entity: @@ -118,7 +122,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: # chat model client.chat.completions.create( - messages=[{"role": "user", "content": 'ping'}], + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=20, @@ -127,7 +131,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): else: # text completion model client.completions.create( - prompt='ping', + prompt="ping", model=model, temperature=0, max_tokens=20, @@ -137,33 +141,35 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - base_model_name = credentials.get('base_model_name') + base_model_name = credentials.get("base_model_name") if not base_model_name: - raise ValueError('Base Model Name is required') + raise ValueError("Base Model Name is required") ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) return ai_model_entity.entity if ai_model_entity else None - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # text completion model response = client.completions.create( - prompt=prompt_messages[0].content, - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs + prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -172,15 +178,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) def _handle_generate_response( - self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] ): assistant_text = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens if response.usage: @@ -209,24 +212,21 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return result def _handle_generate_stream_response( - self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] ) -> Generator: - full_text = '' + full_text = "" for chunk in response: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.text is None or delta.text == ''): + if delta.finish_reason is None and (delta.text is None or delta.text == ""): continue # transform assistant message to prompt message - text = delta.text if delta.text else '' - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + text = delta.text if delta.text else "" + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -254,8 +254,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( @@ -265,14 +265,20 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) response_format = model_parameters.get("response_format") @@ -293,7 +299,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] + extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] # extra_model_kwargs['functions'] = [{ # "name": tool.name, # "description": tool.description, @@ -301,10 +307,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # } for tool in tools] if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # chat model messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] @@ -322,9 +328,12 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) def _handle_chat_generate_response( - self, model: str, credentials: dict, response: ChatCompletion, + self, + model: str, + credentials: dict, + response: ChatCompletion, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ): assistant_message = response.choices[0].message assistant_message_tool_calls = assistant_message.tool_calls @@ -334,10 +343,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -369,13 +375,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): credentials: dict, response: Stream[ChatCompletionChunk], prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ): index = 0 - full_assistant_content = '' + full_assistant_content = "" real_model = model system_fingerprint = None - completion = '' + completion = "" tool_calls = [] for chunk in response: if len(chunk.choices) == 0: @@ -386,7 +392,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if delta.delta is None: continue - # extract tool calls from response self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls) @@ -396,15 +401,14 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" real_model = chunk.model system_fingerprint = chunk.system_fingerprint - completion += delta.delta.content if delta.delta.content else '' + completion += delta.delta.content if delta.delta.content else "" yield LLMResultChunk( model=real_model, @@ -413,7 +417,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 0 @@ -421,9 +425,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # calculate num tokens prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) - full_assistant_prompt_message = AssistantPromptMessage( - content=completion - ) + full_assistant_prompt_message = AssistantPromptMessage(content=completion) completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message]) # transform usage @@ -434,27 +436,24 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): prompt_messages=prompt_messages, system_fingerprint=system_fingerprint, delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage(content=''), - finish_reason='stop', - usage=usage - ) + index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage + ), ) @staticmethod - def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None: + def _update_tool_calls( + tool_calls: list[AssistantPromptMessage.ToolCall], + tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]], + ) -> None: if tool_calls_response: for response_tool_call in tool_calls_response: if isinstance(response_tool_call, ChatCompletionMessageToolCall): function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) elif isinstance(response_tool_call, ChoiceDeltaToolCall): @@ -463,8 +462,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): tool_calls[index].id = response_tool_call.id or tool_calls[index].id tool_calls[index].type = response_tool_call.type or tool_calls[index].type if response_tool_call.function: - tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name - tool_calls[index].function.arguments += response_tool_call.function.arguments or '' + tool_calls[index].function.name = ( + response_tool_call.function.name or tool_calls[index].function.name + ) + tool_calls[index].function.arguments += response_tool_call.function.arguments or "" else: assert response_tool_call.id is not None assert response_tool_call.type is not None @@ -473,13 +474,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): assert response_tool_call.function.arguments is not None function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) @@ -495,19 +493,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -525,7 +517,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): "role": "tool", "name": message.name, "content": message.content, - "tool_call_id": message.tool_call_id + "tool_call_id": message.tool_call_id, } else: raise ValueError(f"Got unknown type {message}") @@ -535,10 +527,11 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, credentials: dict, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None + ) -> int: try: - encoding = tiktoken.encoding_for_model(credentials['base_model_name']) + encoding = tiktoken.encoding_for_model(credentials["base_model_name"]) except KeyError: encoding = tiktoken.get_encoding("cl100k_base") @@ -550,14 +543,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return num_tokens def _num_tokens_from_messages( - self, credentials: dict, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - model = credentials['base_model_name'] + model = credentials["base_model_name"] try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -591,10 +583,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -626,40 +618,39 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): @staticmethod def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int: - num_tokens = 0 for tool in tools: - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode('function')) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) # calculate num tokens for function object - num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode("name")) num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode("description")) num_tokens += len(encoding.encode(tool.description)) parameters = tool.parameters - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) - num_tokens += len(encoding.encode(parameters['title'])) - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode(parameters['type'])) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters['properties'].items(): + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) + num_tokens += len(encoding.encode(parameters["title"])) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode(parameters["type"])) + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters["properties"].items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(encoding.encode(enum_field)) else: num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(encoding.encode(required_field)) diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py index 8aebcb90e4..a2b14cf3db 100644 --- a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -15,9 +15,7 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -40,7 +38,7 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -65,10 +63,9 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): return response.text def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: for ai_model_entity in SPEECH2TEXT_BASE_MODELS: diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index e073bef014..d9cff8ecbb 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -16,19 +16,18 @@ from core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: - base_model_name = credentials['base_model_name'] + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + base_model_name = credentials["base_model_name"] credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'base64' + extra_model_kwargs["encoding_format"] = "base64" context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -44,11 +43,9 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): enc = tiktoken.get_encoding("cl100k_base") for i, text in enumerate(texts): - token = enc.encode( - text - ) + token = enc.encode(text) for j in range(0, len(token), context_size): - tokens += [token[j: j + context_size]] + tokens += [token[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -56,10 +53,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): for i in _iter: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts=tokens[i: i + max_chunks], - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -75,10 +69,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts="", - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -88,24 +79,16 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=base_model_name - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=base_model_name) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: if len(texts) == 0: return 0 try: - enc = tiktoken.encoding_for_model(credentials['base_model_name']) + enc = tiktoken.encoding_for_model(credentials["base_model_name"]) except KeyError: enc = tiktoken.get_encoding("cl100k_base") @@ -118,57 +101,52 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): return total_num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - if 'openai_api_base' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") - if 'openai_api_key' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API key is required') + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") - if 'base_model_name' not in credentials: - raise CredentialsValidateFailedError('Base Model Name is required') + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") - if not self._get_ai_model_entity(credentials['base_model_name'], model): + if not self._get_ai_model_entity(credentials["base_model_name"], model): raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') try: credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity @staticmethod - def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + def _embedding_invoke( + model: str, client: AzureOpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: response = client.embeddings.create( input=texts, model=model, **extra_model_kwargs, ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": # decode base64 embedding - return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], - response.usage.total_tokens) + return ( + [list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], + response.usage.total_tokens, + ) return [data.embedding for data in response.data], response.usage.total_tokens def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,7 +157,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index f9ddd86f68..bbad726467 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -17,8 +17,9 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: """ _invoke text2speech model @@ -30,13 +31,12 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): :param user: unique user id :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -50,14 +50,13 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model :param model: model name @@ -75,23 +74,29 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): if len(content_text) > max_length: sentences = self._split_text_into_sentences(content_text, max_length=max_length) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model, - response_format="mp3", - input=sentences[i], voice=voice) for i in range(len(sentences))] + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, + ) + for i in range(len(sentences)) + ] for index, future in enumerate(futures): yield from future.result().__enter__().iter_bytes(1024) else: - response = client.audio.speech.with_streaming_response.create(model=model, voice=voice, - response_format="mp3", - input=content_text.strip()) + response = client.audio.speech.with_streaming_response.create( + model=model, voice=voice, response_format="mp3", input=content_text.strip() + ) yield from response.__enter__().iter_bytes(1024) except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, - voice, credentials: dict): + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api @@ -108,10 +113,9 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): return response.read() def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None: for ai_model_entity in TTS_BASE_MODELS: diff --git a/api/core/model_runtime/model_providers/baichuan/baichuan.py b/api/core/model_runtime/model_providers/baichuan/baichuan.py index 71bd6b5d92..626fc811cf 100644 --- a/api/core/model_runtime/model_providers/baichuan/baichuan.py +++ b/api/core/model_runtime/model_providers/baichuan/baichuan.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class BaichuanProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ class BaichuanProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `baichuan2-turbo` model for validate, - model_instance.validate_credentials( - model='baichuan2-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="baichuan2-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py index 7549b2fb60..bea6777f83 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py @@ -4,17 +4,17 @@ import re class BaichuanTokenizer: @classmethod def count_chinese_characters(cls, text: str) -> int: - return len(re.findall(r'[\u4e00-\u9fa5]', text)) + return len(re.findall(r"[\u4e00-\u9fa5]", text)) @classmethod def count_english_vocabularies(cls, text: str) -> int: # remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc. - text = re.sub(r'[^a-zA-Z0-9\s]', '', text) + text = re.sub(r"[^a-zA-Z0-9\s]", "", text) # count the number of words not characters return len(text.split()) - + @classmethod def _get_num_tokens(cls, text: str) -> int: # tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return) # https://platform.baichuan-ai.com/docs/text-Embedding - return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) \ No newline at end of file + return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index a8fd9dce91..6e181ac5f8 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -94,7 +94,6 @@ class BaichuanModel: timeout: int, tools: Optional[list[PromptMessageTool]] = None, ) -> Union[Iterator, dict]: - if model in self._model_mapping.keys(): api_base = "https://api.baichuan-ai.com/v1/chat/completions" else: @@ -120,9 +119,7 @@ class BaichuanModel: err = resp["error"]["type"] msg = resp["error"]["message"] except Exception as e: - raise InternalServerError( - f"Failed to convert response to json: {e} with text: {response.text}" - ) + raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") if err == "invalid_api_key": raise InvalidAPIKeyError(msg) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py index 67d76b4a29..4e56e58d7e 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalance(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index 36c7003d1b..3291fe2b2e 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -38,17 +38,16 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor class BaichuanLanguageModel(LargeLanguageModel): - def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: bool = True, - user: str | None = None, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, ) -> LLMResult | Generator: return self._generate( model=model, @@ -60,17 +59,17 @@ class BaichuanLanguageModel(LargeLanguageModel): ) def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, ) -> int: return self._num_tokens_from_messages(prompt_messages) def _num_tokens_from_messages( - self, - messages: list[PromptMessage], + self, + messages: list[PromptMessage], ) -> int: """Calculate num tokens for baichuan model""" @@ -111,18 +110,13 @@ class BaichuanLanguageModel(LargeLanguageModel): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - message_dict["tool_calls"] = [tool_call.dict() for tool_call in - message.tool_calls] + message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = { - "role": "tool", - "content": message.content, - "tool_call_id": message.tool_call_id - } + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} else: raise ValueError(f"Unknown message type {type(message)}") @@ -146,15 +140,14 @@ class BaichuanLanguageModel(LargeLanguageModel): raise CredentialsValidateFailedError(f"Invalid API key: {e}") def _generate( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stream: bool = True, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stream: bool = True, ) -> LLMResult | Generator: - instance = BaichuanModel(api_key=credentials["api_key"]) messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] @@ -169,23 +162,19 @@ class BaichuanLanguageModel(LargeLanguageModel): ) if stream: - return self._handle_chat_generate_stream_response( - model, prompt_messages, credentials, response - ) + return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) - return self._handle_chat_generate_response( - model, prompt_messages, credentials, response - ) + return self._handle_chat_generate_response(model, prompt_messages, credentials, response) def _handle_chat_generate_response( - self, - model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: dict, + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: dict, ) -> LLMResult: choices = response.get("choices", []) - assistant_message = AssistantPromptMessage(content='', tool_calls=[]) + assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if choices and choices[0]["finish_reason"] == "tool_calls": for choice in choices: for tool_call in choice["message"]["tool_calls"]: @@ -194,7 +183,7 @@ class BaichuanLanguageModel(LargeLanguageModel): type=tool_call.get("type", ""), function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.get("function", {}).get("name", ""), - arguments=tool_call.get("function", {}).get("arguments", "") + arguments=tool_call.get("function", {}).get("arguments", ""), ), ) assistant_message.tool_calls.append(tool) @@ -228,11 +217,11 @@ class BaichuanLanguageModel(LargeLanguageModel): ) def _handle_chat_generate_stream_response( - self, - model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Iterator, + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Iterator, ) -> Generator: for line in response: if not line: @@ -260,9 +249,7 @@ class BaichuanLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=choice["delta"]["content"], tool_calls=[] - ), + message=AssistantPromptMessage(content=choice["delta"]["content"], tool_calls=[]), finish_reason=stop_reason, ), ) diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 81bd58e3ce..b7276fabb5 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -31,11 +31,12 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): """ Model class for BaiChuan text embedding model. """ - api_base: str = 'http://api.baichuan-ai.com/v1/embeddings' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "http://api.baichuan-ai.com/v1/embeddings" + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -45,28 +46,23 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] - if model != 'baichuan-text-embedding': - raise ValueError('Invalid model name') + api_key = credentials["api_key"] + if model != "baichuan-text-embedding": + raise ValueError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') - + raise CredentialsValidateFailedError("api_key is required") + # split into chunks of batch size 16 chunks = [] for i in range(0, len(texts), 16): - chunks.append(texts[i:i + 16]) + chunks.append(texts[i : i + 16]) embeddings = [] token_usage = 0 for chunk in chunks: # embedding chunk - chunk_embeddings, chunk_usage = self.embedding( - model=model, - api_key=api_key, - texts=chunk, - user=user - ) + chunk_embeddings, chunk_usage = self.embedding(model=model, api_key=api_key, texts=chunk, user=user) embeddings.extend(chunk_embeddings) token_usage += chunk_usage @@ -74,17 +70,14 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - - def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \ - -> tuple[list[list[float]], int]: + + def embedding( + self, model: str, api_key, texts: list[str], user: Optional[str] = None + ) -> tuple[list[list[float]], int]: """ Embed given texts @@ -95,56 +88,47 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): :return: embeddings result """ url = self.api_base - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': 'Baichuan-Text-Embedding', - 'input': texts - } + data = {"model": "Baichuan-Text-Embedding", "input": texts} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() # try to parse error message - err = resp['error']['code'] - msg = resp['error']['message'] + err = resp["error"]["code"] + msg = resp["error"]["message"] except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - if err == 'invalid_api_key': + if err == "invalid_api_key": raise InvalidAPIKeyError(msg) - elif err == 'insufficient_quota': + elif err == "insufficient_quota": raise InsufficientAccountBalance(msg) - elif err == 'invalid_authentication': - raise InvalidAuthenticationError(msg) - elif err and 'rate' in err: + elif err == "invalid_authentication": + raise InvalidAuthenticationError(msg) + elif err and "rate" in err: raise RateLimitReachedError(msg) - elif err and 'internal' in err: + elif err and "internal" in err: raise InternalServerError(msg) - elif err == 'api_key_empty': + elif err == "api_key_empty": raise InvalidAPIKeyError(msg) else: raise InternalServerError(f"Unknown error: {err} with message: {msg}") - + try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - return [ - data['embedding'] for data in embeddings - ], usage['total_tokens'] - + return [data["embedding"] for data in embeddings], usage["total_tokens"] def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -170,32 +154,24 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvalidAPIKeyError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalance, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -207,10 +183,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -221,7 +194,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.py b/api/core/model_runtime/model_providers/bedrock/bedrock.py index e99bc52ff8..1cfc1d199c 100644 --- a/api/core/model_runtime/model_providers/bedrock/bedrock.py +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class BedrockProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,13 +20,10 @@ class BedrockProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `amazon.titan-text-lite-v1` model by default for validating credentials - model_for_validation = credentials.get('model_for_validation', 'amazon.titan-text-lite-v1') - model_instance.validate_credentials( - model=model_for_validation, - credentials=credentials - ) + model_for_validation = credentials.get("model_for_validation", "amazon.titan-text-lite-v1") + model_instance.validate_credentials(model=model_for_validation, credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index c325ac3cec..a2a69b86bb 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -45,36 +45,42 @@ from core.model_runtime.model_providers.__base.large_language_model import Large logger = logging.getLogger(__name__) -class BedrockLargeLanguageModel(LargeLanguageModel): +class BedrockLargeLanguageModel(LargeLanguageModel): # please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html # TODO There is invoke issue: context limit on Cohere Model, will add them after fixed. - CONVERSE_API_ENABLED_MODEL_INFO=[ - {'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False}, - {'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False}, - {'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'cohere.command-r', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False} + CONVERSE_API_ENABLED_MODEL_INFO = [ + {"prefix": "anthropic.claude-v2", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "anthropic.claude-v1", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False}, + {"prefix": "mistral.mixtral-8x7b-instruct", "support_system_prompts": False, "support_tool_use": False}, + {"prefix": "mistral.mistral-large", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "mistral.mistral-small", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False}, ] @staticmethod def _find_model_info(model_id): for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO: - if model_id.startswith(model['prefix']): + if model_id.startswith(model["prefix"]): return model logger.info(f"current model id: {model_id} did not support by Converse API") return None - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -88,17 +94,28 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - - model_info= BedrockLargeLanguageModel._find_model_info(model) + + model_info = BedrockLargeLanguageModel._find_model_info(model) if model_info: - model_info['model'] = model + model_info["model"] = model # invoke models via boto3 converse API - return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools) + return self._generate_with_converse( + model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools + ) # invoke other models via boto3 client return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]: + def _generate_with_converse( + self, + model_info: dict, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + tools: Optional[list[PromptMessageTool]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model with converse API @@ -110,35 +127,39 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param stream: is stream response :return: full response or stream response chunk generator result """ - bedrock_client = boto3.client(service_name='bedrock-runtime', - aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - region_name=credentials["aws_region"]) + bedrock_client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=credentials.get("aws_access_key_id"), + aws_secret_access_key=credentials.get("aws_secret_access_key"), + region_name=credentials["aws_region"], + ) system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages) inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop) parameters = { - 'modelId': model_info['model'], - 'messages': prompt_message_dicts, - 'inferenceConfig': inference_config, - 'additionalModelRequestFields': additional_model_fields, + "modelId": model_info["model"], + "messages": prompt_message_dicts, + "inferenceConfig": inference_config, + "additionalModelRequestFields": additional_model_fields, } - if model_info['support_system_prompts'] and system and len(system) > 0: - parameters['system'] = system + if model_info["support_system_prompts"] and system and len(system) > 0: + parameters["system"] = system - if model_info['support_tool_use'] and tools: - parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools) + if model_info["support_tool_use"] and tools: + parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools) try: if stream: response = bedrock_client.converse_stream(**parameters) - return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages) + return self._handle_converse_stream_response( + model_info["model"], credentials, response, prompt_messages + ) else: response = bedrock_client.converse(**parameters) - return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages) + return self._handle_converse_response(model_info["model"], credentials, response, prompt_messages) except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: @@ -149,8 +170,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel): except Exception as ex: raise InvokeError(str(ex)) - def _handle_converse_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> LLMResult: + + def _handle_converse_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -160,36 +183,30 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: full response chunk generator result """ - response_content = response['output']['message']['content'] + response_content = response["output"]["message"]["content"] # transform assistant message to prompt message - if response['stopReason'] == 'tool_use': + if response["stopReason"] == "tool_use": tool_calls = [] text, tool_use = self._extract_tool_use(response_content) tool_call = AssistantPromptMessage.ToolCall( - id=tool_use['toolUseId'], - type='function', + id=tool_use["toolUseId"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_use['name'], - arguments=json.dumps(tool_use['input']) - ) + name=tool_use["name"], arguments=json.dumps(tool_use["input"]) + ), ) tool_calls.append(tool_call) - assistant_prompt_message = AssistantPromptMessage( - content=text, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=text, tool_calls=tool_calls) else: - assistant_prompt_message = AssistantPromptMessage( - content=response_content[0]['text'] - ) + assistant_prompt_message = AssistantPromptMessage(content=response_content[0]["text"]) # calculate num tokens - if response['usage']: + if response["usage"]: # transform usage - prompt_tokens = response['usage']['inputTokens'] - completion_tokens = response['usage']['outputTokens'] + prompt_tokens = response["usage"]["inputTokens"] + completion_tokens = response["usage"]["outputTokens"] else: # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -206,20 +223,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel): ) return result - def _extract_tool_use(self, content:dict)-> tuple[str, dict]: + def _extract_tool_use(self, content: dict) -> tuple[str, dict]: tool_use = {} - text = '' + text = "" for item in content: - if 'toolUse' in item: - tool_use = item['toolUse'] - elif 'text' in item: - text = item['text'] + if "toolUse" in item: + tool_use = item["toolUse"] + elif "text" in item: + text = item["text"] else: raise ValueError(f"Got unknown item: {item}") return text, tool_use - def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage], ) -> Generator: + def _handle_converse_stream_response( + self, + model: str, + credentials: dict, + response: dict, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -231,7 +253,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): """ try: - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -240,87 +262,85 @@ class BedrockLargeLanguageModel(LargeLanguageModel): tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_use = {} - for chunk in response['stream']: - if 'messageStart' in chunk: + for chunk in response["stream"]: + if "messageStart" in chunk: return_model = model - elif 'messageStop' in chunk: - finish_reason = chunk['messageStop']['stopReason'] - elif 'contentBlockStart' in chunk: - tool = chunk['contentBlockStart']['start']['toolUse'] - tool_use['toolUseId'] = tool['toolUseId'] - tool_use['name'] = tool['name'] - elif 'metadata' in chunk: - input_tokens = chunk['metadata']['usage']['inputTokens'] - output_tokens = chunk['metadata']['usage']['outputTokens'] + elif "messageStop" in chunk: + finish_reason = chunk["messageStop"]["stopReason"] + elif "contentBlockStart" in chunk: + tool = chunk["contentBlockStart"]["start"]["toolUse"] + tool_use["toolUseId"] = tool["toolUseId"] + tool_use["name"] = tool["name"] + elif "metadata" in chunk: + input_tokens = chunk["metadata"]["usage"]["inputTokens"] + output_tokens = chunk["metadata"]["usage"]["outputTokens"] usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens) yield LLMResultChunk( model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage( - content='', - tool_calls=tool_calls - ), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) - elif 'contentBlockDelta' in chunk: - delta = chunk['contentBlockDelta']['delta'] - if 'text' in delta: - chunk_text = delta['text'] if delta['text'] else '' + elif "contentBlockDelta" in chunk: + delta = chunk["contentBlockDelta"]["delta"] + if "text" in delta: + chunk_text = delta["text"] if delta["text"] else "" full_assistant_content += chunk_text assistant_prompt_message = AssistantPromptMessage( - content=chunk_text if chunk_text else '', + content=chunk_text if chunk_text else "", ) - index = chunk['contentBlockDelta']['contentBlockIndex'] + index = chunk["contentBlockDelta"]["contentBlockIndex"] yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index+1, + index=index + 1, message=assistant_prompt_message, - ) + ), ) - elif 'toolUse' in delta: - if 'input' not in tool_use: - tool_use['input'] = '' - tool_use['input'] += delta['toolUse']['input'] - elif 'contentBlockStop' in chunk: - if 'input' in tool_use: + elif "toolUse" in delta: + if "input" not in tool_use: + tool_use["input"] = "" + tool_use["input"] += delta["toolUse"]["input"] + elif "contentBlockStop" in chunk: + if "input" in tool_use: tool_call = AssistantPromptMessage.ToolCall( - id=tool_use['toolUseId'], - type='function', + id=tool_use["toolUseId"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_use['name'], - arguments=tool_use['input'] - ) + name=tool_use["name"], arguments=tool_use["input"] + ), ) tool_calls.append(tool_call) tool_use = {} except Exception as ex: raise InvokeError(str(ex)) - - def _convert_converse_api_model_parameters(self, model_parameters: dict, stop: Optional[list[str]] = None) -> tuple[dict, dict]: + + def _convert_converse_api_model_parameters( + self, model_parameters: dict, stop: Optional[list[str]] = None + ) -> tuple[dict, dict]: inference_config = {} additional_model_fields = {} - if 'max_tokens' in model_parameters: - inference_config['maxTokens'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters: + inference_config["maxTokens"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters: - inference_config['temperature'] = model_parameters['temperature'] - - if 'top_p' in model_parameters: - inference_config['topP'] = model_parameters['temperature'] + if "temperature" in model_parameters: + inference_config["temperature"] = model_parameters["temperature"] + + if "top_p" in model_parameters: + inference_config["topP"] = model_parameters["temperature"] if stop: - inference_config['stopSequences'] = stop - - if 'top_k' in model_parameters: - additional_model_fields['top_k'] = model_parameters['top_k'] - + inference_config["stopSequences"] = stop + + if "top_k" in model_parameters: + additional_model_fields["top_k"] = model_parameters["top_k"] + return inference_config, additional_model_fields def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: @@ -332,7 +352,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): prompt_message_dicts = [] for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content=message.content.strip() + message.content = message.content.strip() system.append({"text": message.content}) else: prompt_message_dicts.append(self._convert_prompt_message_to_dict(message)) @@ -349,15 +369,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): "toolSpec": { "name": tool.name, "description": tool.description, - "inputSchema": { - "json": tool.parameters - } + "inputSchema": {"json": tool.parameters}, } } ) tool_config["tools"] = configs return tool_config - + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ Convert PromptMessage to dict @@ -365,15 +383,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": "user", "content": [{'text': message.content}]} + message_dict = {"role": "user", "content": [{"text": message.content}]} else: sub_messages = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "text": message_content.data - } + sub_message_dict = {"text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -384,7 +400,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): image_content = requests.get(url).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") else: @@ -394,16 +410,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): image_content = base64.b64decode(base64_data) if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { - "image": { - "format": mime_type.replace('image/', ''), - "source": { - "bytes": image_content - } - } + "image": {"format": mime_type.replace("image/", ""), "source": {"bytes": image_content}} } sub_messages.append(sub_message_dict) @@ -412,36 +425,46 @@ class BedrockLargeLanguageModel(LargeLanguageModel): message = cast(AssistantPromptMessage, message) if message.tool_calls: message_dict = { - "role": "assistant", "content":[{ - "toolUse": { - "toolUseId": message.tool_calls[0].id, - "name": message.tool_calls[0].function.name, - "input": json.loads(message.tool_calls[0].function.arguments) + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": message.tool_calls[0].id, + "name": message.tool_calls[0].function.name, + "input": json.loads(message.tool_calls[0].function.arguments), + } } - }] + ], } else: - message_dict = {"role": "assistant", "content": [{'text': message.content}]} + message_dict = {"role": "assistant", "content": [{"text": message.content}]} elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = [{'text': message.content}] + message_dict = [{"text": message.content}] elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "toolResult": { - "toolUseId": message.tool_call_id, - "content": [{"json": {"text": message.content}}] - } - }] + "content": [ + { + "toolResult": { + "toolUseId": message.tool_call_id, + "content": [{"json": {"text": message.content}}], + } + } + ], } else: raise ValueError(f"Got unknown type {message}") return message_dict - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage] | str, + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -451,15 +474,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return:md = genai.GenerativeModel(model) """ - prefix = model.split('.')[0] - model_name = model.split('.')[1] - + prefix = model.split(".")[0] + model_name = model.split(".")[1] + if isinstance(prompt_messages, str): prompt = prompt_messages else: prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name) - return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -482,24 +504,28 @@ class BedrockLargeLanguageModel(LargeLanguageModel): "topP": 0.9, "maxTokens": 32, } - + try: ping_message = UserPromptMessage(content="ping") - self._invoke(model=model, - credentials=credentials, - prompt_messages=[ping_message], - model_parameters=required_params, - stream=False) - + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[ping_message], + model_parameters=required_params, + stream=False, + ) + except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg))) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str: + def _convert_one_message_to_text( + self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None + ) -> str: """ Convert a single message to a string. @@ -514,7 +540,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): if isinstance(message, UserPromptMessage): body = content - if (isinstance(content, list)): + if isinstance(content, list): body = "".join([c.data for c in content if c.type == PromptMessageContentType.TEXT]) message_text = f"{human_prompt_prefix} {body} {human_prompt_postfix}" elif isinstance(message, AssistantPromptMessage): @@ -528,7 +554,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str: + def _convert_messages_to_prompt( + self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None + ) -> str: """ Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models @@ -537,27 +565,31 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: - return '' + return "" messages = messages.copy() # don't mutate the original list if not isinstance(messages[-1], AssistantPromptMessage): messages.append(AssistantPromptMessage(content="")) - text = "".join( - self._convert_one_message_to_text(message, model_prefix, model_name) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message, model_prefix, model_name) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + def _create_payload( + self, + model: str, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + ): """ Create payload for bedrock api call depending on model provider """ payload = {} - model_prefix = model.split('.')[0] - model_name = model.split('.')[1] + model_prefix = model.split(".")[0] + model_name = model.split(".")[1] if model_prefix == "ai21": payload["temperature"] = model_parameters.get("temperature") @@ -571,21 +603,27 @@ class BedrockLargeLanguageModel(LargeLanguageModel): payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")} if model_parameters.get("countPenalty"): payload["countPenalty"] = {model_parameters.get("countPenalty")} - + elif model_prefix == "cohere": - payload = { **model_parameters } + payload = {**model_parameters} payload["prompt"] = prompt_messages[0].content payload["stream"] = stream - + else: raise ValueError(f"Got unknown model prefix {model_prefix}") - + return payload - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -598,18 +636,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - client_config = Config( - region_name=credentials["aws_region"] - ) + client_config = Config(region_name=credentials["aws_region"]) runtime_client = boto3.client( - service_name='bedrock-runtime', + service_name="bedrock-runtime", config=client_config, aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key") + aws_secret_access_key=credentials.get("aws_secret_access_key"), ) - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream) # need workaround for ai21 models which doesn't support streaming @@ -619,18 +655,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): invoke = runtime_client.invoke_model try: - body_jsonstr=json.dumps(payload) - response = invoke( - modelId=model, - contentType="application/json", - accept= "*/*", - body=body_jsonstr - ) + body_jsonstr = json.dumps(payload) + response = invoke(modelId=model, contentType="application/json", accept="*/*", body=body_jsonstr) except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) - + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: raise InvokeConnectionError(str(ex)) @@ -639,15 +670,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel): except Exception as ex: raise InvokeError(str(ex)) - if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -657,7 +688,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) finish_reason = response_body.get("error") @@ -665,25 +696,23 @@ class BedrockLargeLanguageModel(LargeLanguageModel): raise InvokeError(finish_reason) # get output text and calculate num tokens based on model / provider - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "ai21": - output = response_body.get('completions')[0].get('data').get('text') + output = response_body.get("completions")[0].get("data").get("text") prompt_tokens = len(response_body.get("prompt").get("tokens")) - completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) - + completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens")) + elif model_prefix == "cohere": output = response_body.get("generations")[0].get("text") prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, output if output else '') - + completion_tokens = self.get_num_tokens(model, credentials, output if output else "") + else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") # construct assistant message from output - assistant_prompt_message = AssistantPromptMessage( - content=output - ) + assistant_prompt_message = AssistantPromptMessage(content=output) # calculate usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) @@ -698,8 +727,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -709,65 +739,59 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "ai21": - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) - content = response_body.get('completions')[0].get('data').get('text') - finish_reason = response_body.get('completions')[0].get('finish_reason') + content = response_body.get("completions")[0].get("data").get("text") + finish_reason = response_body.get("completions")[0].get("finish_reason") prompt_tokens = len(response_body.get("prompt").get("tokens")) - completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) + completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens")) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=content), - finish_reason=finish_reason, - usage=usage - ) - ) + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, message=AssistantPromptMessage(content=content), finish_reason=finish_reason, usage=usage + ), + ) return - - stream = response.get('body') + + stream = response.get("body") if not stream: - raise InvokeError('No response body') - + raise InvokeError("No response body") + index = -1 for event in stream: - chunk = event.get('chunk') - + chunk = event.get("chunk") + if not chunk: exception_name = next(iter(event)) full_ex_msg = f"{exception_name}: {event[exception_name]['message']}" raise self._map_client_to_invoke_error(exception_name, full_ex_msg) - payload = json.loads(chunk.get('bytes').decode()) + payload = json.loads(chunk.get("bytes").decode()) - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "cohere": content_delta = payload.get("text") finish_reason = payload.get("finish_reason") - + else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response") # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content = content_delta if content_delta else '', + content=content_delta if content_delta else "", ) index += 1 - + if not finish_reason: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: @@ -777,18 +801,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel): # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=finish_reason, - usage=usage - ) + index=index, message=assistant_prompt_message, finish_reason=finish_reason, usage=usage + ), ) - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -804,9 +825,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } - + def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]: """ Map client error to invoke error @@ -822,7 +843,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return InvokeBadRequestError(error_msg) elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: return InvokeRateLimitError(error_msg) - elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + elif error_code in [ + "ModelTimeoutException", + "ModelErrorException", + "InternalServerException", + "ModelNotReadyException", + ]: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index ef22a9c868..2d898e3aaa 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -27,12 +27,11 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE logger = logging.getLogger(__name__) + class BedrockTextEmbeddingModel(TextEmbeddingModel): - - - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -42,67 +41,56 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - client_config = Config( - region_name=credentials["aws_region"] - ) + client_config = Config(region_name=credentials["aws_region"]) bedrock_runtime = boto3.client( - service_name='bedrock-runtime', + service_name="bedrock-runtime", config=client_config, aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key") + aws_secret_access_key=credentials.get("aws_secret_access_key"), ) embeddings = [] token_usage = 0 - - model_prefix = model.split('.')[0] - - if model_prefix == "amazon" : + + model_prefix = model.split(".")[0] + + if model_prefix == "amazon": for text in texts: body = { - "inputText": text, + "inputText": text, } response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend([response_body.get('embedding')]) - token_usage += response_body.get('inputTextTokenCount') - logger.warning(f'Total Tokens: {token_usage}') + embeddings.extend([response_body.get("embedding")]) + token_usage += response_body.get("inputTextTokenCount") + logger.warning(f"Total Tokens: {token_usage}") result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - if model_prefix == "cohere" : - input_type = 'search_document' if len(texts) > 1 else 'search_query' + if model_prefix == "cohere": + input_type = "search_document" if len(texts) > 1 else "search_query" for text in texts: body = { - "texts": [text], - "input_type": input_type, + "texts": [text], + "input_type": input_type, } response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend(response_body.get('embeddings')) + embeddings.extend(response_body.get("embeddings")) token_usage += len(text) result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - #others + # others raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -125,7 +113,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): :param credentials: model credentials :return: """ - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -141,19 +129,25 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } - - def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + + def _create_payload( + self, + model_prefix: str, + texts: list[str], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + ): """ Create payload for bedrock api call depending on model provider """ payload = {} if model_prefix == "amazon": - payload['inputText'] = texts + payload["inputText"] = texts - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -165,10 +159,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,7 +170,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -199,31 +190,37 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): return InvokeBadRequestError(error_msg) elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: return InvokeRateLimitError(error_msg) - elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + elif error_code in [ + "ModelTimeoutException", + "ModelErrorException", + "InternalServerException", + "ModelNotReadyException", + ]: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) return InvokeError(error_msg) - - def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ): - accept = 'application/json' - content_type = 'application/json' + def _invoke_bedrock_embedding( + self, + model: str, + bedrock_runtime, + body: dict, + ): + accept = "application/json" + content_type = "application/json" try: response = bedrock_runtime.invoke_model( - body=json.dumps(body), - modelId=model, - accept=accept, - contentType=content_type + body=json.dumps(body), modelId=model, accept=accept, contentType=content_type ) - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) return response_body except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) - + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: raise InvokeConnectionError(str(ex)) diff --git a/api/core/model_runtime/model_providers/chatglm/chatglm.py b/api/core/model_runtime/model_providers/chatglm/chatglm.py index e9dd5794f3..71d9a15322 100644 --- a/api/core/model_runtime/model_providers/chatglm/chatglm.py +++ b/api/core/model_runtime/model_providers/chatglm/chatglm.py @@ -20,12 +20,9 @@ class ChatGLMProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `chatglm3-6b` model for validate, - model_instance.validate_credentials( - model='chatglm3-6b', - credentials=credentials - ) + model_instance.validate_credentials(model="chatglm3-6b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index e83d08af71..114acc1ec3 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -43,12 +43,19 @@ from core.model_runtime.utils import helper logger = logging.getLogger(__name__) + class ChatGLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ Invoke large language model @@ -71,11 +78,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -96,11 +108,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): :return: """ try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content="ping"), - ], model_parameters={ - "max_tokens": 16, - }) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[ + UserPromptMessage(content="ping"), + ], + model_parameters={ + "max_tokens": 16, + }, + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) @@ -124,24 +141,24 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError + PermissionDeniedError, ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ Invoke large language model @@ -163,35 +180,31 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if tools and len(tools) > 0: - extra_model_kwargs['functions'] = [ - helper.dump_model(tool) for tool in tools - ] + extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools] result = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - + return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - + def _check_chatglm_parameters(self, model: str, model_parameters: dict, tools: list[PromptMessageTool]) -> None: if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0: raise InvokeBadRequestError("ChatGLM2 does not support function calling") @@ -212,7 +225,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -223,12 +236,12 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): message_dict = {"role": "function", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - - def _extract_response_tool_calls(self, - response_function_calls: list[FunctionCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + + def _extract_response_tool_calls( + self, response_function_calls: list[FunctionCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -239,19 +252,14 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): if response_function_calls: for response_tool_call in response_function_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.name, - arguments=response_tool_call.arguments + name=response_tool_call.name, arguments=response_tool_call.arguments ) - tool_call = AssistantPromptMessage.ToolCall( - id=0, - type='function', - function=function - ) + tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function) tool_calls.append(tool_call) return tool_calls - + def _to_client_kwargs(self, credentials: dict) -> dict: """ Convert invoke kwargs to client kwargs @@ -265,17 +273,20 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": str(URL(credentials['api_base']) / 'v1') + "base_url": str(URL(credentials["api_base"]) / "v1"), } return client_kwargs - - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) \ - -> Generator: - - full_response = '' + + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -283,9 +294,9 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue - + # check if there is a tool call in the response function_calls = None if delta.delta.function_call: @@ -295,23 +306,25 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -320,7 +333,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -335,11 +348,15 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ) full_response += delta.delta.content - - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) \ - -> LLMResult: + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -359,15 +376,14 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else []) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -378,7 +394,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ) return response - + def _num_tokens_from_string(self, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -395,17 +411,19 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer. it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer, As a temporary solution we use GPT2 tokenizer instead. """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) - + tokens_per_message = 3 tokens_per_name = 1 num_tokens = 0 @@ -414,10 +432,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text if key == "function_call": @@ -452,36 +470,37 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return: number of tokens """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) diff --git a/api/core/model_runtime/model_providers/cohere/cohere.py b/api/core/model_runtime/model_providers/cohere/cohere.py index cfbcb94d26..8394a45fcf 100644 --- a/api/core/model_runtime/model_providers/cohere/cohere.py +++ b/api/core/model_runtime/model_providers/cohere/cohere.py @@ -20,12 +20,9 @@ class CohereProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.RERANK) # Use `rerank-english-v2.0` model for validate, - model_instance.validate_credentials( - model='rerank-english-v2.0', - credentials=credentials - ) + model_instance.validate_credentials(model="rerank-english-v2.0", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 89b04c0279..203ca9c4a0 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -55,11 +55,17 @@ class CohereLargeLanguageModel(LargeLanguageModel): Model class for Cohere large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -85,7 +91,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: return self._generate( @@ -95,11 +101,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -136,30 +147,37 @@ class CohereLargeLanguageModel(LargeLanguageModel): self._chat_generate( model=model, credentials=credentials, - prompt_messages=[UserPromptMessage(content='ping')], + prompt_messages=[UserPromptMessage(content="ping")], model_parameters={ - 'max_tokens': 20, - 'temperature': 0, + "max_tokens": 20, + "temperature": 0, }, - stream=False + stream=False, ) else: self._generate( model=model, credentials=credentials, - prompt_messages=[UserPromptMessage(content='ping')], + prompt_messages=[UserPromptMessage(content="ping")], model_parameters={ - 'max_tokens': 20, - 'temperature': 0, + "max_tokens": 20, + "temperature": 0, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm model @@ -173,17 +191,17 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) if stop: - model_parameters['end_sequences'] = stop + model_parameters["end_sequences"] = stop if stream: response = client.generate_stream( prompt=prompt_messages[0].content, model=model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_generate_stream_response(model, credentials, response, prompt_messages) @@ -192,14 +210,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt=prompt_messages[0].content, model=model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: Generation, - prompt_messages: list[PromptMessage]) \ - -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: Generation, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -212,9 +230,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): assistant_text = response.generations[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens prompt_tokens = int(response.meta.billed_units.input_tokens) @@ -225,17 +241,18 @@ class CohereLargeLanguageModel(LargeLanguageModel): # transform response response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_generate_stream_response(self, model: str, credentials: dict, - response: Iterator[GenerateStreamedResponse], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + response: Iterator[GenerateStreamedResponse], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -245,7 +262,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: llm response chunk generator """ index = 1 - full_assistant_content = '' + full_assistant_content = "" for chunk in response: if isinstance(chunk, GenerateStreamedResponse_TextGeneration): chunk = cast(GenerateStreamedResponse_TextGeneration, chunk) @@ -255,9 +272,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_assistant_content += text @@ -267,7 +282,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 @@ -277,9 +292,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) completion_tokens = self._num_tokens_from_messages( - model, - credentials, - [AssistantPromptMessage(content=full_assistant_content)] + model, credentials, [AssistantPromptMessage(content=full_assistant_content)] ) # transform usage @@ -290,20 +303,27 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content=''), + message=AssistantPromptMessage(content=""), finish_reason=chunk.finish_reason, - usage=usage - ) + usage=usage, + ), ) break elif isinstance(chunk, GenerateStreamedResponse_StreamError): chunk = cast(GenerateStreamedResponse_StreamError, chunk) raise InvokeBadRequestError(chunk.err) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -318,27 +338,28 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) if stop: - model_parameters['stop_sequences'] = stop + model_parameters["stop_sequences"] = stop if tools: if len(tools) == 1: raise ValueError("Cohere tool call requires at least two tools to be specified.") - model_parameters['tools'] = self._convert_tools(tools) + model_parameters["tools"] = self._convert_tools(tools) - message, chat_histories, tool_results \ - = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) + message, chat_histories, tool_results = self._convert_prompt_messages_to_message_and_chat_histories( + prompt_messages + ) if tool_results: - model_parameters['tool_results'] = tool_results + model_parameters["tool_results"] = tool_results # chat model real_model = model if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: - real_model = model.removesuffix('-chat') + real_model = model.removesuffix("-chat") if stream: response = client.chat_stream( @@ -346,7 +367,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): chat_history=chat_histories, model=real_model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) @@ -356,14 +377,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): chat_history=chat_histories, model=real_model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_chat_generate_response(model, credentials, response, prompt_messages) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse, - prompt_messages: list[PromptMessage]) \ - -> LLMResult: + def _handle_chat_generate_response( + self, model: str, credentials: dict, response: NonStreamedChatResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -380,19 +401,15 @@ class CohereLargeLanguageModel(LargeLanguageModel): for cohere_tool_call in response.tool_calls: tool_call = AssistantPromptMessage.ToolCall( id=cohere_tool_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=cohere_tool_call.name, - arguments=json.dumps(cohere_tool_call.parameters) - ) + name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters) + ), ) tool_calls.append(tool_call) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text, tool_calls=tool_calls) # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) @@ -403,17 +420,18 @@ class CohereLargeLanguageModel(LargeLanguageModel): # transform response response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Iterator[StreamedChatResponse], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Iterator[StreamedChatResponse], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -423,17 +441,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: llm response chunk generator """ - def final_response(full_text: str, - tool_calls: list[AssistantPromptMessage.ToolCall], - index: int, - finish_reason: Optional[str] = None) -> LLMResultChunk: + def final_response( + full_text: str, + tool_calls: list[AssistantPromptMessage.ToolCall], + index: int, + finish_reason: Optional[str] = None, + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) - full_assistant_prompt_message = AssistantPromptMessage( - content=full_text, - tool_calls=tool_calls - ) + full_assistant_prompt_message = AssistantPromptMessage(content=full_text, tool_calls=tool_calls) completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message]) # transform usage @@ -444,14 +461,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content='', tool_calls=tool_calls), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) index = 1 - full_assistant_content = '' + full_assistant_content = "" tool_calls = [] for chunk in response: if isinstance(chunk, StreamedChatResponse_TextGeneration): @@ -462,9 +479,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_assistant_content += text @@ -474,7 +489,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 @@ -484,11 +499,10 @@ class CohereLargeLanguageModel(LargeLanguageModel): for cohere_tool_call in chunk.tool_calls: tool_call = AssistantPromptMessage.ToolCall( id=cohere_tool_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=cohere_tool_call.name, - arguments=json.dumps(cohere_tool_call.parameters) - ) + name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters) + ), ) tool_calls.append(tool_call) elif isinstance(chunk, StreamedChatResponse_StreamEnd): @@ -496,8 +510,9 @@ class CohereLargeLanguageModel(LargeLanguageModel): yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason) index += 1 - def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ - -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]: + def _convert_prompt_messages_to_message_and_chat_histories( + self, prompt_messages: list[PromptMessage] + ) -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]: """ Convert prompt messages to message and chat histories :param prompt_messages: prompt messages @@ -510,13 +525,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt_message = cast(AssistantPromptMessage, prompt_message) if prompt_message.tool_calls: for tool_call in prompt_message.tool_calls: - latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem( - call=ToolCall( - name=tool_call.function.name, - parameters=json.loads(tool_call.function.arguments) - ), - outputs=[] - )) + latest_tool_call_n_outputs.append( + ChatStreamRequestToolResultsItem( + call=ToolCall( + name=tool_call.function.name, parameters=json.loads(tool_call.function.arguments) + ), + outputs=[], + ) + ) else: cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message) if cohere_prompt_message: @@ -529,12 +545,9 @@ class CohereLargeLanguageModel(LargeLanguageModel): if tool_call_n_outputs.call.name == prompt_message.tool_call_id: latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem( call=ToolCall( - name=tool_call_n_outputs.call.name, - parameters=tool_call_n_outputs.call.parameters + name=tool_call_n_outputs.call.name, parameters=tool_call_n_outputs.call.parameters ), - outputs=[{ - "result": prompt_message.content - }] + outputs=[{"result": prompt_message.content}], ) break i += 1 @@ -556,7 +569,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): latest_message = chat_histories.pop() message = latest_message.message else: - raise ValueError('Prompt messages is empty') + raise ValueError("Prompt messages is empty") return message, chat_histories, latest_tool_call_n_outputs @@ -569,7 +582,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): if isinstance(message.content, str): chat_message = ChatMessage(role="USER", message=message.content) else: - sub_message_text = '' + sub_message_text = "" for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) @@ -597,8 +610,8 @@ class CohereLargeLanguageModel(LargeLanguageModel): """ cohere_tools = [] for tool in tools: - properties = tool.parameters['properties'] - required_properties = tool.parameters['required'] + properties = tool.parameters["properties"] + required_properties = tool.parameters["required"] parameter_definitions = {} for p_key, p_val in properties.items(): @@ -606,21 +619,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): if p_key in required_properties: required = True - desc = p_val['description'] - if 'enum' in p_val: - desc += (f"; Only accepts one of the following predefined options: " - f"[{', '.join(p_val['enum'])}]") + desc = p_val["description"] + if "enum" in p_val: + desc += f"; Only accepts one of the following predefined options: " f"[{', '.join(p_val['enum'])}]" parameter_definitions[p_key] = ToolParameterDefinitionsValue( - description=desc, - type=p_val['type'], - required=required + description=desc, type=p_val["type"], required=required ) cohere_tool = Tool( - name=tool.name, - description=tool.description, - parameter_definitions=parameter_definitions + name=tool.name, description=tool.description, parameter_definitions=parameter_definitions ) cohere_tools.append(cohere_tool) @@ -637,12 +645,9 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: number of tokens """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) - response = client.tokenize( - text=text, - model=model - ) + response = client.tokenize(text=text, model=model) return len(response.tokens) @@ -658,30 +663,30 @@ class CohereLargeLanguageModel(LargeLanguageModel): real_model = model if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: - real_model = model.removesuffix('-chat') + real_model = model.removesuffix("-chat") return self._num_tokens_from_string(real_model, credentials, message_str) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - Cohere supports fine-tuning of their models. This method returns the schema of the base model - but renamed to the fine-tuned model name. + Cohere supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. - :param model: model name - :param credentials: credentials + :param model: model name + :param credentials: credentials - :return: model schema + :return: model schema """ # get model schema models = self.predefined_models() model_map = {model.model: model for model in models} - mode = credentials.get('mode') + mode = credentials.get("mode") - if mode == 'chat': - base_model_schema = model_map['command-light-chat'] + if mode == "chat": + base_model_schema = model_map["command-light-chat"] else: - base_model_schema = model_map['command-light'] + base_model_schema = model_map["command-light"] base_model_schema = cast(AIModelEntity, base_model_schema) @@ -691,16 +696,13 @@ class CohereLargeLanguageModel(LargeLanguageModel): entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, features=list(base_model_schema_features), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=dict(base_model_schema_model_properties.items()), parameter_rules=list(base_model_schema_parameters_rules), - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing, ) return entity @@ -716,22 +718,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index d2fdb30c6f..aba8fedbc0 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -21,10 +21,16 @@ class CohereRerankModel(RerankModel): Model class for Cohere rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -38,20 +44,17 @@ class CohereRerankModel(RerankModel): :return: rerank result """ if len(docs) == 0: - return RerankResult( - model=model, - docs=docs - ) + return RerankResult(model=model, docs=docs) # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) response = client.rerank( query=query, documents=docs, model=model, top_n=top_n, return_documents=True, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) rerank_documents = [] @@ -70,10 +73,7 @@ class CohereRerankModel(RerankModel): else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -94,7 +94,7 @@ class CohereRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -110,22 +110,16 @@ class CohereRerankModel(RerankModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 0540fb740f..a1c5e98118 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -24,9 +24,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): Model class for Cohere text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -46,14 +46,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - tokenize_response = self._tokenize( - model=model, - credentials=credentials, - text=text - ) + tokenize_response = self._tokenize(model=model, credentials=credentials, text=text) for j in range(0, len(tokenize_response), context_size): - tokens += [tokenize_response[j: j + context_size]] + tokens += [tokenize_response[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -62,9 +58,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=["".join(token) for token in tokens[i: i + max_chunks]] + model=model, credentials=credentials, texts=["".join(token) for token in tokens[i : i + max_chunks]] ) used_tokens += embedding_used_tokens @@ -80,9 +74,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=[" "] + model=model, credentials=credentials, texts=[" "] ) used_tokens += embedding_used_tokens @@ -92,17 +84,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -116,14 +100,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): if len(texts) == 0: return 0 - full_text = ' '.join(texts) + full_text = " ".join(texts) try: - response = self._tokenize( - model=model, - credentials=credentials, - text=full_text - ) + response = self._tokenize(model=model, credentials=credentials, text=full_text) except Exception as e: raise self._transform_invoke_error(e) @@ -141,14 +121,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): return [] # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) - response = client.tokenize( - text=text, - model=model, - offline=False, - request_options=RequestOptions(max_retries=0) - ) + response = client.tokenize(text=text, model=model, offline=False, request_options=RequestOptions(max_retries=0)) return response.token_strings @@ -162,11 +137,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): """ try: # call embedding model - self._embedding_invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._embedding_invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -180,14 +151,14 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): :return: embeddings and used tokens """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) # call embedding model response = client.embed( texts=texts, model=model, - input_type='search_document' if len(texts) > 1 else 'search_query', - request_options=RequestOptions(max_retries=1) + input_type="search_document" if len(texts) > 1 else "search_query", + request_options=RequestOptions(max_retries=1), ) return response.embeddings, int(response.meta.billed_units.input_tokens) @@ -203,10 +174,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -217,7 +185,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -233,22 +201,16 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/deepseek/deepseek.py b/api/core/model_runtime/model_providers/deepseek/deepseek.py index d61fd4ddc8..10feef8972 100644 --- a/api/core/model_runtime/model_providers/deepseek/deepseek.py +++ b/api/core/model_runtime/model_providers/deepseek/deepseek.py @@ -7,9 +7,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) - class DeepSeekProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -22,12 +20,9 @@ class DeepSeekProvider(ModelProvider): # Use `deepseek-chat` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='deepseek-chat', - credentials=credentials - ) + model_instance.validate_credentials(model="deepseek-chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/deepseek/llm/llm.py b/api/core/model_runtime/model_providers/deepseek/llm/llm.py index bdb3823b60..6d0a3ee262 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/llm.py +++ b/api/core/model_runtime/model_providers/deepseek/llm/llm.py @@ -13,12 +13,17 @@ from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguag class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -27,10 +32,8 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -48,8 +51,9 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -69,10 +73,10 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -103,11 +107,10 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://api.deepseek.com' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://api.deepseek.com" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" - + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/fishaudio/__init__.py b/api/core/model_runtime/model_providers/fishaudio/__init__.py index 5f282702bb..e69de29bb2 100644 --- a/api/core/model_runtime/model_providers/fishaudio/__init__.py +++ b/api/core/model_runtime/model_providers/fishaudio/__init__.py @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/fishaudio/fishaudio.py b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py index 9f80996d9d..3bc4b533e0 100644 --- a/api/core/model_runtime/model_providers/fishaudio/fishaudio.py +++ b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py @@ -1,4 +1,4 @@ -import logging +import logging from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -18,11 +18,9 @@ class FishAudioProvider(ModelProvider): """ try: model_instance = self.get_model_instance(ModelType.TTS) - model_instance.validate_credentials( - credentials=credentials - ) + model_instance.validate_credentials(credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/tts.py b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py index 5b673ce186..895a7a914c 100644 --- a/api/core/model_runtime/model_providers/fishaudio/tts/tts.py +++ b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional import httpx @@ -12,9 +12,7 @@ class FishAudioText2SpeechModel(TTSModel): Model class for Fish.audio Text to Speech model. """ - def get_tts_model_voices( - self, model: str, credentials: dict, language: Optional[str] = None - ) -> list: + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: api_base = credentials.get("api_base", "https://api.fish.audio") api_key = credentials.get("api_key") use_public_models = credentials.get("use_public_models", "false") == "true" @@ -68,9 +66,7 @@ class FishAudioText2SpeechModel(TTSModel): voice=voice, ) - def validate_credentials( - self, credentials: dict, user: Optional[str] = None - ) -> None: + def validate_credentials(self, credentials: dict, user: Optional[str] = None) -> None: """ Validate credentials for text2speech model @@ -91,9 +87,7 @@ class FishAudioText2SpeechModel(TTSModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming( - self, model: str, credentials: dict, content_text: str, voice: str - ) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ Invoke streaming text2speech model :param model: model name @@ -106,12 +100,10 @@ class FishAudioText2SpeechModel(TTSModel): try: word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: - sentences = self._split_text_into_sentences( - content_text, max_length=word_limit - ) + sentences = self._split_text_into_sentences(content_text, max_length=word_limit) else: sentences = [content_text.strip()] - + for i in range(len(sentences)): yield from self._tts_invoke_streaming_sentence( credentials=credentials, content_text=sentences[i], voice=voice @@ -120,9 +112,7 @@ class FishAudioText2SpeechModel(TTSModel): except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _tts_invoke_streaming_sentence( - self, credentials: dict, content_text: str, voice: Optional[str] = None - ) -> any: + def _tts_invoke_streaming_sentence(self, credentials: dict, content_text: str, voice: Optional[str] = None) -> any: """ Invoke streaming text2speech model @@ -141,20 +131,14 @@ class FishAudioText2SpeechModel(TTSModel): with httpx.stream( "POST", api_url + "/v1/tts", - json={ - "text": content_text, - "reference_id": voice, - "latency": latency - }, + json={"text": content_text, "reference_id": voice, "latency": latency}, headers={ "Authorization": f"Bearer {api_key}", }, timeout=None, ) as response: if response.status_code != 200: - raise InvokeBadRequestError( - f"Error: {response.status_code} - {response.text}" - ) + raise InvokeBadRequestError(f"Error: {response.status_code} - {response.text}") yield from response.iter_bytes() @property diff --git a/api/core/model_runtime/model_providers/google/google.py b/api/core/model_runtime/model_providers/google/google.py index ba25c74e71..70f56a8337 100644 --- a/api/core/model_runtime/model_providers/google/google.py +++ b/api/core/model_runtime/model_providers/google/google.py @@ -20,12 +20,9 @@ class GoogleProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `gemini-pro` model for validate, - model_instance.validate_credentials( - model='gemini-pro', - credentials=credentials - ) + model_instance.validate_credentials(model="gemini-pro", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 11f9f32f96..274ff02095 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -49,12 +49,17 @@ if you are not sure about the structure. class GoogleLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -70,9 +75,14 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -85,7 +95,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Google model @@ -95,13 +105,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() - + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: """ Convert tool messages to glm tools @@ -117,14 +124,16 @@ class GoogleLargeLanguageModel(LargeLanguageModel): type=glm.Type.OBJECT, properties={ key: { - 'type_': value.get('type', 'string').upper(), - 'description': value.get('description', ''), - 'enum': value.get('enum', []) - } for key, value in tool.parameters.get('properties', {}).items() + "type_": value.get("type", "string").upper(), + "description": value.get("description", ""), + "enum": value.get("enum", []), + } + for key, value in tool.parameters.get("properties", {}).items() }, - required=tool.parameters.get('required', []) + required=tool.parameters.get("required", []), ), - ) for tool in tools + ) + for tool in tools ] ) @@ -136,20 +145,25 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :param credentials: model credentials :return: """ - + try: ping_message = SystemPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) - + except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -163,14 +177,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if stop: config_kwargs["stop_sequences"] = stop - google_model = genai.GenerativeModel( - model_name=model - ) + google_model = genai.GenerativeModel(model_name=model) history = [] @@ -180,7 +192,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): content = self._format_message_to_glm_content(last_msg) history.append(content) else: - for msg in prompt_messages: # makes message roles strictly alternating + for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) if history and history[-1]["role"] == content["role"]: history[-1]["parts"].extend(content["parts"]) @@ -194,7 +206,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): google_model._client = new_custom_client - safety_settings={ + safety_settings = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, @@ -203,13 +215,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel): response = google_model.generate_content( contents=history, - generation_config=genai.types.GenerationConfig( - **config_kwargs - ), + generation_config=genai.types.GenerationConfig(**config_kwargs), stream=stream, safety_settings=safety_settings, tools=self._convert_tools_to_glm_tool(tools) if tools else None, - request_options={"timeout": 600} + request_options={"timeout": 600}, ) if stream: @@ -217,8 +227,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: GenerateContentResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -229,9 +240,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.text) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -250,8 +259,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: GenerateContentResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -264,9 +274,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): index = -1 for chunk in response: for part in chunk.parts: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if part.text: assistant_prompt_message.content += part.text @@ -275,36 +283,31 @@ class GoogleLargeLanguageModel(LargeLanguageModel): assistant_prompt_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=part.function_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=part.function_call.name, - arguments=json.dumps(dict(part.function_call.args.items())) - ) + arguments=json.dumps(dict(part.function_call.args.items())), + ), ) ] index += 1 - + if not response._done: - # transform assistant message to prompt message yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: - # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -312,8 +315,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel): index=index, message=assistant_prompt_message, finish_reason=str(chunk.candidates[0].finish_reason), - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -328,9 +331,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" @@ -353,65 +354,61 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: glm Content representation of message """ if isinstance(message, UserPromptMessage): - glm_content = { - "role": "user", - "parts": [] - } - if (isinstance(message.content, str)): - glm_content['parts'].append(to_part(message.content)) + glm_content = {"role": "user", "parts": []} + if isinstance(message.content, str): + glm_content["parts"].append(to_part(message.content)) else: for c in message.content: if c.type == PromptMessageContentType.TEXT: - glm_content['parts'].append(to_part(c.data)) + glm_content["parts"].append(to_part(c.data)) elif c.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, c) if message_content.data.startswith("data:"): - metadata, base64_data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] + metadata, base64_data = c.data.split(",", 1) + mime_type = metadata.split(";", 1)[0].split(":")[1] else: # fetch image data from url try: image_content = requests.get(message_content.data).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") - blob = {"inline_data":{"mime_type":mime_type,"data":base64_data}} - glm_content['parts'].append(blob) + blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}} + glm_content["parts"].append(blob) return glm_content elif isinstance(message, AssistantPromptMessage): - glm_content = { - "role": "model", - "parts": [] - } + glm_content = {"role": "model", "parts": []} if message.content: - glm_content['parts'].append(to_part(message.content)) + glm_content["parts"].append(to_part(message.content)) if message.tool_calls: - glm_content["parts"].append(to_part(glm.FunctionCall( - name=message.tool_calls[0].function.name, - args=json.loads(message.tool_calls[0].function.arguments), - ))) + glm_content["parts"].append( + to_part( + glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ) + ) + ) return glm_content elif isinstance(message, SystemPromptMessage): - return { - "role": "user", - "parts": [to_part(message.content)] - } + return {"role": "user", "parts": [to_part(message.content)]} elif isinstance(message, ToolPromptMessage): return { "role": "function", - "parts": [glm.Part(function_response=glm.FunctionResponse( - name=message.name, - response={ - "response": message.content - } - ))] + "parts": [ + glm.Part( + function_response=glm.FunctionResponse( + name=message.name, response={"response": message.content} + ) + ) + ], } else: raise ValueError(f"Got unknown type {message}") - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -423,25 +420,20 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: Invoke emd = genai.GenerativeModel(model) error mapping """ return { - InvokeConnectionError: [ - exceptions.RetryError - ], + InvokeConnectionError: [exceptions.RetryError], InvokeServerUnavailableError: [ exceptions.ServiceUnavailable, exceptions.InternalServerError, exceptions.BadGateway, exceptions.GatewayTimeout, - exceptions.DeadlineExceeded - ], - InvokeRateLimitError: [ - exceptions.ResourceExhausted, - exceptions.TooManyRequests + exceptions.DeadlineExceeded, ], + InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests], InvokeAuthorizationError: [ exceptions.Unauthenticated, exceptions.PermissionDenied, exceptions.Unauthenticated, - exceptions.Forbidden + exceptions.Forbidden, ], InvokeBadRequestError: [ exceptions.BadRequest, @@ -457,5 +449,5 @@ class GoogleLargeLanguageModel(LargeLanguageModel): exceptions.PreconditionFailed, exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, - ] + ], } diff --git a/api/core/model_runtime/model_providers/groq/groq.py b/api/core/model_runtime/model_providers/groq/groq.py index b3f37b3967..d0d5ff68f8 100644 --- a/api/core/model_runtime/model_providers/groq/groq.py +++ b/api/core/model_runtime/model_providers/groq/groq.py @@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) -class GroqProvider(ModelProvider): +class GroqProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -18,12 +18,9 @@ class GroqProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='llama3-8b-8192', - credentials=credentials - ) + model_instance.validate_credentials(model="llama3-8b-8192", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/groq/llm/llm.py b/api/core/model_runtime/model_providers/groq/llm/llm.py index 915f7a4e1a..352a7b519e 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llm.py +++ b/api/core/model_runtime/model_providers/groq/llm/llm.py @@ -7,11 +7,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,6 +27,5 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.groq.com/openai/v1' - + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.groq.com/openai/v1" diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py index dd8ae526e6..3c4020b6ee 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -4,12 +4,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError class _CommonHuggingfaceHub: - @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - HfHubHTTPError, - BadRequestError - ] - } + return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]} diff --git a/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py b/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py index 15e2a4fed4..54d2a2bf39 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class HuggingfaceHubProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index f43a8aedaf..10c6d553f3 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -29,16 +29,23 @@ from core.model_runtime.model_providers.huggingface_hub._common import _CommonHu class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + model = credentials["huggingfacehub_endpoint_url"] - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - model = credentials['huggingfacehub_endpoint_url'] - - if 'baichuan' in model.lower(): + if "baichuan" in model.lower(): stream = False response = client.text_generation( @@ -47,98 +54,100 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel stream=stream, model=model, stop_sequences=stop, - **model_parameters) + **model_parameters, + ) if stream: return self._handle_generate_stream_response(model, credentials, prompt_messages, response) return self._handle_generate_response(model, credentials, prompt_messages, response) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: try: - if 'huggingfacehub_api_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.') + if "huggingfacehub_api_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if credentials['huggingfacehub_api_type'] not in ('inference_endpoints', 'hosted_inference_api'): - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.') + if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"): + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") - if 'huggingfacehub_api_token' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Access Token must be provided.') + if "huggingfacehub_api_token" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Access Token must be provided.") - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - if 'huggingfacehub_endpoint_url' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.') + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + if "huggingfacehub_endpoint_url" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.") - if 'task_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.') - elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api': - credentials['task_type'] = self._get_hosted_model_task_type(credentials['huggingfacehub_api_token'], - model) + if "task_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.") + elif credentials["huggingfacehub_api_type"] == "hosted_inference_api": + credentials["task_type"] = self._get_hosted_model_task_type( + credentials["huggingfacehub_api_token"], model + ) - if credentials['task_type'] not in ("text2text-generation", "text-generation"): - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be one of text2text-generation, ' - 'text-generation.') + if credentials["task_type"] not in ("text2text-generation", "text-generation"): + raise CredentialsValidateFailedError( + "Huggingface Hub Task Type must be one of text2text-generation, " "text-generation." + ) - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - model = credentials['huggingfacehub_endpoint_url'] + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + model = credentials["huggingfacehub_endpoint_url"] try: - client.text_generation( - prompt='Who are you?', - stream=True, - model=model) + client.text_generation(prompt="Who are you?", stream=True, model=model) except BadRequestError as e: - raise CredentialsValidateFailedError('Only available for models running on with the `text-generation-inference`. ' - 'To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.') + raise CredentialsValidateFailedError( + "Only available for models running on with the `text-generation-inference`. " + "To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference." + ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ - ModelPropertyKey.MODE: LLMMode.COMPLETION.value - }, - parameter_rules=self._get_customizable_model_parameter_rules() + model_properties={ModelPropertyKey.MODE: LLMMode.COMPLETION.value}, + parameter_rules=self._get_customizable_model_parameter_rules(), ) return entity @staticmethod def _get_customizable_model_parameter_rules() -> list[ParameterRule]: - temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get( - DefaultParameterName.TEMPERATURE).copy() - temperature_rule_dict['name'] = 'temperature' + temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TEMPERATURE).copy() + temperature_rule_dict["name"] = "temperature" temperature_rule = ParameterRule(**temperature_rule_dict) temperature_rule.default = 0.5 top_p_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TOP_P).copy() - top_p_rule_dict['name'] = 'top_p' + top_p_rule_dict["name"] = "top_p" top_p_rule = ParameterRule(**top_p_rule_dict) top_p_rule.default = 0.5 top_k_rule = ParameterRule( - name='top_k', + name="top_k", label={ - 'en_US': 'Top K', - 'zh_Hans': 'Top K', + "en_US": "Top K", + "zh_Hans": "Top K", }, - type='int', + type="int", help={ - 'en_US': 'The number of highest probability vocabulary tokens to keep for top-k-filtering.', - 'zh_Hans': '保留的最高概率词汇标记的数量。', + "en_US": "The number of highest probability vocabulary tokens to keep for top-k-filtering.", + "zh_Hans": "保留的最高概率词汇标记的数量。", }, required=False, default=2, @@ -148,15 +157,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ) max_new_tokens = ParameterRule( - name='max_new_tokens', + name="max_new_tokens", label={ - 'en_US': 'Max New Tokens', - 'zh_Hans': '最大新标记', + "en_US": "Max New Tokens", + "zh_Hans": "最大新标记", }, - type='int', + type="int", help={ - 'en_US': 'Maximum number of generated tokens.', - 'zh_Hans': '生成的标记的最大数量。', + "en_US": "Maximum number of generated tokens.", + "zh_Hans": "生成的标记的最大数量。", }, required=False, default=20, @@ -166,30 +175,30 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ) seed = ParameterRule( - name='seed', + name="seed", label={ - 'en_US': 'Random sampling seed', - 'zh_Hans': '随机采样种子', + "en_US": "Random sampling seed", + "zh_Hans": "随机采样种子", }, - type='int', + type="int", help={ - 'en_US': 'Random sampling seed.', - 'zh_Hans': '随机采样种子。', + "en_US": "Random sampling seed.", + "zh_Hans": "随机采样种子。", }, required=False, precision=0, ) repetition_penalty = ParameterRule( - name='repetition_penalty', + name="repetition_penalty", label={ - 'en_US': 'Repetition Penalty', - 'zh_Hans': '重复惩罚', + "en_US": "Repetition Penalty", + "zh_Hans": "重复惩罚", }, - type='float', + type="float", help={ - 'en_US': 'The parameter for repetition penalty. 1.0 means no penalty.', - 'zh_Hans': '重复惩罚的参数。1.0 表示没有惩罚。', + "en_US": "The parameter for repetition penalty. 1.0 means no penalty.", + "zh_Hans": "重复惩罚的参数。1.0 表示没有惩罚。", }, required=False, precision=1, @@ -197,11 +206,9 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel return [temperature_rule, top_k_rule, top_p_rule, max_new_tokens, seed, repetition_penalty] - def _handle_generate_stream_response(self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - response: Generator) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: Generator + ) -> Generator: index = -1 for chunk in response: # skip special tokens @@ -210,9 +217,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel index += 1 - assistant_prompt_message = AssistantPromptMessage( - content=chunk.token.text - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk.token.text) if chunk.details: prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -240,15 +245,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ), ) - def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any + ) -> LLMResult: if isinstance(response, str): content = response else: content = response.generated_text - assistant_prompt_message = AssistantPromptMessage( - content=content - ) + assistant_prompt_message = AssistantPromptMessage(content=content) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) @@ -270,15 +275,14 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel try: if not model_info: - raise ValueError(f'Model {model_name} not found.') + raise ValueError(f"Model {model_name} not found.") - if 'inference' in model_info.cardData and not model_info.cardData['inference']: - raise ValueError(f'Inference API has been turned off for this model {model_name}.') + if "inference" in model_info.cardData and not model_info.cardData["inference"]: + raise ValueError(f"Inference API has been turned off for this model {model_name}.") valid_tasks = ("text2text-generation", "text-generation") if model_info.pipeline_tag not in valid_tasks: - raise ValueError(f"Model {model_name} is not a valid task, " - f"must be one of {valid_tasks}.") + raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {valid_tasks}.") except Exception as e: raise CredentialsValidateFailedError(f"{str(e)}") @@ -287,10 +291,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index 0f0c166f3e..cb7a30bbe5 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -13,40 +13,30 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub -HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/' +HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/" class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) execute_model = model - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - execute_model = credentials['huggingfacehub_endpoint_url'] + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + execute_model = credentials["huggingfacehub_endpoint_url"] output = client.post( - json={ - "inputs": texts, - "options": { - "wait_for_model": False, - "use_cache": False - } - }, - model=execute_model) + json={"inputs": texts, "options": {"wait_for_model": False, "use_cache": False}}, model=execute_model + ) embeddings = json.loads(output.decode()) tokens = self.get_num_tokens(model, credentials, texts) usage = self._calc_response_usage(model, credentials, tokens) - return TextEmbeddingResult( - embeddings=self._mean_pooling(embeddings), - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=self._mean_pooling(embeddings), usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: num_tokens = 0 @@ -56,52 +46,48 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel def validate_credentials(self, model: str, credentials: dict) -> None: try: - if 'huggingfacehub_api_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.') + if "huggingfacehub_api_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if 'huggingfacehub_api_token' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub API Token must be provided.') + if "huggingfacehub_api_token" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub API Token must be provided.") - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - if 'huggingface_namespace' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub User Name / Organization Name must be provided.') + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + if "huggingface_namespace" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub User Name / Organization Name must be provided." + ) - if 'huggingfacehub_endpoint_url' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.') + if "huggingfacehub_endpoint_url" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.") - if 'task_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.') + if "task_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.") - if credentials['task_type'] != 'feature-extraction': - raise CredentialsValidateFailedError('Huggingface Hub Task Type is invalid.') + if credentials["task_type"] != "feature-extraction": + raise CredentialsValidateFailedError("Huggingface Hub Task Type is invalid.") self._check_endpoint_url_model_repository_name(credentials, model) - model = credentials['huggingfacehub_endpoint_url'] + model = credentials["huggingfacehub_endpoint_url"] - elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api': - self._check_hosted_model_task_type(credentials['huggingfacehub_api_token'], - model) + elif credentials["huggingfacehub_api_type"] == "hosted_inference_api": + self._check_hosted_model_task_type(credentials["huggingfacehub_api_token"], model) else: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.') + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") - client = InferenceClient(token=credentials['huggingfacehub_api_token']) - client.feature_extraction(text='hello world', model=model) + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) + client.feature_extraction(text="hello world", model=model) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - 'context_size': 10000, - 'max_chunks': 1 - } + model_properties={"context_size": 10000, "max_chunks": 1}, ) return entity @@ -128,24 +114,20 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel try: if not model_info: - raise ValueError(f'Model {model_name} not found.') + raise ValueError(f"Model {model_name} not found.") - if 'inference' in model_info.cardData and not model_info.cardData['inference']: - raise ValueError(f'Inference API has been turned off for this model {model_name}.') + if "inference" in model_info.cardData and not model_info.cardData["inference"]: + raise ValueError(f"Inference API has been turned off for this model {model_name}.") valid_tasks = "feature-extraction" if model_info.pipeline_tag not in valid_tasks: - raise ValueError(f"Model {model_name} is not a valid task, " - f"must be one of {valid_tasks}.") + raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {valid_tasks}.") except Exception as e: raise CredentialsValidateFailedError(f"{str(e)}") def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -156,7 +138,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -166,25 +148,26 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel try: url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}' headers = { - 'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}', - 'Content-Type': 'application/json' + "Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}', + "Content-Type": "application/json", } response = requests.get(url=url, headers=headers) if response.status_code != 200: - raise ValueError('User Name or Organization Name is invalid.') + raise ValueError("User Name or Organization Name is invalid.") - model_repository_name = '' + model_repository_name = "" for item in response.json().get("items", []): - if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']: + if item.get("status", {}).get("url") == credentials["huggingfacehub_endpoint_url"]: model_repository_name = item.get("model", {}).get("repository") break if model_repository_name != model_name: raise ValueError( - f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.') + f"Model Name {model_name} is invalid. Please check it on the inference endpoints console." + ) except Exception as e: raise ValueError(str(e)) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py index 9454466250..97d7e28dc6 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class HuggingfaceTeiProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py index 34013426de..c128c35f6d 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py @@ -47,29 +47,29 @@ class HuggingfaceTeiRerankModel(RerankModel): """ if len(docs) == 0: return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): + if server_url.endswith("/"): server_url = server_url[:-1] try: results = TeiHelper.invoke_rerank(server_url, query, docs) rerank_documents = [] - for result in results: + for result in results: rerank_document = RerankDocument( - index=result['index'], - text=result['text'], - score=result['score'], + index=result["index"], + text=result["text"], + score=result["score"], ) - if score_threshold is None or result['score'] >= score_threshold: + if score_threshold is None or result["score"] >= score_threshold: rerank_documents.append(rerank_document) if top_n is not None and len(rerank_documents) >= top_n: break return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -80,21 +80,21 @@ class HuggingfaceTeiRerankModel(RerankModel): :return: """ try: - server_url = credentials['server_url'] + server_url = credentials["server_url"] extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) - if extra_args.model_type != 'reranker': - raise CredentialsValidateFailedError('Current model is not a rerank model') + if extra_args.model_type != "reranker": + raise CredentialsValidateFailedError("Current model is not a rerank model") - credentials['context_size'] = extra_args.max_input_length + credentials["context_size"] = extra_args.max_input_length self.invoke( model=model, credentials=credentials, - query='Whose kasumi', + query="Whose kasumi", docs=[ 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', - 'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ', - 'and she leads a team named PopiParty.', + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty.", ], score_threshold=0.8, ) @@ -129,7 +129,7 @@ class HuggingfaceTeiRerankModel(RerankModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), }, parameter_rules=[], ) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py index 2aa785c89d..56c51e8888 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py @@ -31,16 +31,16 @@ class TeiHelper: with cache_lock: if model_name not in cache: cache[model_name] = { - 'expires': time() + 300, - 'value': TeiHelper._get_tei_extra_parameter(server_url), + "expires": time() + 300, + "value": TeiHelper._get_tei_extra_parameter(server_url), } - return cache[model_name]['value'] + return cache[model_name]["value"] @staticmethod def _clean_cache() -> None: try: with cache_lock: - expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + expired_keys = [model_uid for model_uid, model in cache.items() if model["expires"] < time()] for model_uid in expired_keys: del cache[model_uid] except RuntimeError as e: @@ -52,40 +52,38 @@ class TeiHelper: get tei model extra parameter like model_type, max_input_length, max_batch_requests """ - url = str(URL(server_url) / 'info') + url = str(URL(server_url) / "info") # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 session = Session() - session.mount('http://', HTTPAdapter(max_retries=3)) - session.mount('https://', HTTPAdapter(max_retries=3)) + session.mount("http://", HTTPAdapter(max_retries=3)) + session.mount("https://", HTTPAdapter(max_retries=3)) try: response = session.get(url, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: - raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}') + raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}") if response.status_code != 200: raise RuntimeError( - f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}' + f"get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}" ) response_json = response.json() - model_type = response_json.get('model_type', {}) + model_type = response_json.get("model_type", {}) if len(model_type.keys()) < 1: - raise RuntimeError('model_type is empty') + raise RuntimeError("model_type is empty") model_type = list(model_type.keys())[0] - if model_type not in ['embedding', 'reranker']: - raise RuntimeError(f'invalid model_type: {model_type}') - - max_input_length = response_json.get('max_input_length', 512) - max_client_batch_size = response_json.get('max_client_batch_size', 1) + if model_type not in ["embedding", "reranker"]: + raise RuntimeError(f"invalid model_type: {model_type}") + + max_input_length = response_json.get("max_input_length", 512) + max_client_batch_size = response_json.get("max_client_batch_size", 1) return TeiModelExtraParameter( - model_type=model_type, - max_input_length=max_input_length, - max_client_batch_size=max_client_batch_size + model_type=model_type, max_input_length=max_input_length, max_client_batch_size=max_client_batch_size ) - + @staticmethod def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: """ @@ -116,12 +114,12 @@ class TeiHelper: :param texts: texts to tokenize """ resp = httpx.post( - f'{server_url}/tokenize', - json={'inputs': texts}, + f"{server_url}/tokenize", + json={"inputs": texts}, ) resp.raise_for_status() return resp.json() - + @staticmethod def invoke_embeddings(server_url: str, texts: list[str]) -> dict: """ @@ -149,8 +147,8 @@ class TeiHelper: """ # Use OpenAI compatible API here, which has usage tracking resp = httpx.post( - f'{server_url}/v1/embeddings', - json={'input': texts}, + f"{server_url}/v1/embeddings", + json={"input": texts}, ) resp.raise_for_status() return resp.json() @@ -173,11 +171,11 @@ class TeiHelper: :param texts: texts to rerank :param candidates: candidates to rerank """ - params = {'query': query, 'texts': docs, 'return_text': True} + params = {"query": query, "texts": docs, "return_text": True} response = httpx.post( - server_url + '/rerank', + server_url + "/rerank", json=params, ) - response.raise_for_status() + response.raise_for_status() return response.json() diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index 6897b87f6d..2d04abb277 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -40,12 +40,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): + if server_url.endswith("/"): server_url = server_url[:-1] - # get model properties context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -58,7 +57,6 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): - # Check if the number of tokens is larger than the context size num_tokens = len(tokenize_result) @@ -66,20 +64,22 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): # Find the best cutoff point pre_special_token_count = 0 for token in tokenize_result: - if token['special']: + if token["special"]: pre_special_token_count += 1 else: break - rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count + rest_special_token_count = ( + len([token for token in tokenize_result if token["special"]]) - pre_special_token_count + ) # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit token_cutoff = context_size - rest_special_token_count - 20 # Find the cutoff index cutpoint_token = tokenize_result[token_cutoff] - cutoff = cutpoint_token['start'] + cutoff = cutpoint_token["start"] - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -92,12 +92,12 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): for i in _iter: iter_texts = inputs[i : i + max_chunks] results = TeiHelper.invoke_embeddings(server_url, iter_texts) - embeddings = results['data'] - embeddings = [embedding['embedding'] for embedding in embeddings] + embeddings = results["data"] + embeddings = [embedding["embedding"] for embedding in embeddings] batched_embeddings.extend(embeddings) - usage = results['usage'] - used_tokens += usage['total_tokens'] + usage = results["usage"] + used_tokens += usage["total_tokens"] except RuntimeError as e: raise InvokeServerUnavailableError(str(e)) @@ -117,9 +117,9 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): :return: """ num_tokens = 0 - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): + if server_url.endswith("/"): server_url = server_url[:-1] batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) @@ -135,15 +135,15 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - server_url = credentials['server_url'] + server_url = credentials["server_url"] extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) print(extra_args) - if extra_args.model_type != 'embedding': - raise CredentialsValidateFailedError('Current model is not a embedding model') + if extra_args.model_type != "embedding": + raise CredentialsValidateFailedError("Current model is not a embedding model") - credentials['context_size'] = extra_args.max_input_length - credentials['max_chunks'] = extra_args.max_client_batch_size - self._invoke(model=model, credentials=credentials, texts=['ping']) + credentials["context_size"] = extra_args.max_input_length + credentials["max_chunks"] = extra_args.max_client_batch_size + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -195,8 +195,8 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ - ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)), - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + ModelPropertyKey.MAX_CHUNKS: int(credentials.get("max_chunks", 1)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), }, parameter_rules=[], ) diff --git a/api/core/model_runtime/model_providers/hunyuan/hunyuan.py b/api/core/model_runtime/model_providers/hunyuan/hunyuan.py index 5a298d33ac..e65772e7dd 100644 --- a/api/core/model_runtime/model_providers/hunyuan/hunyuan.py +++ b/api/core/model_runtime/model_providers/hunyuan/hunyuan.py @@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) -class HunyuanProvider(ModelProvider): +class HunyuanProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +19,9 @@ class HunyuanProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `hunyuan-standard` model for validate, - model_instance.validate_credentials( - model='hunyuan-standard', - credentials=credentials - ) + model_instance.validate_credentials(model="hunyuan-standard", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py index 0bdf6ec005..c056ab7a08 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py @@ -23,21 +23,27 @@ from core.model_runtime.model_providers.__base.large_language_model import Large logger = logging.getLogger(__name__) + class HunyuanLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = self._setup_hunyuan_client(credentials) request = models.ChatCompletionsRequest() messages_dict = self._convert_prompt_messages_to_dicts(prompt_messages) custom_parameters = { - 'Temperature': model_parameters.get('temperature', 0.0), - 'TopP': model_parameters.get('top_p', 1.0), - 'EnableEnhancement': model_parameters.get('enable_enhance', True) + "Temperature": model_parameters.get("temperature", 0.0), + "TopP": model_parameters.get("top_p", 1.0), + "EnableEnhancement": model_parameters.get("enable_enhance", True), } params = { @@ -47,16 +53,19 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): **custom_parameters, } # add Tools and ToolChoice - if (tools and len(tools) > 0): - params['ToolChoice'] = "auto" - params['Tools'] = [{ - "Type": "function", - "Function": { - "Name": tool.name, - "Description": tool.description, - "Parameters": json.dumps(tool.parameters) + if tools and len(tools) > 0: + params["ToolChoice"] = "auto" + params["Tools"] = [ + { + "Type": "function", + "Function": { + "Name": tool.name, + "Description": tool.description, + "Parameters": json.dumps(tool.parameters), + }, } - } for tool in tools] + for tool in tools + ] request.from_json_string(json.dumps(params)) response = client.ChatCompletions(request) @@ -76,22 +85,19 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): req = models.ChatCompletionsRequest() params = { "Model": model, - "Messages": [{ - "Role": "user", - "Content": "hello" - }], + "Messages": [{"Role": "user", "Content": "hello"}], "TopP": 1, "Temperature": 0, - "Stream": False + "Stream": False, } req.from_json_string(json.dumps(params)) client.ChatCompletions(req) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") def _setup_hunyuan_client(self, credentials): - secret_id = credentials['secret_id'] - secret_key = credentials['secret_key'] + secret_id = credentials["secret_id"] + secret_key = credentials["secret_key"] cred = credential.Credential(secret_id, secret_key) httpProfile = HttpProfile() httpProfile.endpoint = "hunyuan.tencentcloudapi.com" @@ -106,92 +112,96 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): for message in prompt_messages: if isinstance(message, AssistantPromptMessage): tool_calls = message.tool_calls - if (tool_calls and len(tool_calls) > 0): + if tool_calls and len(tool_calls) > 0: dict_tool_calls = [ { "Id": tool_call.id, "Type": tool_call.type, "Function": { "Name": tool_call.function.name, - "Arguments": tool_call.function.arguments if (tool_call.function.arguments == "") else "{}" - } - } for tool_call in tool_calls] - - dict_list.append({ - "Role": message.role.value, - # fix set content = "" while tool_call request - # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time. - "Content": " ", # message.content if (message.content is not None) else "", - "ToolCalls": dict_tool_calls - }) + "Arguments": tool_call.function.arguments + if (tool_call.function.arguments == "") + else "{}", + }, + } + for tool_call in tool_calls + ] + + dict_list.append( + { + "Role": message.role.value, + # fix set content = "" while tool_call request + # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time. + "Content": " ", # message.content if (message.content is not None) else "", + "ToolCalls": dict_tool_calls, + } + ) else: - dict_list.append({ "Role": message.role.value, "Content": message.content }) + dict_list.append({"Role": message.role.value, "Content": message.content}) elif isinstance(message, ToolPromptMessage): - tool_execute_result = { "result": message.content } - content =json.dumps(tool_execute_result, ensure_ascii=False) - dict_list.append({ "Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id }) + tool_execute_result = {"result": message.content} + content = json.dumps(tool_execute_result, ensure_ascii=False) + dict_list.append({"Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id}) else: - dict_list.append({ "Role": message.role.value, "Content": message.content }) + dict_list.append({"Role": message.role.value, "Content": message.content}) return dict_list def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp): - tool_call = None tool_calls = [] for index, event in enumerate(resp): logging.debug("_handle_stream_chat_response, event: %s", event) - data_str = event['data'] + data_str = event["data"] data = json.loads(data_str) - choices = data.get('Choices', []) + choices = data.get("Choices", []) if not choices: continue choice = choices[0] - delta = choice.get('Delta', {}) - message_content = delta.get('Content', '') - finish_reason = choice.get('FinishReason', '') + delta = choice.get("Delta", {}) + message_content = delta.get("Content", "") + finish_reason = choice.get("FinishReason", "") - usage = data.get('Usage', {}) - prompt_tokens = usage.get('PromptTokens', 0) - completion_tokens = usage.get('CompletionTokens', 0) + usage = data.get("Usage", {}) + prompt_tokens = usage.get("PromptTokens", 0) + completion_tokens = usage.get("CompletionTokens", 0) - response_tool_calls = delta.get('ToolCalls') - if (response_tool_calls is not None): + response_tool_calls = delta.get("ToolCalls") + if response_tool_calls is not None: new_tool_calls = self._extract_response_tool_calls(response_tool_calls) - if (len(new_tool_calls) > 0): + if len(new_tool_calls) > 0: new_tool_call = new_tool_calls[0] - if (tool_call is None): tool_call = new_tool_call - elif (tool_call.id != new_tool_call.id): + if tool_call is None: + tool_call = new_tool_call + elif tool_call.id != new_tool_call.id: tool_calls.append(tool_call) tool_call = new_tool_call else: tool_call.function.name += new_tool_call.function.name tool_call.function.arguments += new_tool_call.function.arguments - if (tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0): + if tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0: tool_calls.append(tool_call) tool_call = None - assistant_prompt_message = AssistantPromptMessage( - content=message_content, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=message_content, tool_calls=[]) # rewrite content = "" while tool_call to avoid show content on web page - if (len(tool_calls) > 0): assistant_prompt_message.content = "" - + if len(tool_calls) > 0: + assistant_prompt_message.content = "" + # add tool_calls to assistant_prompt_message - if (finish_reason == 'tool_calls'): + if finish_reason == "tool_calls": assistant_prompt_message.tool_calls = tool_calls tool_call = None tool_calls = [] - if (len(finish_reason) > 0): + if len(finish_reason) > 0: usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) delta_chunk = LLMResultChunkDelta( index=index, - role=delta.get('Role', 'assistant'), + role=delta.get("Role", "assistant"), message=assistant_prompt_message, usage=usage, finish_reason=finish_reason, @@ -212,8 +222,9 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): ) def _handle_chat_response(self, credentials, model, prompt_messages, response): - usage = self._calc_response_usage(model, credentials, response.Usage.PromptTokens, - response.Usage.CompletionTokens) + usage = self._calc_response_usage( + model, credentials, response.Usage.PromptTokens, response.Usage.CompletionTokens + ) assistant_prompt_message = AssistantPromptMessage() assistant_prompt_message.content = response.Choices[0].Message.Content result = LLMResult( @@ -225,8 +236,13 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): return result - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: if len(prompt_messages) == 0: return 0 prompt = self._convert_messages_to_prompt(prompt_messages) @@ -241,10 +257,7 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -287,10 +300,8 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): return { InvokeError: [TencentCloudSDKException], } - - def _extract_response_tool_calls(self, - response_tool_calls: list[dict]) \ - -> list[AssistantPromptMessage.ToolCall]: + + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -300,17 +311,14 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): tool_calls = [] if response_tool_calls: for response_tool_call in response_tool_calls: - response_function = response_tool_call.get('Function', {}) + response_function = response_tool_call.get("Function", {}) function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function.get('Name', ''), - arguments=response_function.get('Arguments', '') + name=response_function.get("Name", ""), arguments=response_function.get("Arguments", "") ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.get('Id', 0), - type='function', - function=function + id=response_tool_call.get("Id", 0), type="function", function=function ) tool_calls.append(tool_call) - return tool_calls \ No newline at end of file + return tool_calls diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index 64d8dcf795..1396e59e18 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -19,14 +19,15 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE logger = logging.getLogger(__name__) + class HunyuanTextEmbeddingModel(TextEmbeddingModel): """ Model class for Hunyuan text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,9 +38,9 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): :return: embeddings result """ - if model != 'hunyuan-embedding': - raise ValueError('Invalid model name') - + if model != "hunyuan-embedding": + raise ValueError("Invalid model name") + client = self._setup_hunyuan_client(credentials) embeddings = [] @@ -47,9 +48,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): for input in texts: request = models.GetEmbeddingRequest() - params = { - "Input": input - } + params = {"Input": input} request.from_json_string(json.dumps(params)) response = client.GetEmbedding(request) usage = response.Usage.TotalTokens @@ -60,11 +59,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result @@ -79,22 +74,19 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): req = models.ChatCompletionsRequest() params = { "Model": model, - "Messages": [{ - "Role": "user", - "Content": "hello" - }], + "Messages": [{"Role": "user", "Content": "hello"}], "TopP": 1, "Temperature": 0, - "Stream": False + "Stream": False, } req.from_json_string(json.dumps(params)) client.ChatCompletions(req) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") def _setup_hunyuan_client(self, credentials): - secret_id = credentials['secret_id'] - secret_key = credentials['secret_key'] + secret_id = credentials["secret_id"] + secret_key = credentials["secret_key"] cred = credential.Credential(secret_id, secret_key) httpProfile = HttpProfile() httpProfile.endpoint = "hunyuan.tencentcloudapi.com" @@ -102,7 +94,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): clientProfile.httpProfile = httpProfile client = hunyuan_client.HunyuanClient(cred, "", clientProfile) return client - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -114,10 +106,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -128,11 +117,11 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -146,7 +135,7 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): return { InvokeError: [TencentCloudSDKException], } - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -170,4 +159,4 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel): # response = client.GetTokenCount(request) # num_tokens += response.TokenCount - return num_tokens \ No newline at end of file + return num_tokens diff --git a/api/core/model_runtime/model_providers/jina/jina.py b/api/core/model_runtime/model_providers/jina/jina.py index cde4313495..33977b6a33 100644 --- a/api/core/model_runtime/model_providers/jina/jina.py +++ b/api/core/model_runtime/model_providers/jina/jina.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class JinaProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ class JinaProvider(ModelProvider): # Use `jina-embeddings-v2-base-en` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials=credentials - ) + model_instance.validate_credentials(model="jina-embeddings-v2-base-en", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py index de7e038b9f..d8394f7a4c 100644 --- a/api/core/model_runtime/model_providers/jina/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -22,9 +22,16 @@ class JinaRerankModel(RerankModel): Model class for Jina rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -40,37 +47,32 @@ class JinaRerankModel(RerankModel): if len(docs) == 0: return RerankResult(model=model, docs=[]) - base_url = credentials.get('base_url', 'https://api.jina.ai/v1') - if base_url.endswith('/'): + base_url = credentials.get("base_url", "https://api.jina.ai/v1") + if base_url.endswith("/"): base_url = base_url[:-1] try: response = httpx.post( - base_url + '/rerank', - json={ - "model": model, - "query": query, - "documents": docs, - "top_n": top_n - }, - headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_n": top_n}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, ) - response.raise_for_status() + response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -81,7 +83,6 @@ class JinaRerankModel(RerankModel): :return: """ try: - self._invoke( model=model, credentials=credentials, @@ -92,7 +93,7 @@ class JinaRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -105,23 +106,21 @@ class JinaRerankModel(RerankModel): return { InvokeConnectionError: [httpx.ConnectError], InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')) - } + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index 50f8c73ed9..d80cbfa83d 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -14,19 +14,19 @@ class JinaTokenizer: with cls._lock: if cls._tokenizer is None: base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer') + gpt2_tokenizer_path = join(dirname(base_path), "tokenizer") cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path) return cls._tokenizer @classmethod def _get_num_tokens_by_jina_base(cls, text: str) -> int: """ - use jina tokenizer to get num tokens + use jina tokenizer to get num tokens """ tokenizer = cls._get_tokenizer() tokens = tokenizer.encode(text) return len(tokens) - + @classmethod def get_num_tokens(cls, text: str) -> int: - return cls._get_num_tokens_by_jina_base(text) \ No newline at end of file + return cls._get_num_tokens_by_jina_base(text) diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 23203491e6..7ed3e4d384 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -24,11 +24,12 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): """ Model class for Jina text embedding model. """ - api_base: str = 'https://api.jina.ai/v1' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://api.jina.ai/v1" + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -38,29 +39,23 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] + api_key = credentials["api_key"] if not api_key: - raise CredentialsValidateFailedError('api_key is required') + raise CredentialsValidateFailedError("api_key is required") - base_url = credentials.get('base_url', self.api_base) - if base_url.endswith('/'): + base_url = credentials.get("base_url", self.api_base) + if base_url.endswith("/"): base_url = base_url[:-1] - url = base_url + '/embeddings' - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + url = base_url + "/embeddings" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} def transform_jina_input_text(model, text): - if model == 'jina-clip-v1': + if model == "jina-clip-v1": return {"text": text} return text - data = { - 'model': model, - 'input': [transform_jina_input_text(model, text) for text in texts] - } + data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]} try: response = post(url, headers=headers, data=dumps(data)) @@ -70,7 +65,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): if response.status_code != 200: try: resp = response.json() - msg = resp['detail'] + msg = resp["detail"] if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -81,25 +76,20 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): raise InvokeBadRequestError(msg) except JSONDecodeError as e: raise InvokeServerUnavailableError( - f"Failed to convert response to json: {e} with text: {response.text}") + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: - raise InvokeServerUnavailableError( - f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -128,30 +118,18 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as e: - raise CredentialsValidateFailedError( - f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError, - InvokeBadRequestError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError, InvokeBadRequestError], } def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: @@ -165,10 +143,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,24 +154,21 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int( - credentials.get('context_size')) - } + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, ) return entity diff --git a/api/core/model_runtime/model_providers/leptonai/leptonai.py b/api/core/model_runtime/model_providers/leptonai/leptonai.py index b035c31ac5..34a55ff192 100644 --- a/api/core/model_runtime/model_providers/leptonai/leptonai.py +++ b/api/core/model_runtime/model_providers/leptonai/leptonai.py @@ -6,8 +6,8 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) -class LeptonAIProvider(ModelProvider): +class LeptonAIProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -18,12 +18,9 @@ class LeptonAIProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='llama2-7b', - credentials=credentials - ) + model_instance.validate_credentials(model="llama2-7b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/leptonai/llm/llm.py b/api/core/model_runtime/model_providers/leptonai/llm/llm.py index 523309bac5..3d69417e45 100644 --- a/api/core/model_runtime/model_providers/leptonai/llm/llm.py +++ b/api/core/model_runtime/model_providers/leptonai/llm/llm.py @@ -8,18 +8,25 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel): MODEL_PREFIX_MAP = { - 'llama2-7b': 'llama2-7b', - 'gemma-7b': 'gemma-7b', - 'mistral-7b': 'mistral-7b', - 'mixtral-8x7b': 'mixtral-8x7b', - 'llama3-70b': 'llama3-70b', - 'llama2-13b': 'llama2-13b', - } - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + "llama2-7b": "llama2-7b", + "gemma-7b": "gemma-7b", + "mistral-7b": "mistral-7b", + "mixtral-8x7b": "mixtral-8x7b", + "llama3-70b": "llama3-70b", + "llama2-13b": "llama2-13b", + } + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials, model) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -29,6 +36,5 @@ class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel): @classmethod def _add_custom_parameters(cls, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = f'https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1' - \ No newline at end of file + credentials["mode"] = "chat" + credentials["endpoint_url"] = f"https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1" diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 1009995c58..94c03efe7b 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -52,29 +52,48 @@ from core.model_runtime.utils import helper class LocalAILanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: # tools is not supported yet return self._num_tokens_from_messages(prompt_messages, tools=tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for baichuan model - LocalAI does not supports + Calculate num tokens for baichuan model + LocalAI does not supports """ def tokens(text: str): """ - We could not determine which tokenizer to use, cause the model is customized. - So we use gpt2 tokenizer to calculate the num tokens for convenience. + We could not determine which tokenizer to use, cause the model is customized. + So we use gpt2 tokenizer to calculate the num tokens for convenience. """ return self._get_num_tokens_by_gpt2(text) @@ -87,10 +106,10 @@ class LocalAILanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -142,30 +161,30 @@ class LocalAILanguageModel(LargeLanguageModel): num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) @@ -180,102 +199,104 @@ class LocalAILanguageModel(LargeLanguageModel): :return: """ try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content='ping') - ], model_parameters={ - 'max_tokens': 10, - }, stop=[], stream=False) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={ + "max_tokens": 10, + }, + stop=[], + stream=False, + ) except Exception as ex: - raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}') + raise CredentialsValidateFailedError(f"Invalid credentials {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: completion_model = None - if credentials['completion_type'] == 'chat_completion': + if credentials["completion_type"] == "chat_completion": completion_model = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_model = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, max=2048, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] - model_properties = { - ModelPropertyKey.MODE: completion_model, - } if completion_model else {} + model_properties = ( + { + ModelPropertyKey.MODE: completion_model, + } + if completion_model + else {} + ) - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048')) + model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get("context_size", "2048")) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties=model_properties, - parameter_rules=rules + parameter_rules=rules, ) return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: kwargs = self._to_client_kwargs(credentials) # init model client client = OpenAI(**kwargs) model_name = model - completion_type = credentials['completion_type'] + completion_type = credentials["completion_type"] extra_model_kwargs = { "timeout": 60, } if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if tools and len(tools) > 0: - extra_model_kwargs['functions'] = [ - helper.dump_model(tool) for tool in tools - ] + extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools] - if completion_type == 'chat_completion': + if completion_type == "chat_completion": result = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], model=model_name, @@ -283,36 +304,32 @@ class LocalAILanguageModel(LargeLanguageModel): **model_parameters, **extra_model_kwargs, ) - elif completion_type == 'completion': + elif completion_type == "completion": result = client.completions.create( prompt=self._convert_prompt_message_to_completion_prompts(prompt_messages), model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) else: raise ValueError(f"Unknown completion type {completion_type}") if stream: - if completion_type == 'completion': + if completion_type == "completion": return self._handle_completion_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - if completion_type == 'completion': + if completion_type == "completion": return self._handle_completion_generate_response( - model=model, credentials=credentials, response=result, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, prompt_messages=prompt_messages ) return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) def _to_client_kwargs(self, credentials: dict) -> dict: @@ -322,13 +339,13 @@ class LocalAILanguageModel(LargeLanguageModel): :param credentials: credentials dict :return: client kwargs """ - if not credentials['server_url'].endswith('/'): - credentials['server_url'] += '/' + if not credentials["server_url"].endswith("/"): + credentials["server_url"] += "/" client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": str(URL(credentials['server_url']) / 'v1'), + "base_url": str(URL(credentials["server_url"]) / "v1"), } return client_kwargs @@ -349,7 +366,7 @@ class LocalAILanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -359,11 +376,7 @@ class LocalAILanguageModel(LargeLanguageModel): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] + "content": [{"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content}], } else: raise ValueError(f"Unknown message type {type(message)}") @@ -374,27 +387,29 @@ class LocalAILanguageModel(LargeLanguageModel): """ Convert PromptMessage to completion prompts """ - prompts = '' + prompts = "" for message in messages: if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" else: raise ValueError(f"Unknown message type {type(message)}") return prompts - def _handle_completion_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Completion, - ) -> LLMResult: + def _handle_completion_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Completion, + ) -> LLMResult: """ Handle llm chat response @@ -411,18 +426,16 @@ class LocalAILanguageModel(LargeLanguageModel): assistant_message = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message, tool_calls=[]) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_completion_prompts(prompt_messages) ) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -434,11 +447,14 @@ class LocalAILanguageModel(LargeLanguageModel): return response - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: ChatCompletion, - tools: list[PromptMessageTool]) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: ChatCompletion, + tools: list[PromptMessageTool], + ) -> LLMResult: """ Handle llm chat response @@ -459,16 +475,14 @@ class LocalAILanguageModel(LargeLanguageModel): tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else []) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -480,12 +494,15 @@ class LocalAILanguageModel(LargeLanguageModel): return response - def _handle_completion_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Stream[Completion], - tools: list[PromptMessageTool]) -> Generator: - full_response = '' + def _handle_completion_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Stream[Completion], + tools: list[PromptMessageTool], + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -494,17 +511,11 @@ class LocalAILanguageModel(LargeLanguageModel): delta = chunk.choices[0] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[]) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_completion_prompts(prompt_messages) @@ -512,8 +523,12 @@ class LocalAILanguageModel(LargeLanguageModel): completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -523,7 +538,7 @@ class LocalAILanguageModel(LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -539,12 +554,15 @@ class LocalAILanguageModel(LargeLanguageModel): full_response += delta.text - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Stream[ChatCompletionChunk], - tools: list[PromptMessageTool]) -> Generator: - full_response = '' + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Stream[ChatCompletionChunk], + tools: list[PromptMessageTool], + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -552,7 +570,7 @@ class LocalAILanguageModel(LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue # check if there is a tool call in the response @@ -564,22 +582,24 @@ class LocalAILanguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -589,7 +609,7 @@ class LocalAILanguageModel(LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -605,9 +625,9 @@ class LocalAILanguageModel(LargeLanguageModel): full_response += delta.delta.content - def _extract_response_tool_calls(self, - response_function_calls: list[FunctionCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_function_calls: list[FunctionCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -618,15 +638,10 @@ class LocalAILanguageModel(LargeLanguageModel): if response_function_calls: for response_tool_call in response_function_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.name, - arguments=response_tool_call.arguments + name=response_tool_call.name, arguments=response_tool_call.arguments ) - tool_call = AssistantPromptMessage.ToolCall( - id=0, - type='function', - function=function - ) + tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function) tool_calls.append(tool_call) return tool_calls @@ -651,15 +666,9 @@ class LocalAILanguageModel(LargeLanguageModel): ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError + PermissionDeniedError, ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } diff --git a/api/core/model_runtime/model_providers/localai/localai.py b/api/core/model_runtime/model_providers/localai/localai.py index 6d2278fd54..4ff898052b 100644 --- a/api/core/model_runtime/model_providers/localai/localai.py +++ b/api/core/model_runtime/model_providers/localai/localai.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class LocalAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/localai/rerank/rerank.py b/api/core/model_runtime/model_providers/localai/rerank/rerank.py index c8ba9a6c7c..2b0f53bc19 100644 --- a/api/core/model_runtime/model_providers/localai/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/localai/rerank/rerank.py @@ -25,9 +25,16 @@ class LocalaiRerankModel(RerankModel): LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -43,45 +50,37 @@ class LocalaiRerankModel(RerankModel): if len(docs) == 0: return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] + server_url = credentials["server_url"] model_name = model - - if not server_url: - raise CredentialsValidateFailedError('server_url is required') - if not model_name: - raise CredentialsValidateFailedError('model_name is required') - - url = server_url - headers = { - 'Authorization': f"Bearer {credentials.get('api_key')}", - 'Content-Type': 'application/json' - } - data = { - "model": model_name, - "query": query, - "documents": docs, - "top_n": top_n - } + if not server_url: + raise CredentialsValidateFailedError("server_url is required") + if not model_name: + raise CredentialsValidateFailedError("model_name is required") + + url = server_url + headers = {"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"} + + data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n} try: - response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10) - response.raise_for_status() + response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=10) + response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -92,7 +91,6 @@ class LocalaiRerankModel(RerankModel): :return: """ try: - self._invoke( model=model, credentials=credentials, @@ -103,7 +101,7 @@ class LocalaiRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -116,21 +114,21 @@ class LocalaiRerankModel(RerankModel): return { InvokeConnectionError: [httpx.ConnectError], InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], } - + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={} + model_properties={}, ) return entity diff --git a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py index d7403aff4f..4b9d0f5bfe 100644 --- a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py @@ -32,8 +32,8 @@ class LocalAISpeech2text(Speech2TextModel): :param user: unique user id :return: text for given audio file """ - - url = str(URL(credentials['server_url']) / "v1/audio/transcriptions") + + url = str(URL(credentials["server_url"]) / "v1/audio/transcriptions") data = {"model": model} files = {"file": file} @@ -42,7 +42,7 @@ class LocalAISpeech2text(Speech2TextModel): prepared_request = session.prepare_request(request) response = session.send(prepared_request) - if 'error' in response.json(): + if "error" in response.json(): raise InvokeServerUnavailableError("Empty response") return response.json()["text"] @@ -58,7 +58,7 @@ class LocalAISpeech2text(Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -66,36 +66,24 @@ class LocalAISpeech2text(Speech2TextModel): @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError - ], + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index 954c9d10f2..7d258be81e 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -24,9 +24,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): """ Model class for Jina text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,39 +38,33 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): :return: embeddings result """ if len(texts) != 1: - raise InvokeBadRequestError('Only one text is supported') + raise InvokeBadRequestError("Only one text is supported") - server_url = credentials['server_url'] + server_url = credentials["server_url"] model_name = model if not server_url: - raise CredentialsValidateFailedError('server_url is required') + raise CredentialsValidateFailedError("server_url is required") if not model_name: - raise CredentialsValidateFailedError('model_name is required') - - url = server_url - headers = { - 'Authorization': 'Bearer 123', - 'Content-Type': 'application/json' - } + raise CredentialsValidateFailedError("model_name is required") - data = { - 'model': model_name, - 'input': texts[0] - } + url = server_url + headers = {"Authorization": "Bearer 123", "Content-Type": "application/json"} + + data = {"model": model_name, "input": texts[0]} try: - response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10) + response = post(str(URL(url) / "embeddings"), headers=headers, data=dumps(data), timeout=10) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() - code = resp['error']['code'] - msg = resp['error']['message'] + code = resp["error"]["code"] + msg = resp["error"]["message"] if code == 500: raise InvokeServerUnavailableError(msg) - + if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -79,23 +74,21 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): else: raise InvokeError(msg) except JSONDecodeError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -114,7 +107,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): # use GPT2Tokenizer to get num tokens num_tokens += self._get_num_tokens_by_gpt2(text) return num_tokens - + def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema @@ -130,10 +123,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): features=[], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512")), ModelPropertyKey.MAX_CHUNKS: 1, }, - parameter_rules=[] + parameter_rules=[], ) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -145,32 +138,22 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid credentials') + raise CredentialsValidateFailedError("Invalid credentials") except InvokeConnectionError as e: - raise CredentialsValidateFailedError(f'Invalid credentials: {e}') + raise CredentialsValidateFailedError(f"Invalid credentials: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -182,10 +165,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -196,7 +176,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 6c41e0d2a5..96f99c8929 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -17,42 +17,48 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage class MinimaxChatCompletion: """ - Minimax Chat Completion API + Minimax Chat Completion API """ - def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: list[MinimaxMessage], model_parameters: dict, - tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + + def generate( + self, + model: str, + api_key: str, + group_id: str, + prompt_messages: list[MinimaxMessage], + model_parameters: dict, + tools: list[dict[str, Any]], + stop: list[str] | None, + stream: bool, + user: str, + ) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ - generate chat completion + generate chat completion """ if not api_key or not group_id: - raise InvalidAPIKeyError('Invalid API key or group ID') - - url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}' + raise InvalidAPIKeyError("Invalid API key or group ID") + + url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}" extra_kwargs = {} - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - extra_kwargs['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + extra_kwargs["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - extra_kwargs['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + extra_kwargs["top_p"] = model_parameters["top_p"] - prompt = '你是一个什么都懂的专家' + prompt = "你是一个什么都懂的专家" - role_meta = { - 'user_name': '我', - 'bot_name': '专家' - } + role_meta = {"user_name": "我", "bot_name": "专家"} # check if there is a system message if len(prompt_messages) == 0: - raise BadRequestError('At least one message is required') - + raise BadRequestError("At least one message is required") + if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: prompt = prompt_messages[0].content @@ -60,40 +66,39 @@ class MinimaxChatCompletion: # check if there is a user message if len(prompt_messages) == 0: - raise BadRequestError('At least one user message is required') - - messages = [{ - 'sender_type': message.role, - 'text': message.content, - } for message in prompt_messages] + raise BadRequestError("At least one user message is required") - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + messages = [ + { + "sender_type": message.role, + "text": message.content, + } + for message in prompt_messages + ] + + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} body = { - 'model': model, - 'messages': messages, - 'prompt': prompt, - 'role_meta': role_meta, - 'stream': stream, - **extra_kwargs + "model": model, + "messages": messages, + "prompt": prompt, + "role_meta": role_meta, + "stream": stream, + **extra_kwargs, } try: - response = post( - url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) + response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) except Exception as e: raise InternalServerError(e) - + if response.status_code != 200: raise InternalServerError(response.text) - + if stream: return self._handle_stream_chat_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_error(self, code: int, msg: str): if code == 1000 or code == 1001 or code == 1013 or code == 1027: raise InternalServerError(msg) @@ -110,65 +115,52 @@ class MinimaxChatCompletion: def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ - handle chat generate response + handle chat generate response """ response = response.json() - if 'base_resp' in response and response['base_resp']['status_code'] != 0: - code = response['base_resp']['status_code'] - msg = response['base_resp']['status_msg'] + if "base_resp" in response and response["base_resp"]["status_code"] != 0: + code = response["base_resp"]["status_code"] + msg = response["base_resp"]["status_msg"] self._handle_error(code, msg) - - message = MinimaxMessage( - content=response['reply'], - role=MinimaxMessage.Role.ASSISTANT.value - ) + + message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': response['usage']['total_tokens'], - 'total_tokens': response['usage']['total_tokens'] + "prompt_tokens": 0, + "completion_tokens": response["usage"]["total_tokens"], + "total_tokens": response["usage"]["total_tokens"], } - message.stop_reason = response['choices'][0]['finish_reason'] + message.stop_reason = response["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: """ - handle stream chat generate response + handle stream chat generate response """ for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() data = loads(line) - if 'base_resp' in data and data['base_resp']['status_code'] != 0: - code = data['base_resp']['status_code'] - msg = data['base_resp']['status_msg'] + if "base_resp" in data and data["base_resp"]["status_code"] != 0: + code = data["base_resp"]["status_code"] + msg = data["base_resp"]["status_msg"] self._handle_error(code, msg) - if data['reply']: - total_tokens = data['usage']['total_tokens'] - message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) - message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': total_tokens, - 'total_tokens': total_tokens - } - message.stop_reason = data['choices'][0]['finish_reason'] + if data["reply"]: + total_tokens = data["usage"]["total_tokens"] + message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") + message.usage = {"prompt_tokens": 0, "completion_tokens": total_tokens, "total_tokens": total_tokens} + message.stop_reason = data["choices"][0]["finish_reason"] yield message return - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) == 0: continue for choice in choices: - message = choice['delta'] - yield MinimaxMessage( - content=message, - role=MinimaxMessage.Role.ASSISTANT.value - ) \ No newline at end of file + message = choice["delta"] + yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 55747057c9..0a2a67a56d 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -17,86 +17,83 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage class MinimaxChatCompletionPro: """ - Minimax Chat Completion Pro API, supports function calling - however, we do not have enough time and energy to implement it, but the parameters are reserved + Minimax Chat Completion Pro API, supports function calling + however, we do not have enough time and energy to implement it, but the parameters are reserved """ - def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: list[MinimaxMessage], model_parameters: dict, - tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + + def generate( + self, + model: str, + api_key: str, + group_id: str, + prompt_messages: list[MinimaxMessage], + model_parameters: dict, + tools: list[dict[str, Any]], + stop: list[str] | None, + stream: bool, + user: str, + ) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ - generate chat completion + generate chat completion """ if not api_key or not group_id: - raise InvalidAPIKeyError('Invalid API key or group ID') + raise InvalidAPIKeyError("Invalid API key or group ID") - url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}' + url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}" extra_kwargs = {} - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - extra_kwargs['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + extra_kwargs["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - extra_kwargs['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + extra_kwargs["top_p"] = model_parameters["top_p"] - if 'mask_sensitive_info' in model_parameters and type(model_parameters['mask_sensitive_info']) == bool: - extra_kwargs['mask_sensitive_info'] = model_parameters['mask_sensitive_info'] - - if model_parameters.get('plugin_web_search'): - extra_kwargs['plugins'] = [ - 'plugin_web_search' - ] + if "mask_sensitive_info" in model_parameters and type(model_parameters["mask_sensitive_info"]) == bool: + extra_kwargs["mask_sensitive_info"] = model_parameters["mask_sensitive_info"] - bot_setting = { - 'bot_name': '专家', - 'content': '你是一个什么都懂的专家' - } + if model_parameters.get("plugin_web_search"): + extra_kwargs["plugins"] = ["plugin_web_search"] - reply_constraints = { - 'sender_type': 'BOT', - 'sender_name': '专家' - } + bot_setting = {"bot_name": "专家", "content": "你是一个什么都懂的专家"} + + reply_constraints = {"sender_type": "BOT", "sender_name": "专家"} # check if there is a system message if len(prompt_messages) == 0: - raise BadRequestError('At least one message is required') + raise BadRequestError("At least one message is required") if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: - bot_setting['content'] = prompt_messages[0].content + bot_setting["content"] = prompt_messages[0].content prompt_messages = prompt_messages[1:] # check if there is a user message if len(prompt_messages) == 0: - raise BadRequestError('At least one user message is required') + raise BadRequestError("At least one user message is required") messages = [message.to_dict() for message in prompt_messages] - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} body = { - 'model': model, - 'messages': messages, - 'bot_setting': [bot_setting], - 'reply_constraints': reply_constraints, - 'stream': stream, - **extra_kwargs + "model": model, + "messages": messages, + "bot_setting": [bot_setting], + "reply_constraints": reply_constraints, + "stream": stream, + **extra_kwargs, } if tools: - body['functions'] = tools - body['function_call'] = {'type': 'auto'} + body["functions"] = tools + body["function_call"] = {"type": "auto"} try: - response = post( - url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) + response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) except Exception as e: raise InternalServerError(e) @@ -123,78 +120,72 @@ class MinimaxChatCompletionPro: def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ - handle chat generate response + handle chat generate response """ response = response.json() - if 'base_resp' in response and response['base_resp']['status_code'] != 0: - code = response['base_resp']['status_code'] - msg = response['base_resp']['status_msg'] + if "base_resp" in response and response["base_resp"]["status_code"] != 0: + code = response["base_resp"]["status_code"] + msg = response["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage( - content=response['reply'], - role=MinimaxMessage.Role.ASSISTANT.value - ) + message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': response['usage']['total_tokens'], - 'total_tokens': response['usage']['total_tokens'] + "prompt_tokens": 0, + "completion_tokens": response["usage"]["total_tokens"], + "total_tokens": response["usage"]["total_tokens"], } - message.stop_reason = response['choices'][0]['finish_reason'] + message.stop_reason = response["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: """ - handle stream chat generate response + handle stream chat generate response """ for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() data = loads(line) - if 'base_resp' in data and data['base_resp']['status_code'] != 0: - code = data['base_resp']['status_code'] - msg = data['base_resp']['status_msg'] + if "base_resp" in data and data["base_resp"]["status_code"] != 0: + code = data["base_resp"]["status_code"] + msg = data["base_resp"]["status_msg"] self._handle_error(code, msg) # final chunk - if data['reply'] or data.get('usage'): - total_tokens = data['usage']['total_tokens'] - minimax_message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) + if data["reply"] or data.get("usage"): + total_tokens = data["usage"]["total_tokens"] + minimax_message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") minimax_message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': total_tokens, - 'total_tokens': total_tokens + "prompt_tokens": 0, + "completion_tokens": total_tokens, + "total_tokens": total_tokens, } - minimax_message.stop_reason = data['choices'][0]['finish_reason'] + minimax_message.stop_reason = data["choices"][0]["finish_reason"] - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) > 0: for choice in choices: - message = choice['messages'][0] + message = choice["messages"][0] # append function_call message - if 'function_call' in message: - function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) - function_call_message.function_call = message['function_call'] + if "function_call" in message: + function_call_message = MinimaxMessage(content="", role=MinimaxMessage.Role.ASSISTANT.value) + function_call_message.function_call = message["function_call"] yield function_call_message yield minimax_message return # partial chunk - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) == 0: continue for choice in choices: - message = choice['messages'][0] + message = choice["messages"][0] # append text message - if 'text' in message: - minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value) + if "text" in message: + minimax_message = MinimaxMessage(content=message["text"], role=MinimaxMessage.Role.ASSISTANT.value) yield minimax_message diff --git a/api/core/model_runtime/model_providers/minimax/llm/errors.py b/api/core/model_runtime/model_providers/minimax/llm/errors.py index d9d279e6ca..309b5cf413 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/errors.py +++ b/api/core/model_runtime/model_providers/minimax/llm/errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index feeba75f49..76ed704a75 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -34,18 +34,25 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage class MinimaxLargeLanguageModel(LargeLanguageModel): model_apis = { - 'abab6.5s-chat': MinimaxChatCompletionPro, - 'abab6.5-chat': MinimaxChatCompletionPro, - 'abab6-chat': MinimaxChatCompletionPro, - 'abab5.5s-chat': MinimaxChatCompletionPro, - 'abab5.5-chat': MinimaxChatCompletionPro, - 'abab5-chat': MinimaxChatCompletion + "abab6.5s-chat": MinimaxChatCompletionPro, + "abab6.5-chat": MinimaxChatCompletionPro, + "abab6-chat": MinimaxChatCompletionPro, + "abab5.5s-chat": MinimaxChatCompletionPro, + "abab5.5-chat": MinimaxChatCompletionPro, + "abab5-chat": MinimaxChatCompletion, } - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -53,82 +60,97 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): Validate credentials for Baichuan model """ if model not in self.model_apis: - raise CredentialsValidateFailedError(f'Invalid model: {model}') + raise CredentialsValidateFailedError(f"Invalid model: {model}") - if not credentials.get('minimax_api_key'): - raise CredentialsValidateFailedError('Invalid API key') + if not credentials.get("minimax_api_key"): + raise CredentialsValidateFailedError("Invalid API key") + + if not credentials.get("minimax_group_id"): + raise CredentialsValidateFailedError("Invalid group ID") - if not credentials.get('minimax_group_id'): - raise CredentialsValidateFailedError('Invalid group ID') - # ping instance = MinimaxChatCompletionPro() try: instance.generate( - model=model, api_key=credentials['minimax_api_key'], group_id=credentials['minimax_group_id'], - prompt_messages=[ - MinimaxMessage(content='ping', role='USER') - ], + model=model, + api_key=credentials["minimax_api_key"], + group_id=credentials["minimax_group_id"], + prompt_messages=[MinimaxMessage(content="ping", role="USER")], model_parameters={}, - tools=[], stop=[], + tools=[], + stop=[], stream=False, - user='' + user="", ) except (InvalidAuthenticationError, InsufficientAccountBalanceError) as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages, tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for minimax model + Calculate num tokens for minimax model - not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way - to calculate the num tokens, so we use str() to convert the prompt to string + not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way + to calculate the num tokens, so we use str() to convert the prompt to string - Minimax does not provide their own tokenizer of adab5.5 and abab5 model - therefore, we use gpt2 tokenizer instead + Minimax does not provide their own tokenizer of adab5.5 and abab5 model + therefore, we use gpt2 tokenizer instead """ messages_dict = [self._convert_prompt_message_to_minimax_message(m).to_dict() for m in messages] return self._get_num_tokens_by_gpt2(str(messages_dict)) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface + use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface """ client: MinimaxChatCompletionPro = self.model_apis[model]() if tools: - tools = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + tools = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] response = client.generate( model=model, - api_key=credentials['minimax_api_key'], - group_id=credentials['minimax_group_id'], + api_key=credentials["minimax_api_key"], + group_id=credentials["minimax_group_id"], prompt_messages=[self._convert_prompt_message_to_minimax_message(message) for message in prompt_messages], model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if stream: - return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) - return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) + return self._handle_chat_generate_stream_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) + return self._handle_chat_generate_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessage) -> MinimaxMessage: """ - convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface + convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface """ if isinstance(prompt_message, SystemPromptMessage): return MinimaxMessage(role=MinimaxMessage.Role.SYSTEM.value, content=prompt_message.content) @@ -136,26 +158,27 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) elif isinstance(prompt_message, AssistantPromptMessage): if prompt_message.tool_calls: - message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) - message.function_call={ - 'name': prompt_message.tool_calls[0].function.name, - 'arguments': prompt_message.tool_calls[0].function.arguments + message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") + message.function_call = { + "name": prompt_message.tool_calls[0].function.name, + "arguments": prompt_message.tool_calls[0].function.arguments, } return message return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) elif isinstance(prompt_message, ToolPromptMessage): return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content) else: - raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') + raise NotImplementedError(f"Prompt message type {type(prompt_message)} is not supported") - def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens'] - ) + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage + ) -> LLMResult: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, @@ -166,31 +189,33 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], - credentials: dict, response: Generator[MinimaxMessage, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[MinimaxMessage, None, None], + ) -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens'] + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, finish_reason=message.stop_reason if message.stop_reason else None, ), ) elif message.function_call: - if 'name' not in message.function_call or 'arguments' not in message.function_call: + if "name" not in message.function_call or "arguments" not in message.function_call: continue yield LLMResultChunk( @@ -199,15 +224,16 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content='', - tool_calls=[AssistantPromptMessage.ToolCall( - id='', - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=message.function_call['name'], - arguments=message.function_call['arguments'] + content="", + tool_calls=[ + AssistantPromptMessage.ToolCall( + id="", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=message.function_call["name"], arguments=message.function_call["arguments"] + ), ) - )] + ], ), ), ) @@ -217,10 +243,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), finish_reason=message.stop_reason if message.stop_reason else None, ), ) @@ -236,22 +259,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index b33a7ca9ac..88ebe5e2e0 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -4,32 +4,27 @@ from typing import Any class MinimaxMessage: class Role(Enum): - USER = 'USER' - ASSISTANT = 'BOT' - SYSTEM = 'SYSTEM' - FUNCTION = 'FUNCTION' + USER = "USER" + ASSISTANT = "BOT" + SYSTEM = "SYSTEM" + FUNCTION = "FUNCTION" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" function_call: dict[str, Any] = None def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: - return { - 'sender_type': 'BOT', - 'sender_name': '专家', - 'text': '', - 'function_call': self.function_call - } - + return {"sender_type": "BOT", "sender_name": "专家", "text": "", "function_call": self.function_call} + return { - 'sender_type': self.role, - 'sender_name': '我' if self.role == 'USER' else '专家', - 'text': self.content, + "sender_type": self.role, + "sender_name": "我" if self.role == "USER" else "专家", + "text": self.content, } - - def __init__(self, content: str, role: str = 'USER') -> None: + + def __init__(self, content: str, role: str = "USER") -> None: self.content = content - self.role = role \ No newline at end of file + self.role = role diff --git a/api/core/model_runtime/model_providers/minimax/minimax.py b/api/core/model_runtime/model_providers/minimax/minimax.py index 52f6c2f1d3..5a761903a1 100644 --- a/api/core/model_runtime/model_providers/minimax/minimax.py +++ b/api/core/model_runtime/model_providers/minimax/minimax.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class MinimaxProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ class MinimaxProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `abab5.5-chat` model for validate, - model_instance.validate_credentials( - model='abab5.5-chat', - credentials=credentials - ) + model_instance.validate_credentials(model="abab5.5-chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') - raise CredentialsValidateFailedError(f'{ex}') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise CredentialsValidateFailedError(f"{ex}") diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index 85dc6ef51d..02a53708be 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -30,11 +30,12 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): """ Model class for Minimax text embedding model. """ - api_base: str = 'https://api.minimax.chat/v1/embeddings' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://api.minimax.chat/v1/embeddings" + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -44,54 +45,43 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['minimax_api_key'] - group_id = credentials['minimax_group_id'] - if model != 'embo-01': - raise ValueError('Invalid model name') + api_key = credentials["minimax_api_key"] + group_id = credentials["minimax_group_id"] + if model != "embo-01": + raise ValueError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') - url = f'{self.api_base}?GroupId={group_id}' - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + raise CredentialsValidateFailedError("api_key is required") + url = f"{self.api_base}?GroupId={group_id}" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': 'embo-01', - 'texts': texts, - 'type': 'db' - } + data = {"model": "embo-01", "texts": texts, "type": "db"} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: raise InvokeServerUnavailableError(response.text) - + try: resp = response.json() # check if there is an error - if resp['base_resp']['status_code'] != 0: - code = resp['base_resp']['status_code'] - msg = resp['base_resp']['status_msg'] + if resp["base_resp"]["status_code"] != 0: + code = resp["base_resp"]["status_code"] + msg = resp["base_resp"]["status_msg"] self._handle_error(code, msg) - embeddings = resp['vectors'] - total_tokens = resp['total_tokens'] + embeddings = resp["vectors"] + total_tokens = resp["total_tokens"] except InvalidAuthenticationError: - raise InvalidAPIKeyError('Invalid api key') + raise InvalidAPIKeyError("Invalid api key") except KeyError as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) return result @@ -119,9 +109,9 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvalidAPIKeyError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") def _handle_error(self, code: int, msg: str): if code == 1000 or code == 1001: @@ -148,25 +138,17 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -178,10 +160,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -192,7 +171,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/mistralai/llm/llm.py b/api/core/model_runtime/model_providers/mistralai/llm/llm.py index 01ed8010de..da60bd7661 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/llm.py +++ b/api/core/model_runtime/model_providers/mistralai/llm/llm.py @@ -7,14 +7,19 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) - + # mistral dose not support user/stop arguments stop = [] user = None @@ -27,5 +32,5 @@ class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.mistral.ai/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.mistral.ai/v1" diff --git a/api/core/model_runtime/model_providers/mistralai/mistralai.py b/api/core/model_runtime/model_providers/mistralai/mistralai.py index f1d825f6c6..7f9db8da1c 100644 --- a/api/core/model_runtime/model_providers/mistralai/mistralai.py +++ b/api/core/model_runtime/model_providers/mistralai/mistralai.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class MistralAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class MistralAIProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='open-mistral-7b', - credentials=credentials - ) + model_instance.validate_credentials(model="open-mistral-7b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index c233596637..3ea46c2967 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -30,11 +30,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) self._add_function_call(model, credentials) user = user[:32] if user else None @@ -49,50 +55,50 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): model=model, label=I18nObject(en_US=model, zh_Hans=model), model_type=ModelType.LLM, - features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] - if credentials.get('function_calling_type') == 'tool_call' - else [], + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), ModelPropertyKey.MODE: LLMMode.CHAT.value, }, parameter_rules=[ ParameterRule( - name='temperature', - use_template='temperature', - label=I18nObject(en_US='Temperature', zh_Hans='温度'), + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), type=ParameterType.FLOAT, ), ParameterRule( - name='max_tokens', - use_template='max_tokens', + name="max_tokens", + use_template="max_tokens", default=512, min=1, - max=int(credentials.get('max_tokens', 4096)), - label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + max=int(credentials.get("max_tokens", 4096)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), type=ParameterType.INT, ), ParameterRule( - name='top_p', - use_template='top_p', - label=I18nObject(en_US='Top P', zh_Hans='Top P'), + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), type=ParameterType.FLOAT, ), - ] + ], ) def _add_custom_parameters(self, credentials: dict) -> None: - credentials['mode'] = 'chat' - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' + credentials["mode"] = "chat" + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["endpoint_url"] = "https://api.moonshot.cn/v1" def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) - if model_schema and { - ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL - }.intersection(model_schema.features or []): - credentials['function_calling_type'] = 'tool_call' + if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection( + model_schema.features or [] + ): + credentials["function_calling_type"] = "tool_call" def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict: """ @@ -107,19 +113,13 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -129,14 +129,16 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): if message.tool_calls: message_dict["tool_calls"] = [] for function_call in message.tool_calls: - message_dict["tool_calls"].append({ - "id": function_call.id, - "type": function_call.type, - "function": { - "name": function_call.function.name, - "arguments": function_call.function.arguments + message_dict["tool_calls"].append( + { + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + }, } - }) + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} @@ -162,21 +164,26 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", - arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + name=response_tool_call["function"]["name"] + if response_tool_call.get("function", {}).get("name") + else "", + arguments=response_tool_call["function"]["arguments"] + if response_tool_call.get("function", {}).get("arguments") + else "", ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call["id"] if response_tool_call.get("id") else "", type=response_tool_call["type"] if response_tool_call.get("type") else "", - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -186,11 +193,12 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -201,12 +209,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) tools_calls: list[AssistantPromptMessage.ToolCall] = [] @@ -220,9 +223,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) if tool_call is None: tool_call = AssistantPromptMessage.ToolCall( - id='', - type='', - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + id="", + type="", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), ) tools_calls.append(tool_call) @@ -244,9 +247,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) @@ -255,21 +258,21 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = delta.get("tool_calls", None) # assistant_message_function_call = delta.delta.function_call # extract tool calls from response @@ -277,19 +280,18 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] + content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else [] ) full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -305,26 +307,21 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 - + if tools_calls: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason - ) \ No newline at end of file + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason + ) diff --git a/api/core/model_runtime/model_providers/moonshot/moonshot.py b/api/core/model_runtime/model_providers/moonshot/moonshot.py index 5654ae1459..4995e235f5 100644 --- a/api/core/model_runtime/model_providers/moonshot/moonshot.py +++ b/api/core/model_runtime/model_providers/moonshot/moonshot.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class MoonshotProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class MoonshotProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='moonshot-v1-8k', - credentials=credentials - ) + model_instance.validate_credentials(model="moonshot-v1-8k", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/novita/llm/llm.py b/api/core/model_runtime/model_providers/novita/llm/llm.py index 7662bf914a..23367ed1b4 100644 --- a/api/core/model_runtime/model_providers/novita/llm/llm.py +++ b/api/core/model_runtime/model_providers/novita/llm/llm.py @@ -8,20 +8,25 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_endpoint_url(self, credentials: dict): - - credentials['endpoint_url'] = "https://api.novita.ai/v3/openai" - credentials['extra_headers'] = { 'X-Novita-Source': 'dify.ai' } + credentials["endpoint_url"] = "https://api.novita.ai/v3/openai" + credentials["extra_headers"] = {"X-Novita-Source": "dify.ai"} return credentials - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + def validate_credentials(self, model: str, credentials: dict) -> None: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) self._add_custom_parameters(credentials, model) @@ -29,21 +34,36 @@ class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel): @classmethod def _add_custom_parameters(cls, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' + credentials["mode"] = "chat" - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) - return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate( + model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user + ) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_customizable_model_schema(model, cred_with_endpoint) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) diff --git a/api/core/model_runtime/model_providers/novita/novita.py b/api/core/model_runtime/model_providers/novita/novita.py index f1b7224605..76a75b01e2 100644 --- a/api/core/model_runtime/model_providers/novita/novita.py +++ b/api/core/model_runtime/model_providers/novita/novita.py @@ -20,12 +20,9 @@ class NovitaProvider(ModelProvider): # Use `meta-llama/llama-3-8b-instruct` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials=credentials - ) + model_instance.validate_credentials(model="meta-llama/llama-3-8b-instruct", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py index bc42eaca65..4d3747dc84 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py @@ -21,31 +21,36 @@ from core.model_runtime.utils import helper class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): MODEL_SUFFIX_MAP = { - 'fuyu-8b': 'vlm/adept/fuyu-8b', - 'mistralai/mistral-large': '', - 'mistralai/mixtral-8x7b-instruct-v0.1': '', - 'mistralai/mixtral-8x22b-instruct-v0.1': '', - 'google/gemma-7b': '', - 'google/codegemma-7b': '', - 'snowflake/arctic':'', - 'meta/llama2-70b': '', - 'meta/llama3-8b-instruct': '', - 'meta/llama3-70b-instruct': '', - 'meta/llama-3.1-8b-instruct': '', - 'meta/llama-3.1-70b-instruct': '', - 'meta/llama-3.1-405b-instruct': '', - 'google/recurrentgemma-2b': '', - 'nvidia/nemotron-4-340b-instruct': '', - 'microsoft/phi-3-medium-128k-instruct':'', - 'microsoft/phi-3-mini-128k-instruct':'' + "fuyu-8b": "vlm/adept/fuyu-8b", + "mistralai/mistral-large": "", + "mistralai/mixtral-8x7b-instruct-v0.1": "", + "mistralai/mixtral-8x22b-instruct-v0.1": "", + "google/gemma-7b": "", + "google/codegemma-7b": "", + "snowflake/arctic": "", + "meta/llama2-70b": "", + "meta/llama3-8b-instruct": "", + "meta/llama3-70b-instruct": "", + "meta/llama-3.1-8b-instruct": "", + "meta/llama-3.1-70b-instruct": "", + "meta/llama-3.1-405b-instruct": "", + "google/recurrentgemma-2b": "", + "nvidia/nemotron-4-340b-instruct": "", + "microsoft/phi-3-medium-128k-instruct": "", + "microsoft/phi-3-mini-128k-instruct": "", } - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials, model) prompt_messages = self._transform_prompt_messages(prompt_messages) stop = [] @@ -60,16 +65,14 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): for i, p in enumerate(prompt_messages): if isinstance(p, UserPromptMessage) and isinstance(p.content, list): content = p.content - content_text = '' + content_text = "" for prompt_content in content: if prompt_content.type == PromptMessageContentType.TEXT: content_text += prompt_content.data else: content_text += f' ' - prompt_message = UserPromptMessage( - content=content_text - ) + prompt_message = UserPromptMessage(content=content_text) prompt_messages[i] = prompt_message return prompt_messages @@ -78,15 +81,15 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): self._validate_credentials(model, credentials) def _add_custom_parameters(self, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' - - if self.MODEL_SUFFIX_MAP[model]: - credentials['server_url'] = f'https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}' - credentials.pop('endpoint_url') - else: - credentials['endpoint_url'] = 'https://integrate.api.nvidia.com/v1' + credentials["mode"] = "chat" - credentials['stream_mode_delimiter'] = '\n' + if self.MODEL_SUFFIX_MAP[model]: + credentials["server_url"] = f"https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}" + credentials.pop("endpoint_url") + else: + credentials["endpoint_url"] = "https://integrate.api.nvidia.com/v1" + + credentials["stream_mode_delimiter"] = "\n" def _validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -97,72 +100,67 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if endpoint_url and not endpoint_url.endswith('/'): - endpoint_url += '/' - server_url = credentials.get('server_url') + endpoint_url = credentials.get("endpoint_url") + if endpoint_url and not endpoint_url.endswith("/"): + endpoint_url += "/" + server_url = credentials.get("server_url") # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'max_tokens': 5 - } + data = {"model": model, "max_tokens": 5} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - data['messages'] = [ - { - "role": "user", - "content": "ping" - }, + data["messages"] = [ + {"role": "user", "content": "ping"}, ] - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions') - elif 'server_url' in credentials: + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "chat" / "completions") + elif "server_url" in credentials: endpoint_url = server_url elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'completions') - elif 'server_url' in credentials: + data["prompt"] = "ping" + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "completions") + elif "server_url" in credentials: endpoint_url = server_url else: raise ValueError("Unsupported completion type for model configuration.") # send a post request to validate the credentials - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") except CredentialsValidateFailedError: raise except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -176,57 +174,51 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): :return: full response or stream response chunk generator result """ headers = { - 'Content-Type': 'application/json', - 'Accept-Charset': 'utf-8', + "Content-Type": "application/json", + "Accept-Charset": "utf-8", } - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: - headers['Authorization'] = f'Bearer {api_key}' + headers["Authorization"] = f"Bearer {api_key}" if stream: - headers['Accept'] = 'text/event-stream' + headers["Accept"] = "text/event-stream" - endpoint_url = credentials.get('endpoint_url') - if endpoint_url and not endpoint_url.endswith('/'): - endpoint_url += '/' - server_url = credentials.get('server_url') + endpoint_url = credentials.get("endpoint_url") + if endpoint_url and not endpoint_url.endswith("/"): + endpoint_url += "/" + server_url = credentials.get("server_url") - data = { - "model": model, - "stream": stream, - **model_parameters - } + data = {"model": model, "stream": stream, **model_parameters} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions') - elif 'server_url' in credentials: + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "chat" / "completions") + elif "server_url" in credentials: endpoint_url = server_url - data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] + data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'completions') - elif 'server_url' in credentials: + data["prompt"] = "ping" + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "completions") + elif "server_url" in credentials: endpoint_url = server_url else: raise ValueError("Unsupported completion type for model configuration.") - # annotate tools with names, descriptions, etc. - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") formatted_tools = [] if tools: - if function_calling_type == 'function_call': - data['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] - elif function_calling_type == 'tool_call': + if function_calling_type == "function_call": + data["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} + for tool in tools + ] + elif function_calling_type == "tool_call": data["tool_choice"] = "auto" for tool in tools: @@ -240,16 +232,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): if user: data["user"] = user - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300), - stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) - if response.encoding is None or response.encoding == 'ISO-8859-1': - response.encoding = 'utf-8' + if response.encoding is None or response.encoding == "ISO-8859-1": + response.encoding = "utf-8" if not response.ok: raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") diff --git a/api/core/model_runtime/model_providers/nvidia/nvidia.py b/api/core/model_runtime/model_providers/nvidia/nvidia.py index e83f8badb5..058fa00346 100644 --- a/api/core/model_runtime/model_providers/nvidia/nvidia.py +++ b/api/core/model_runtime/model_providers/nvidia/nvidia.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class MistralAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class MistralAIProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='mistralai/mixtral-8x7b-instruct-v0.1', - credentials=credentials - ) + model_instance.validate_credentials(model="mistralai/mixtral-8x7b-instruct-v0.1", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py index 80c24b0555..fabebc67ab 100644 --- a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py @@ -22,11 +22,18 @@ class NvidiaRerankModel(RerankModel): """ def _sigmoid(self, logit: float) -> float: - return 1/(1+exp(-logit)) + return 1 / (1 + exp(-logit)) - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -60,9 +67,9 @@ class NvidiaRerankModel(RerankModel): results = response.json() rerank_documents = [] - for result in results['rankings']: - index = result['index'] - logit = result['logit'] + for result in results["rankings"]: + index = result["index"] + logit = result["logit"] rerank_document = RerankDocument( index=index, text=docs[index], @@ -110,5 +117,5 @@ class NvidiaRerankModel(RerankModel): InvokeServerUnavailableError: [requests.HTTPError], InvokeRateLimitError: [], InvokeAuthorizationError: [requests.HTTPError], - InvokeBadRequestError: [requests.RequestException] + InvokeBadRequestError: [requests.RequestException], } diff --git a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py index a2adef400d..00cec265d5 100644 --- a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py @@ -22,12 +22,13 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): """ Model class for Nvidia text embedding model. """ - api_base: str = 'https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings' - models: list[str] = ['NV-Embed-QA'] - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings" + models: list[str] = ["NV-Embed-QA"] + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,32 +38,25 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] + api_key = credentials["api_key"] if model not in self.models: - raise InvokeBadRequestError('Invalid model name') + raise InvokeBadRequestError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') + raise CredentialsValidateFailedError("api_key is required") url = self.api_base - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': model, - 'input': texts[0], - 'input_type': 'query' - } + data = {"model": model, "input": texts[0], "input_type": "query"} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() - msg = resp['detail'] + msg = resp["detail"] if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -72,23 +66,21 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): else: raise InvokeError(msg) except JSONDecodeError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -117,30 +109,20 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -152,10 +134,7 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -166,7 +145,7 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py index f7b849fbe2..6ff380bdd9 100644 --- a/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py @@ -9,4 +9,5 @@ class NVIDIANIMProvider(OAIAPICompatLargeLanguageModel): """ Model class for NVIDIA NIM large language model. """ + pass diff --git a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py index 25ab3e8e20..ad890ada22 100644 --- a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py +++ b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class NVIDIANIMProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/oci/llm/llm.py b/api/core/model_runtime/model_providers/oci/llm/llm.py index 37787c459d..ad5197a154 100644 --- a/api/core/model_runtime/model_providers/oci/llm/llm.py +++ b/api/core/model_runtime/model_providers/oci/llm/llm.py @@ -33,31 +33,29 @@ logger = logging.getLogger(__name__) request_template = { "compartmentId": "", - "servingMode": { - "modelId": "cohere.command-r-plus", - "servingType": "ON_DEMAND" - }, + "servingMode": {"modelId": "cohere.command-r-plus", "servingType": "ON_DEMAND"}, "chatRequest": { "apiFormat": "COHERE", - #"preambleOverride": "You are a helpful assistant.", - #"message": "Hello!", - #"chatHistory": [], + # "preambleOverride": "You are a helpful assistant.", + # "message": "Hello!", + # "chatHistory": [], "maxTokens": 600, "isStream": False, "frequencyPenalty": 0, "presencePenalty": 0, "temperature": 1, - "topP": 0.75 - } + "topP": 0.75, + }, } oci_config_template = { - "user": "", - "fingerprint": "", - "tenancy": "", - "region": "", - "compartment_id": "", - "key_content": "" - } + "user": "", + "fingerprint": "", + "tenancy": "", + "region": "", + "compartment_id": "", + "key_content": "", +} + class OCILargeLanguageModel(LargeLanguageModel): # https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm @@ -100,11 +98,17 @@ class OCILargeLanguageModel(LargeLanguageModel): return False return feature["system"] - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -118,22 +122,27 @@ class OCILargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - #print("model"+"*"*20) - #print(model) - #print("credentials"+"*"*20) - #print(credentials) - #print("model_parameters"+"*"*20) - #print(model_parameters) - #print("prompt_messages"+"*"*200) - #print(prompt_messages) - #print("tools"+"*"*20) - #print(tools) + # print("model"+"*"*20) + # print(model) + # print("credentials"+"*"*20) + # print(credentials) + # print("model_parameters"+"*"*20) + # print(model_parameters) + # print("prompt_messages"+"*"*200) + # print(prompt_messages) + # print("tools"+"*"*20) + # print(tools) # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -147,8 +156,13 @@ class OCILargeLanguageModel(LargeLanguageModel): return self._get_num_tokens_by_gpt2(prompt) - def get_num_characters(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_characters( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -169,10 +183,7 @@ class OCILargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() @@ -192,11 +203,17 @@ class OCILargeLanguageModel(LargeLanguageModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -218,10 +235,12 @@ class OCILargeLanguageModel(LargeLanguageModel): # ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat oci_config = copy.deepcopy(oci_config_template) if "oci_config_content" in credentials: - oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8') + oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8") config_items = oci_config_content.split("/") if len(config_items) != 5: - raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))") + raise CredentialsValidateFailedError( + "oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))" + ) oci_config["user"] = config_items[0] oci_config["fingerprint"] = config_items[1] oci_config["tenancy"] = config_items[2] @@ -230,12 +249,12 @@ class OCILargeLanguageModel(LargeLanguageModel): else: raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") if "oci_key_content" in credentials: - oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8') + oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8") oci_config["key_content"] = oci_key_content.encode(encoding="utf-8") else: raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") - #oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile')) + # oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile')) compartment_id = oci_config["compartment_id"] client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config) # call embedding model @@ -245,9 +264,9 @@ class OCILargeLanguageModel(LargeLanguageModel): chat_history = [] system_prompts = [] - #if "meta.llama" in model: + # if "meta.llama" in model: # request_args["chatRequest"]["apiFormat"] = "GENERIC" - request_args["chatRequest"]["maxTokens"] = model_parameters.pop('maxTokens', 600) + request_args["chatRequest"]["maxTokens"] = model_parameters.pop("maxTokens", 600) request_args["chatRequest"].update(model_parameters) frequency_penalty = model_parameters.get("frequencyPenalty", 0) presence_penalty = model_parameters.get("presencePenalty", 0) @@ -267,7 +286,7 @@ class OCILargeLanguageModel(LargeLanguageModel): if not valid_value: raise InvokeBadRequestError("Does not support function calling") if model.startswith("cohere"): - #print("run cohere " * 10) + # print("run cohere " * 10) for message in prompt_messages[:-1]: text = "" if isinstance(message.content, str): @@ -279,37 +298,37 @@ class OCILargeLanguageModel(LargeLanguageModel): if isinstance(message, SystemPromptMessage): if isinstance(message.content, str): system_prompts.append(message.content) - args = {"apiFormat": "COHERE", - "preambleOverride": ' '.join(system_prompts), - "message": prompt_messages[-1].content, - "chatHistory": chat_history, } + args = { + "apiFormat": "COHERE", + "preambleOverride": " ".join(system_prompts), + "message": prompt_messages[-1].content, + "chatHistory": chat_history, + } request_args["chatRequest"].update(args) elif model.startswith("meta"): - #print("run meta " * 10) + # print("run meta " * 10) meta_messages = [] for message in prompt_messages: text = message.content meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]}) - args = {"apiFormat": "GENERIC", - "messages": meta_messages, - "numGenerations": 1, - "topK": -1} + args = {"apiFormat": "GENERIC", "messages": meta_messages, "numGenerations": 1, "topK": -1} request_args["chatRequest"].update(args) if stream: request_args["chatRequest"]["isStream"] = True - #print("final request" + "|" * 20) - #print(request_args) + # print("final request" + "|" * 20) + # print(request_args) response = client.chat(request_args) - #print(vars(response)) + # print(vars(response)) if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: BaseChatResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -320,9 +339,7 @@ class OCILargeLanguageModel(LargeLanguageModel): :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.data.chat_response.text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.data.chat_response.text) # calculate num tokens prompt_tokens = self.get_num_characters(model, credentials, prompt_messages) @@ -341,8 +358,9 @@ class OCILargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: BaseChatResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -356,14 +374,12 @@ class OCILargeLanguageModel(LargeLanguageModel): events = response.data.events() for stream in events: chunk = json.loads(stream.data) - #print(chunk) - #chunk: {'apiFormat': 'COHERE', 'text': 'Hello'} + # print(chunk) + # chunk: {'apiFormat': 'COHERE', 'text': 'Hello'} - - - #for chunk in response: - #for part in chunk.parts: - #if part.function_call: + # for chunk in response: + # for part in chunk.parts: + # if part.function_call: # assistant_prompt_message.tool_calls = [ # AssistantPromptMessage.ToolCall( # id=part.function_call.name, @@ -376,9 +392,7 @@ class OCILargeLanguageModel(LargeLanguageModel): # ] if "finishReason" not in chunk: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if model.startswith("cohere"): if chunk["text"]: assistant_prompt_message.content += chunk["text"] @@ -389,10 +403,7 @@ class OCILargeLanguageModel(LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: # calculate num tokens @@ -409,8 +420,8 @@ class OCILargeLanguageModel(LargeLanguageModel): index=index, message=assistant_prompt_message, finish_reason=str(chunk["finishReason"]), - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -425,9 +436,7 @@ class OCILargeLanguageModel(LargeLanguageModel): content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" @@ -457,5 +466,5 @@ class OCILargeLanguageModel(LargeLanguageModel): InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/oci/oci.py b/api/core/model_runtime/model_providers/oci/oci.py index 11d67790a0..e182d2d043 100644 --- a/api/core/model_runtime/model_providers/oci/oci.py +++ b/api/core/model_runtime/model_providers/oci/oci.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class OCIGENAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,14 +20,9 @@ class OCIGENAIProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `cohere.command-r-plus` model for validate, - model_instance.validate_credentials( - model='cohere.command-r-plus', - credentials=credentials - ) + model_instance.validate_credentials(model="cohere.command-r-plus", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex - - diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py index 5e0a85583e..df77db47d9 100644 --- a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -21,29 +21,28 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE request_template = { "compartmentId": "", - "servingMode": { - "modelId": "cohere.embed-english-light-v3.0", - "servingType": "ON_DEMAND" - }, + "servingMode": {"modelId": "cohere.embed-english-light-v3.0", "servingType": "ON_DEMAND"}, "truncate": "NONE", - "inputs": [""] + "inputs": [""], } oci_config_template = { - "user": "", - "fingerprint": "", - "tenancy": "", - "region": "", - "compartment_id": "", - "key_content": "" - } + "user": "", + "fingerprint": "", + "tenancy": "", + "region": "", + "compartment_id": "", + "key_content": "", +} + + class OCITextEmbeddingModel(TextEmbeddingModel): """ Model class for Cohere text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -62,14 +61,13 @@ class OCITextEmbeddingModel(TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) if num_tokens >= context_size: cutoff = int(len(text) * (np.floor(context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -80,26 +78,16 @@ class OCITextEmbeddingModel(TextEmbeddingModel): for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=inputs[i: i + max_chunks] + model=model, credentials=credentials, texts=inputs[i : i + max_chunks] ) used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -125,6 +113,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel): for text in texts: characters += len(text) return characters + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials @@ -135,11 +124,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel): """ try: # call embedding model - self._embedding_invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._embedding_invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -157,10 +142,12 @@ class OCITextEmbeddingModel(TextEmbeddingModel): # initialize client oci_config = copy.deepcopy(oci_config_template) if "oci_config_content" in credentials: - oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8') + oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8") config_items = oci_config_content.split("/") if len(config_items) != 5: - raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))") + raise CredentialsValidateFailedError( + "oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))" + ) oci_config["user"] = config_items[0] oci_config["fingerprint"] = config_items[1] oci_config["tenancy"] = config_items[2] @@ -169,7 +156,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel): else: raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") if "oci_key_content" in credentials: - oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8') + oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8") oci_config["key_content"] = oci_key_content.encode(encoding="utf-8") else: raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") @@ -195,10 +182,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -209,7 +193,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -224,19 +208,9 @@ class OCITextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index 42a588e3dd..160eea0148 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -121,9 +121,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): text = "" for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data break return self._get_num_tokens_by_gpt2(text) @@ -145,13 +143,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel): stream=False, ) except InvokeError as ex: - raise CredentialsValidateFailedError( - f"An error occurred during credentials validation: {ex.description}" - ) + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}") except Exception as ex: - raise CredentialsValidateFailedError( - f"An error occurred during credentials validation: {str(ex)}" - ) + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def _generate( self, @@ -201,9 +195,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if completion_type is LLMMode.CHAT: endpoint_url = urljoin(endpoint_url, "api/chat") - data["messages"] = [ - self._convert_prompt_message_to_dict(m) for m in prompt_messages - ] + data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] else: endpoint_url = urljoin(endpoint_url, "api/generate") first_prompt_message = prompt_messages[0] @@ -216,14 +208,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel): images = [] for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content - ) + message_content = cast(ImagePromptMessageContent, message_content) image_data = re.sub( r"^data:image\/[a-zA-Z]+;base64,", "", @@ -235,24 +223,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel): data["images"] = images # send a post request to validate the credentials - response = requests.post( - endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) response.encoding = "utf-8" if response.status_code != 200: - raise InvokeError( - f"API request failed with status code {response.status_code}: {response.text}" - ) + raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") if stream: - return self._handle_generate_stream_response( - model, credentials, completion_type, response, prompt_messages - ) + return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) - return self._handle_generate_response( - model, credentials, completion_type, response, prompt_messages - ) + return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) def _handle_generate_response( self, @@ -292,9 +272,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) # transform response result = LLMResult( @@ -335,9 +313,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) return LLMResultChunk( model=model, @@ -394,15 +370,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_tokens = chunk_json["eval_count"] else: # calculate num tokens - prompt_tokens = self._get_num_tokens_by_gpt2( - prompt_messages[0].content - ) + prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content) completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) yield LLMResultChunk( model=chunk_json["model"], @@ -439,17 +411,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel): images = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content - ) - image_data = re.sub( - r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data - ) + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) images.append(image_data) message_dict = {"role": "user", "content": text, "images": images} @@ -479,9 +445,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): return num_tokens - def get_customizable_model_schema( - self, model: str, credentials: dict - ) -> AIModelEntity: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ Get customizable model schema. @@ -502,9 +466,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ ModelPropertyKey.MODE: credentials.get("mode"), - ModelPropertyKey.CONTEXT_SIZE: int( - credentials.get("context_size", 4096) - ), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), }, parameter_rules=[ ParameterRule( @@ -568,9 +530,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): en_US="Maximum number of tokens to predict when generating text. " "(Default: 128, -1 = infinite generation, -2 = fill context)" ), - default=( - 512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128 - ), + default=(512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128), min=-2, max=int(credentials.get("max_tokens", 4096)), ), @@ -612,22 +572,23 @@ class OllamaLargeLanguageModel(LargeLanguageModel): label=I18nObject(en_US="Size of context window"), type=ParameterType.INT, help=I18nObject( - en_US="Sets the size of the context window used to generate the next token. " - "(Default: 2048)" + en_US="Sets the size of the context window used to generate the next token. " "(Default: 2048)" ), default=2048, min=1, ), ParameterRule( - name='num_gpu', + name="num_gpu", label=I18nObject(en_US="GPU Layers"), type=ParameterType.INT, - help=I18nObject(en_US="The number of layers to offload to the GPU(s). " - "On macOS it defaults to 1 to enable metal support, 0 to disable." - "As long as a model fits into one gpu it stays in one. " - "It does not set the number of GPU(s). "), + help=I18nObject( + en_US="The number of layers to offload to the GPU(s). " + "On macOS it defaults to 1 to enable metal support, 0 to disable." + "As long as a model fits into one gpu it stays in one. " + "It does not set the number of GPU(s). " + ), min=-1, - default=1 + default=1, ), ParameterRule( name="num_thread", @@ -688,8 +649,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): label=I18nObject(en_US="Format"), type=ParameterType.STRING, help=I18nObject( - en_US="the format to return a response in." - " Currently the only accepted value is json." + en_US="the format to return a response in." " Currently the only accepted value is json." ), options=["json"], ), diff --git a/api/core/model_runtime/model_providers/ollama/ollama.py b/api/core/model_runtime/model_providers/ollama/ollama.py index f8a17b98a0..115280193a 100644 --- a/api/core/model_runtime/model_providers/ollama/ollama.py +++ b/api/core/model_runtime/model_providers/ollama/ollama.py @@ -6,7 +6,6 @@ logger = logging.getLogger(__name__) class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 8f7d54c516..60b85197be 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -37,9 +37,9 @@ class OllamaEmbeddingModel(TextEmbeddingModel): Model class for an Ollama text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -51,15 +51,13 @@ class OllamaEmbeddingModel(TextEmbeddingModel): """ # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - endpoint_url = credentials.get('base_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("base_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'api/embed') + endpoint_url = urljoin(endpoint_url, "api/embed") # get model properties context_size = self._get_context_size(model, credentials) @@ -74,46 +72,34 @@ class OllamaEmbeddingModel(TextEmbeddingModel): if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) # Prepare the payload for the request payload = { - 'input': inputs, - 'model': model, + "input": inputs, + "model": model, } # Make the request to the OpenAI API response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300), - options={"use_mmap": "true"} + endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300), options={"use_mmap": "true"} ) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings = response_data['embeddings'] + embeddings = response_data["embeddings"] embedding_used_tokens = self.get_num_tokens(model, credentials, inputs) used_tokens += embedding_used_tokens # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -135,19 +121,15 @@ class OllamaEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}") except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -155,15 +137,15 @@ class OllamaEmbeddingModel(TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity @@ -179,10 +161,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -193,7 +172,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -221,10 +200,10 @@ class OllamaEmbeddingModel(TextEmbeddingModel): ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] + requests.exceptions.ReadTimeout, # Timeout + ], } diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 467a51daf2..2181bb4f08 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -22,7 +22,7 @@ class _CommonOpenAI: :return: """ credentials_kwargs = { - "api_key": credentials['openai_api_key'], + "api_key": credentials["openai_api_key"], "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } @@ -31,8 +31,8 @@ class _CommonOpenAI: openai_api_base = credentials["openai_api_base"].rstrip("/") credentials_kwargs["base_url"] = openai_api_base + "/v1" - if 'openai_organization' in credentials: - credentials_kwargs['organization'] = credentials['openai_organization'] + if "openai_organization" in credentials: + credentials_kwargs["organization"] = credentials["openai_organization"] return credentials_kwargs diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index dc85f7c9f2..5950b77a96 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -39,16 +39,23 @@ if you are not sure about the structure. """ + class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ Model class for OpenAI large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -64,8 +71,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ # handle fine tune remote models base_model = model - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # get model mode model_mode = self.get_model_mode(base_model, credentials) @@ -80,7 +87,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: # text completion model @@ -91,26 +98,34 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ # handle fine tune remote models base_model = model - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # get model mode model_mode = self.get_model_mode(base_model, credentials) # transform response format - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: stop = stop or [] if model_mode == LLMMode.CHAT: # chat model @@ -123,7 +138,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) else: self._transform_completion_json_prompts( @@ -135,9 +150,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke( model=model, @@ -147,14 +162,21 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -167,25 +189,35 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - - def _transform_completion_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + + def _transform_completion_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -202,25 +234,30 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): break if user_message: - if prompt_messages[i].content[-11:] == 'Assistant: ': + if prompt_messages[i].content[-11:] == "Assistant: ": # now we are in the chat app, remove the last assistant message prompt_messages[i].content = prompt_messages[i].content[:-11] prompt_messages[i] = UserPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", user_message.content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", user_message.content).replace( + "{{block}}", response_format + ) ) prompt_messages[i].content += f"Assistant:\n```{response_format}\n" else: prompt_messages[i] = UserPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", user_message.content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", user_message.content).replace( + "{{block}}", response_format + ) ) prompt_messages[i].content += f"\n```{response_format}\n" - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -231,8 +268,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :return: """ # handle fine tune remote models - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] else: base_model = model @@ -262,14 +299,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # handle fine tune remote models base_model = model # fine-tuned model name likes ft:gpt-3.5-turbo-0613:personal::xxxxx - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # check if model exists remote_models = self.remote_models(credentials) remote_model_map = {model.model: model for model in remote_models} if model not in remote_model_map: - raise CredentialsValidateFailedError(f'Fine-tuned model {model} not found') + raise CredentialsValidateFailedError(f"Fine-tuned model {model} not found") # get model mode model_mode = self.get_model_mode(base_model, credentials) @@ -277,7 +314,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if model_mode == LLMMode.CHAT: # chat model client.chat.completions.create( - messages=[{"role": "user", "content": 'ping'}], + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=20, @@ -286,7 +323,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): else: # text completion model client.completions.create( - prompt='ping', + prompt="ping", model=model, temperature=0, max_tokens=20, @@ -313,11 +350,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # get all remote models remote_models = client.models.list() - fine_tune_models = [model for model in remote_models if model.id.startswith('ft:')] + fine_tune_models = [model for model in remote_models if model.id.startswith("ft:")] ai_model_entities = [] for model in fine_tune_models: - base_model = model.id.split(':')[1] + base_model = model.id.split(":")[1] base_model_schema = None for predefined_model_name, predefined_model in predefined_models_map.items(): @@ -329,30 +366,29 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ai_model_entity = AIModelEntity( model=model.id, - label=I18nObject( - zh_Hans=model.id, - en_US=model.id - ), + label=I18nObject(zh_Hans=model.id, en_US=model.id), model_type=ModelType.LLM, features=base_model_schema.features, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=base_model_schema.model_properties, parameter_rules=base_model_schema.parameter_rules, - pricing=PriceConfig( - input=0.003, - output=0.006, - unit=0.001, - currency='USD' - ) + pricing=PriceConfig(input=0.003, output=0.006, unit=0.001, currency="USD"), ) ai_model_entities.append(ai_model_entity) return ai_model_entities - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -374,23 +410,17 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if stream: - extra_model_kwargs['stream_options'] = { - "include_usage": True - } - + extra_model_kwargs["stream_options"] = {"include_usage": True} + # text completion model response = client.completions.create( - prompt=prompt_messages[0].content, - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs + prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -398,8 +428,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm completion response @@ -412,9 +443,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): assistant_text = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens if response.usage: @@ -440,8 +469,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm completion stream response @@ -451,7 +481,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" prompt_tokens = 0 completion_tokens = 0 @@ -460,8 +490,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -474,14 +504,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.text is None or delta.text == ''): + if delta.finish_reason is None and (delta.text is None or delta.text == ""): continue # transform assistant message to prompt message - text = delta.text if delta.text else '' - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + text = delta.text if delta.text else "" + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -494,7 +522,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -504,7 +532,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -520,10 +548,17 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): yield final_chunk - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -562,22 +597,18 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if tools: # extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - extra_model_kwargs['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if stream: - extra_model_kwargs['stream_options'] = { - 'include_usage': True - } + extra_model_kwargs["stream_options"] = {"include_usage": True} # clear illegal prompt messages prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) @@ -596,9 +627,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -619,10 +655,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -648,9 +681,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: """ Handle llm chat stream response @@ -660,7 +698,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :param tools: tools for tool calling :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None prompt_tokens = 0 completion_tokens = 0 @@ -670,8 +708,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -685,8 +723,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta = chunk.choices[0] has_finish_reason = delta.finish_reason is not None - if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls @@ -708,7 +749,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # start of stream function call delta_assistant_message_function_call_storage = assistant_message_function_call if delta_assistant_message_function_call_storage.arguments is None: - delta_assistant_message_function_call_storage.arguments = '' + delta_assistant_message_function_call_storage.arguments = "" if not has_finish_reason: continue @@ -720,11 +761,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" if has_finish_reason: final_chunk = LLMResultChunk( @@ -735,7 +775,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -745,7 +785,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -753,8 +793,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if not completion_tokens: full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=final_tool_calls + content=full_assistant_content, tool_calls=final_tool_calls ) completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) @@ -764,9 +803,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): yield final_chunk - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -777,21 +816,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -801,14 +838,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call @@ -821,7 +855,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: cleaned prompt messages """ - checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09'] + checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"] if model in checklist: # count how many user messages are there @@ -830,11 +864,16 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): for prompt_message in prompt_messages: if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, list): - prompt_message.content = '\n'.join([ - item.data if item.type == PromptMessageContentType.TEXT else - '[IMAGE]' if item.type == PromptMessageContentType.IMAGE else '' - for item in prompt_message.content - ]) + prompt_message.content = "\n".join( + [ + item.data + if item.type == PromptMessageContentType.TEXT + else "[IMAGE]" + if item.type == PromptMessageContentType.IMAGE + else "" + for item in prompt_message.content + ] + ) return prompt_messages @@ -851,19 +890,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -889,11 +922,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # "content": message.content, # "tool_call_id": message.tool_call_id # } - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -902,8 +931,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -924,13 +952,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - if model.startswith('ft:'): - model = model.split(':')[1] + if model.startswith("ft:"): + model = model.split(":")[1] # Currently, we can use gpt4o to calculate chatgpt-4o-latest's token. if model == "chatgpt-4o-latest": @@ -969,10 +998,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -1011,37 +1040,37 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ num_tokens = 0 for tool in tools: - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode('function')) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) # calculate num tokens for function object - num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode("name")) num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode("description")) num_tokens += len(encoding.encode(tool.description)) parameters = tool.parameters - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) num_tokens += len(encoding.encode(parameters.get("title"))) - num_tokens += len(encoding.encode('type')) + num_tokens += len(encoding.encode("type")) num_tokens += len(encoding.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters.get("properties").items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(encoding.encode(enum_field)) else: num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(encoding.encode(required_field)) @@ -1049,26 +1078,26 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - OpenAI supports fine-tuning of their models. This method returns the schema of the base model - but renamed to the fine-tuned model name. + OpenAI supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. - :param model: model name - :param credentials: credentials + :param model: model name + :param credentials: credentials - :return: model schema + :return: model schema """ - if not model.startswith('ft:'): + if not model.startswith("ft:"): base_model = model else: # get base_model - base_model = model.split(':')[1] + base_model = model.split(":")[1] # get model schema models = self.predefined_models() model_map = {model.model: model for model in models} if base_model not in model_map: - raise ValueError(f'Base model {base_model} not found') - + raise ValueError(f"Base model {base_model} not found") + base_model_schema = model_map[base_model] base_model_schema_features = base_model_schema.features or [] @@ -1077,16 +1106,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, features=list(base_model_schema_features), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=dict(base_model_schema_model_properties.items()), parameter_rules=list(base_model_schema_parameters_rules), - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing, ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index b1d0e57ad2..619044d808 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -14,9 +14,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): Model class for OpenAI text moderation model. """ - def _invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -34,10 +32,10 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): # chars per chunk length = self._get_max_characters_per_chunk(model, credentials) - text_chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i : i + length] for i in range(0, len(text), length)] max_text_chunks = self._get_max_chunks(model, credentials) - chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] + chunks = [text_chunks[i : i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] for text_chunk in chunks: moderation_result = self._moderation_invoke(model=model, client=client, texts=text_chunk) @@ -65,7 +63,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): self._moderation_invoke( model=model, client=client, - texts=['ping'], + texts=["ping"], ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openai/openai.py b/api/core/model_runtime/model_providers/openai/openai.py index 66efd4797f..175d7db73c 100644 --- a/api/core/model_runtime/model_providers/openai/openai.py +++ b/api/core/model_runtime/model_providers/openai/openai.py @@ -9,7 +9,6 @@ logger = logging.getLogger(__name__) class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: Mapping) -> None: """ Validate provider credentials @@ -22,12 +21,9 @@ class OpenAIProvider(ModelProvider): # Use `gpt-3.5-turbo` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='gpt-3.5-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="gpt-3.5-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py index efbdd054f9..18f97e45f3 100644 --- a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py @@ -12,9 +12,7 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -37,7 +35,7 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index e23a2edf87..535d8388bc 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -18,9 +18,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): Model class for OpenAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,9 +37,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'base64' + extra_model_kwargs["encoding_format"] = "base64" # get model properties context_size = self._get_context_size(model, credentials) @@ -56,11 +56,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): enc = tiktoken.get_encoding("cl100k_base") for i, text in enumerate(texts): - token = enc.encode( - text - ) + token = enc.encode(text) for j in range(0, len(token), context_size): - tokens += [token[j: j + context_size]] + tokens += [token[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -69,10 +67,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts=tokens[i: i + max_chunks], - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -88,10 +83,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts="", - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -101,17 +93,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -152,17 +136,13 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): client = OpenAI(**credentials_kwargs) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + def _embedding_invoke( + self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: """ Invoke embedding model @@ -179,10 +159,12 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): **extra_model_kwargs, ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": # decode base64 embedding - return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], - response.usage.total_tokens) + return ( + [list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], + response.usage.total_tokens, + ) return [data.embedding for data in response.data], response.usage.total_tokens @@ -197,10 +179,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -211,7 +190,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index afa5d4b88a..bfb443698c 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -14,8 +14,9 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: """ _invoke text2speech model @@ -28,14 +29,12 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in - self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) # if streaming: - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -50,14 +49,13 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model @@ -71,31 +69,38 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): # doc: https://platform.openai.com/docs/guides/text-to-speech credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) - model_support_voice = [x.get("value") for x in - self.get_tts_model_voices(model=model, credentials=credentials)] + model_support_voice = [ + x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials) + ] if not voice or voice not in model_support_voice: voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: sentences = self._split_text_into_sentences(content_text, max_length=word_limit) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model, - response_format="mp3", - input=sentences[i], voice=voice) for i in range(len(sentences))] + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, + ) + for i in range(len(sentences)) + ] for index, future in enumerate(futures): yield from future.result().__enter__().iter_bytes(1024) else: - response = client.audio.speech.with_streaming_response.create(model=model, voice=voice, - response_format="mp3", - input=content_text.strip()) + response = client.audio.speech.with_streaming_response.create( + model=model, voice=voice, response_format="mp3", input=content_text.strip() + ) yield from response.__enter__().iter_bytes(1024) except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, - voice, credentials: dict): + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index 51950ca377..257dffa30d 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -1,4 +1,3 @@ - import requests from core.model_runtime.errors.invoke import ( @@ -35,10 +34,10 @@ class _CommonOAI_API_Compat: ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] - } \ No newline at end of file + requests.exceptions.ReadTimeout, # Timeout + ], + } diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 6279125f46..75929af590 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -46,11 +46,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): Model class for OpenAI large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -77,8 +83,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -99,93 +110,85 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials['endpoint_url'] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials["endpoint_url"] + if not endpoint_url.endswith("/"): + endpoint_url += "/" # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'max_tokens': 5 - } + data = {"model": model, "max_tokens": 5} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - data['messages'] = [ - { - "role": "user", - "content": "ping" - }, + data["messages"] = [ + {"role": "user", "content": "ping"}, ] - endpoint_url = urljoin(endpoint_url, 'chat/completions') + endpoint_url = urljoin(endpoint_url, "chat/completions") elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - endpoint_url = urljoin(endpoint_url, 'completions') + data["prompt"] = "ping" + endpoint_url = urljoin(endpoint_url, "completions") else: raise ValueError("Unsupported completion type for model configuration.") # send a post request to validate the credentials - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if (completion_type is LLMMode.CHAT and json_result.get('object','') == ''): - json_result['object'] = 'chat.completion' - elif (completion_type is LLMMode.COMPLETION and json_result.get('object','') == ''): - json_result['object'] = 'text_completion' + if completion_type is LLMMode.CHAT and json_result.get("object", "") == "": + json_result["object"] = "chat.completion" + elif completion_type is LLMMode.COMPLETION and json_result.get("object", "") == "": + json_result["object"] = "text_completion" - if (completion_type is LLMMode.CHAT - and ('object' not in json_result or json_result['object'] != 'chat.completion')): + if completion_type is LLMMode.CHAT and ( + "object" not in json_result or json_result["object"] != "chat.completion" + ): raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response object, must be \'chat.completion\'') - elif (completion_type is LLMMode.COMPLETION - and ('object' not in json_result or json_result['object'] != 'text_completion')): + "Credentials validation failed: invalid response object, must be 'chat.completion'" + ) + elif completion_type is LLMMode.COMPLETION and ( + "object" not in json_result or json_result["object"] != "text_completion" + ): raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response object, must be \'text_completion\'') + "Credentials validation failed: invalid response object, must be 'text_completion'" + ) except CredentialsValidateFailedError: raise except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ features = [] - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type in ['function_call']: + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type in ["function_call"]: features.append(ModelFeature.TOOL_CALL) - elif function_calling_type in ['tool_call']: + elif function_calling_type in ["tool_call"]: features.append(ModelFeature.MULTI_TOOL_CALL) - stream_function_calling = credentials.get('stream_function_calling', 'supported') - if stream_function_calling == 'supported': + stream_function_calling = credentials.get("stream_function_calling", "supported") + if stream_function_calling == "supported": features.append(ModelFeature.STREAM_TOOL_CALL) - vision_support = credentials.get('vision_support', 'not_support') - if vision_support == 'support': + vision_support = credentials.get("vision_support", "not_support") + if vision_support == "support": features.append(ModelFeature.VISION) entity = AIModelEntity( @@ -195,43 +198,43 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, features=features, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")), - ModelPropertyKey.MODE: credentials.get('mode'), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")), + ModelPropertyKey.MODE: credentials.get("mode"), }, parameter_rules=[ ParameterRule( name=DefaultParameterName.TEMPERATURE.value, label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT, - default=float(credentials.get('temperature', 0.7)), + default=float(credentials.get("temperature", 0.7)), min=0, max=2, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.TOP_P.value, label=I18nObject(en_US="Top P"), type=ParameterType.FLOAT, - default=float(credentials.get('top_p', 1)), + default=float(credentials.get("top_p", 1)), min=0, max=1, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, label=I18nObject(en_US="Frequency Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('frequency_penalty', 0)), + default=float(credentials.get("frequency_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY.value, label=I18nObject(en_US="Presence Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('presence_penalty', 0)), + default=float(credentials.get("presence_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.MAX_TOKENS.value, @@ -239,20 +242,20 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): type=ParameterType.INT, default=512, min=1, - max=int(credentials.get('max_tokens_to_sample', 4096)), - ) + max=int(credentials.get("max_tokens_to_sample", 4096)), + ), ], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - output=Decimal(credentials.get('output_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") + input=Decimal(credentials.get("input_price", 0)), + output=Decimal(credentials.get("output_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), ), ) - if credentials['mode'] == 'chat': + if credentials["mode"] == "chat": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif credentials['mode'] == 'completion': + elif credentials["mode"] == "completion": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") @@ -260,10 +263,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return entity # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -277,52 +287,47 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :return: full response or stream response chunk generator result """ headers = { - 'Content-Type': 'application/json', - 'Accept-Charset': 'utf-8', + "Content-Type": "application/json", + "Accept-Charset": "utf-8", } - extra_headers = credentials.get('extra_headers') + extra_headers = credentials.get("extra_headers") if extra_headers is not None: headers = { - **headers, - **extra_headers, + **headers, + **extra_headers, } - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials["endpoint_url"] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + if not endpoint_url.endswith("/"): + endpoint_url += "/" - data = { - "model": model, - "stream": stream, - **model_parameters - } + data = {"model": model, "stream": stream, **model_parameters} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - endpoint_url = urljoin(endpoint_url, 'chat/completions') - data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] + endpoint_url = urljoin(endpoint_url, "chat/completions") + data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] elif completion_type is LLMMode.COMPLETION: - endpoint_url = urljoin(endpoint_url, 'completions') - data['prompt'] = prompt_messages[0].content + endpoint_url = urljoin(endpoint_url, "completions") + data["prompt"] = prompt_messages[0].content else: raise ValueError("Unsupported completion type for model configuration.") # annotate tools with names, descriptions, etc. - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") formatted_tools = [] if tools: - if function_calling_type == 'function_call': - data['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] - elif function_calling_type == 'tool_call': + if function_calling_type == "function_call": + data["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} + for tool in tools + ] + elif function_calling_type == "tool_call": data["tool_choice"] = "auto" for tool in tools: @@ -336,16 +341,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if user: data["user"] = user - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300), - stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) - if response.encoding is None or response.encoding == 'ISO-8859-1': - response.encoding = 'utf-8' + if response.encoding is None or response.encoding == "ISO-8859-1": + response.encoding = "utf-8" if response.status_code != 200: raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") @@ -355,8 +354,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -366,11 +366,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -381,16 +382,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) # delimiter for stream response, need unicode_escape import codecs + delimiter = credentials.get("stream_mode_delimiter", "\n\n") delimiter = codecs.decode(delimiter, "unicode_escape") @@ -406,10 +403,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tool_call = AssistantPromptMessage.ToolCall( id=tool_call_id, type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name="", - arguments="" - ) + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), ) tools_calls.append(tool_call) @@ -434,10 +428,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): chunk = chunk.strip() if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() - if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]" + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() + if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]" continue try: @@ -447,30 +441,31 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") assistant_message_tool_calls = None - if 'tool_calls' in delta and credentials.get('function_calling_type', 'no_call') == 'tool_call': - assistant_message_tool_calls = delta.get('tool_calls', None) - elif 'function_call' in delta and credentials.get('function_calling_type', 'no_call') == 'function_call': - assistant_message_tool_calls = [{ - 'id': 'tool_call_id', - 'type': 'function', - 'function': delta.get('function_call', {}) - }] + if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call": + assistant_message_tool_calls = delta.get("tool_calls", None) + elif ( + "function_call" in delta + and credentials.get("function_calling_type", "no_call") == "function_call" + ): + assistant_message_tool_calls = [ + {"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})} + ] # assistant_message_function_call = delta.delta.function_call @@ -479,7 +474,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message @@ -490,9 +485,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # reset tool calls tool_calls = [] full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -507,7 +502,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 @@ -518,47 +513,42 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason ) - def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> LLMResult: - + def _handle_generate_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> LLMResult: response_json = response.json() - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) - output = response_json['choices'][0] + output = response_json["choices"][0] - response_content = '' + response_content = "" tool_calls = None - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") if completion_type is LLMMode.CHAT: - response_content = output.get('message', {})['content'] - if function_calling_type == 'tool_call': - tool_calls = output.get('message', {}).get('tool_calls') - elif function_calling_type == 'function_call': - tool_calls = output.get('message', {}).get('function_call') + response_content = output.get("message", {})["content"] + if function_calling_type == "tool_call": + tool_calls = output.get("message", {}).get("tool_calls") + elif function_calling_type == "function_call": + tool_calls = output.get("message", {}).get("function_call") elif completion_type is LLMMode.COMPLETION: - response_content = output['text'] + response_content = output["text"] assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[]) if tool_calls: - if function_calling_type == 'tool_call': + if function_calling_type == "tool_call": assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) - elif function_calling_type == 'function_call': + elif function_calling_type == "function_call": assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)] usage = response_json.get("usage") @@ -597,19 +587,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -618,11 +602,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'tool_call': - message_dict["tool_calls"] = [tool_call.dict() for tool_call in - message.tool_calls] - elif function_calling_type == 'function_call': + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "tool_call": + message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls] + elif function_calling_type == "function_call": function_call = message.tool_calls[0] message_dict["function_call"] = { "name": function_call.function.name, @@ -633,19 +616,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'tool_call': - message_dict = { - "role": "tool", - "content": message.content, - "tool_call_id": message.tool_call_id - } - elif function_calling_type == 'function_call': - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "tool_call": + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} + elif function_calling_type == "function_call": + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -654,8 +629,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessageContent]], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Approximate num tokens for model with gpt2 tokenizer. @@ -667,7 +643,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if isinstance(text, str): full_text = text else: - full_text = '' + full_text = "" for message_content in text: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) @@ -680,8 +656,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, credentials: dict = None) -> int: + def _num_tokens_from_messages( + self, + model: str, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + credentials: dict = None, + ) -> int: """ Approximate num tokens with GPT2 tokenizer. """ @@ -700,10 +681,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -741,46 +722,44 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): """ num_tokens = 0 for tool in tools: - num_tokens += self._get_num_tokens_by_gpt2('type') - num_tokens += self._get_num_tokens_by_gpt2('function') - num_tokens += self._get_num_tokens_by_gpt2('function') + num_tokens += self._get_num_tokens_by_gpt2("type") + num_tokens += self._get_num_tokens_by_gpt2("function") + num_tokens += self._get_num_tokens_by_gpt2("function") # calculate num tokens for function object - num_tokens += self._get_num_tokens_by_gpt2('name') + num_tokens += self._get_num_tokens_by_gpt2("name") num_tokens += self._get_num_tokens_by_gpt2(tool.name) - num_tokens += self._get_num_tokens_by_gpt2('description') + num_tokens += self._get_num_tokens_by_gpt2("description") num_tokens += self._get_num_tokens_by_gpt2(tool.description) parameters = tool.parameters - num_tokens += self._get_num_tokens_by_gpt2('parameters') - if 'title' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('title') + num_tokens += self._get_num_tokens_by_gpt2("parameters") + if "title" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("title") num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title")) - num_tokens += self._get_num_tokens_by_gpt2('type') + num_tokens += self._get_num_tokens_by_gpt2("type") num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type")) - if 'properties' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("properties") + for key, value in parameters.get("properties").items(): num_tokens += self._get_num_tokens_by_gpt2(key) for field_key, field_value in value.items(): num_tokens += self._get_num_tokens_by_gpt2(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += self._get_num_tokens_by_gpt2(enum_field) else: num_tokens += self._get_num_tokens_by_gpt2(field_key) num_tokens += self._get_num_tokens_by_gpt2(str(field_value)) - if 'required' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += self._get_num_tokens_by_gpt2(required_field) return num_tokens - def _extract_response_tool_calls(self, - response_tool_calls: list[dict]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -792,20 +771,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( name=response_tool_call.get("function", {}).get("name", ""), - arguments=response_tool_call.get("function", {}).get("arguments", "") + arguments=response_tool_call.get("function", {}).get("arguments", ""), ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.get("id", ""), - type=response_tool_call.get("type", ""), - function=function + id=response_tool_call.get("id", ""), type=response_tool_call.get("type", ""), function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call(self, response_function_call) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -815,14 +791,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.get('name', ''), - arguments=response_function_call.get('arguments', '') + name=response_function_call.get("name", ""), arguments=response_function_call.get("arguments", "") ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.get('id', ''), - type="function", - function=function + id=response_function_call.get("id", ""), type="function", function=function ) return tool_call diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py index 3445ebbaf7..ca6f185287 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class OAICompatProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py index 00702ba936..2e8b4ddd72 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -14,9 +14,7 @@ class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel): Model class for OpenAI Compatible Speech to text model. """ - def _invoke( - self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None - ) -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 363054b084..ab358cf70a 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -27,9 +27,9 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): Model class for an OpenAI API-compatible text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -39,27 +39,25 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - - # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } - api_key = credentials.get('api_key') + # Prepare headers and payload for the request + headers = {"Content-Type": "application/json"} + + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'float' + extra_model_kwargs["encoding_format"] = "float" # get model properties context_size = self._get_context_size(model, credentials) @@ -70,7 +68,6 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer # TODO: Optimize for better token estimation and chunking num_tokens = self._get_num_tokens_by_gpt2(text) @@ -78,7 +75,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -88,42 +85,25 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): for i in _iter: # Prepare the payload for the request - payload = { - 'input': inputs[i: i + max_chunks], - 'model': model, - **extra_model_kwargs - } + payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs} # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings_batch = [data['embedding'] for data in response_data['data']] - embedding_used_tokens = response_data['usage']['total_tokens'] + embeddings_batch = [data["embedding"] for data in response_data["data"]] + embedding_used_tokens = response_data["usage"]["total_tokens"] used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) - - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -145,45 +125,35 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") - payload = { - 'input': 'ping', - 'model': model - } + payload = {"input": "ping", "model": model} - response = requests.post( - url=endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if 'model' not in json_result: - raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response') + if "model" not in json_result: + raise CredentialsValidateFailedError("Credentials validation failed: invalid response") except CredentialsValidateFailedError: raise except Exception as ex: @@ -191,7 +161,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -199,20 +169,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -224,10 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -238,7 +204,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index 8ea5819bde..b560afca39 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -38,88 +38,115 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors impo class OpenLLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate credentials for Baichuan model """ - if not credentials.get('server_url'): - raise CredentialsValidateFailedError('Invalid server URL') + if not credentials.get("server_url"): + raise CredentialsValidateFailedError("Invalid server URL") # ping instance = OpenLLMGenerate() try: instance.generate( - server_url=credentials['server_url'], - model_name=model, - prompt_messages=[ - OpenLLMGenerateMessage(content='ping\nAnswer: ', role='user') - ], + server_url=credentials["server_url"], + model_name=model, + prompt_messages=[OpenLLMGenerateMessage(content="ping\nAnswer: ", role="user")], model_parameters={ - 'max_tokens': 64, - 'temperature': 0.8, - 'top_p': 0.9, - 'top_k': 15, + "max_tokens": 64, + "temperature": 0.8, + "top_p": 0.9, + "top_k": 15, }, stream=False, - user='', + user="", stop=[], ) except InvalidAuthenticationError as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages, tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for OpenLLM model - it's a generate model, so we just join them by spe + Calculate num tokens for OpenLLM model + it's a generate model, so we just join them by spe """ - messages = ','.join([message.content for message in messages]) + messages = ",".join([message.content for message in messages]) return self._get_num_tokens_by_gpt2(messages) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = OpenLLMGenerate() response = client.generate( model_name=model, - server_url=credentials['server_url'], + server_url=credentials["server_url"], prompt_messages=[self._convert_prompt_message_to_openllm_message(message) for message in prompt_messages], model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) if stream: - return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) - return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) + return self._handle_chat_generate_stream_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) + return self._handle_chat_generate_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) def _convert_prompt_message_to_openllm_message(self, prompt_message: PromptMessage) -> OpenLLMGenerateMessage: """ - convert PromptMessage to OpenLLMGenerateMessage so that we can use OpenLLMGenerateMessage interface + convert PromptMessage to OpenLLMGenerateMessage so that we can use OpenLLMGenerateMessage interface """ if isinstance(prompt_message, UserPromptMessage): return OpenLLMGenerateMessage(role=OpenLLMGenerateMessage.Role.USER.value, content=prompt_message.content) elif isinstance(prompt_message, AssistantPromptMessage): - return OpenLLMGenerateMessage(role=OpenLLMGenerateMessage.Role.ASSISTANT.value, content=prompt_message.content) + return OpenLLMGenerateMessage( + role=OpenLLMGenerateMessage.Role.ASSISTANT.value, content=prompt_message.content + ) else: - raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') + raise NotImplementedError(f"Prompt message type {type(prompt_message)} is not supported") - def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: OpenLLMGenerateMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens'] - ) + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: OpenLLMGenerateMessage + ) -> LLMResult: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, @@ -130,25 +157,27 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], - credentials: dict, response: Generator[OpenLLMGenerateMessage, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[OpenLLMGenerateMessage, None, None], + ) -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens'] + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, finish_reason=message.stop_reason if message.stop_reason else None, ), @@ -159,73 +188,55 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), finish_reason=message.stop_reason if message.stop_reason else None, ), ) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='top_k', + name="top_k", type=ParameterType.INT, - use_template='top_k', + use_template="top_k", min=1, default=1, - label=I18nObject( - zh_Hans='Top K', - en_US='Top K' - ) + label=I18nObject(zh_Hans="Top K", en_US="Top K"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ + model_properties={ ModelPropertyKey.MODE: LLMMode.COMPLETION.value, }, - parameter_rules=rules + parameter_rules=rules, ) return entity @@ -241,22 +252,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 1c3f084207..e754479ec0 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -15,32 +15,38 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors impo class OpenLLMGenerateMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' + USER = "user" + ASSISTANT = "assistant" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - - def __init__(self, content: str, role: str = 'user') -> None: + + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role class OpenLLMGenerate: def generate( - self, server_url: str, model_name: str, stream: bool, model_parameters: dict[str, Any], - stop: list[str], prompt_messages: list[OpenLLMGenerateMessage], user: str, + self, + server_url: str, + model_name: str, + stream: bool, + model_parameters: dict[str, Any], + stop: list[str], + prompt_messages: list[OpenLLMGenerateMessage], + user: str, ) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]: if not server_url: - raise InvalidAuthenticationError('Invalid server URL') + raise InvalidAuthenticationError("Invalid server URL") default_llm_config = { "max_new_tokens": 128, @@ -72,40 +78,37 @@ class OpenLLMGenerate: "frequency_penalty": 0, "use_beam_search": False, "ignore_eos": False, - "skip_special_tokens": True + "skip_special_tokens": True, } - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - default_llm_config['max_new_tokens'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + default_llm_config["max_new_tokens"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - default_llm_config['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + default_llm_config["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - default_llm_config['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + default_llm_config["top_p"] = model_parameters["top_p"] - if 'top_k' in model_parameters and type(model_parameters['top_k']) == int: - default_llm_config['top_k'] = model_parameters['top_k'] + if "top_k" in model_parameters and type(model_parameters["top_k"]) == int: + default_llm_config["top_k"] = model_parameters["top_k"] - if 'use_cache' in model_parameters and type(model_parameters['use_cache']) == bool: - default_llm_config['use_cache'] = model_parameters['use_cache'] + if "use_cache" in model_parameters and type(model_parameters["use_cache"]) == bool: + default_llm_config["use_cache"] = model_parameters["use_cache"] - headers = { - 'Content-Type': 'application/json', - 'accept': 'application/json' - } + headers = {"Content-Type": "application/json", "accept": "application/json"} if stream: - url = f'{server_url}/v1/generate_stream' + url = f"{server_url}/v1/generate_stream" timeout = 10 else: - url = f'{server_url}/v1/generate' + url = f"{server_url}/v1/generate" timeout = 120 data = { - 'stop': stop if stop else [], - 'prompt': '\n'.join([message.content for message in prompt_messages]), - 'llm_config': default_llm_config, + "stop": stop if stop else [], + "prompt": "\n".join([message.content for message in prompt_messages]), + "llm_config": default_llm_config, } try: @@ -113,10 +116,10 @@ class OpenLLMGenerate: except (ConnectionError, InvalidSchema, MissingSchema) as e: # cloud not connect to the server raise InvalidAuthenticationError(f"Invalid server URL: {e}") - + if not response.ok: resp = response.json() - msg = resp['msg'] + msg = resp["msg"] if response.status_code == 400: raise BadRequestError(msg) elif response.status_code == 404: @@ -125,69 +128,71 @@ class OpenLLMGenerate: raise InternalServerError(msg) else: raise InternalServerError(msg) - + if stream: return self._handle_chat_stream_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_chat_generate_response(self, response: Response) -> OpenLLMGenerateMessage: try: data = response.json() except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - message = data['outputs'][0] - text = message['text'] - token_ids = message['token_ids'] - prompt_token_ids = data['prompt_token_ids'] - stop_reason = message['finish_reason'] + message = data["outputs"][0] + text = message["text"] + token_ids = message["token_ids"] + prompt_token_ids = data["prompt_token_ids"] + stop_reason = message["finish_reason"] message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) message.stop_reason = stop_reason message.usage = { - 'prompt_tokens': len(prompt_token_ids), - 'completion_tokens': len(token_ids), - 'total_tokens': len(prompt_token_ids) + len(token_ids), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": len(token_ids), + "total_tokens": len(prompt_token_ids) + len(token_ids), } return message - def _handle_chat_stream_generate_response(self, response: Response) -> Generator[OpenLLMGenerateMessage, None, None]: + def _handle_chat_stream_generate_response( + self, response: Response + ) -> Generator[OpenLLMGenerateMessage, None, None]: completion_usage = 0 for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() - if line == '[DONE]': + if line == "[DONE]": return try: data = loads(line) except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}") - - output = data['outputs'] + + output = data["outputs"] for choice in output: - text = choice['text'] - token_ids = choice['token_ids'] + text = choice["text"] + token_ids = choice["token_ids"] completion_usage += len(token_ids) message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) - if choice.get('finish_reason'): - finish_reason = choice['finish_reason'] - prompt_token_ids = data['prompt_token_ids'] + if choice.get("finish_reason"): + finish_reason = choice["finish_reason"] + prompt_token_ids = data["prompt_token_ids"] message.stop_reason = finish_reason message.usage = { - 'prompt_tokens': len(prompt_token_ids), - 'completion_tokens': completion_usage, - 'total_tokens': completion_usage + len(prompt_token_ids), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": completion_usage, + "total_tokens": completion_usage + len(prompt_token_ids), } - - yield message \ No newline at end of file + + yield message diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py index d9d279e6ca..309b5cf413 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index 4dbd0678e7..00e583cc79 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -23,9 +23,10 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): """ Model class for OpenLLM text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -35,16 +36,13 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] + server_url = credentials["server_url"] if not server_url: - raise CredentialsValidateFailedError('server_url is required') - - headers = { - 'Content-Type': 'application/json', - 'accept': 'application/json' - } + raise CredentialsValidateFailedError("server_url is required") - url = f'{server_url}/v1/embeddings' + headers = {"Content-Type": "application/json", "accept": "application/json"} + + url = f"{server_url}/v1/embeddings" data = texts try: @@ -54,7 +52,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): raise InvokeAuthorizationError(f"Invalid server URL: {e}") except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: if response.status_code == 400: raise InvokeBadRequestError(response.text) @@ -62,21 +60,17 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): raise InvokeAuthorizationError(response.text) elif response.status_code == 500: raise InvokeServerUnavailableError(response.text) - + try: resp = response.json()[0] - embeddings = resp['embeddings'] - total_tokens = resp['num_tokens'] + embeddings = resp["embeddings"] + total_tokens = resp["num_tokens"] except KeyError as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) return result @@ -104,9 +98,9 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid server_url') + raise CredentialsValidateFailedError("Invalid server_url") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: @@ -119,23 +113,13 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -147,10 +131,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -161,7 +142,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llm.py b/api/core/model_runtime/model_providers/openrouter/llm/llm.py index e78ac4caf1..71b5745f7d 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llm.py +++ b/api/core/model_runtime/model_providers/openrouter/llm/llm.py @@ -8,18 +8,23 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_credential(self, model: str, credentials: dict): - credentials['endpoint_url'] = "https://openrouter.ai/api/v1" - credentials['mode'] = self.get_model_mode(model).value - credentials['function_calling_type'] = 'tool_call' + credentials["endpoint_url"] = "https://openrouter.ai/api/v1" + credentials["mode"] = self.get_model_mode(model).value + credentials["function_calling_type"] = "tool_call" return - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._update_credential(model, credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -29,9 +34,17 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): return super().validate_credentials(model, credentials) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._update_credential(model, credentials) return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -41,8 +54,13 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): return super().get_customizable_model_schema(model, credentials) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: self._update_credential(model, credentials) return super().get_num_tokens(model, credentials, prompt_messages, tools) diff --git a/api/core/model_runtime/model_providers/openrouter/openrouter.py b/api/core/model_runtime/model_providers/openrouter/openrouter.py index 613f71deb1..2e59ab5059 100644 --- a/api/core/model_runtime/model_providers/openrouter/openrouter.py +++ b/api/core/model_runtime/model_providers/openrouter/openrouter.py @@ -8,17 +8,13 @@ logger = logging.getLogger(__name__) class OpenRouterProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='openai/gpt-3.5-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="openai/gpt-3.5-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') - raise ex \ No newline at end of file + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py b/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py index c9116bf685..89cac665aa 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py @@ -13,11 +13,17 @@ from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguag class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -27,8 +33,7 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): super().validate_credentials(model, credentials) # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -46,8 +51,9 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -67,10 +73,10 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -101,10 +107,10 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://cloud.perfxlab.cn' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://cloud.perfxlab.cn" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py b/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py index 0854ef5185..450d22fb75 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py +++ b/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class PerfXCloudProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ class PerfXCloudProvider(ModelProvider): # Use `Qwen2_72B_Chat_GPTQ_Int4` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='Qwen2-72B-Instruct-GPTQ-Int4', - credentials=credentials - ) + model_instance.validate_credentials(model="Qwen2-72B-Instruct-GPTQ-Int4", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 11d57e3749..d0522233e3 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -27,9 +27,9 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): Model class for an OpenAI API-compatible text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -39,30 +39,28 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - - # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } - api_key = credentials.get('api_key') + # Prepare headers and payload for the request + headers = {"Content-Type": "application/json"} + + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - endpoint_url='https://cloud.perfxlab.cn/v1/' + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + endpoint_url = "https://cloud.perfxlab.cn/v1/" else: - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'float' + extra_model_kwargs["encoding_format"] = "float" # get model properties context_size = self._get_context_size(model, credentials) @@ -73,7 +71,6 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer # TODO: Optimize for better token estimation and chunking num_tokens = self._get_num_tokens_by_gpt2(text) @@ -81,7 +78,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -91,42 +88,25 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): for i in _iter: # Prepare the payload for the request - payload = { - 'input': inputs[i: i + max_chunks], - 'model': model, - **extra_model_kwargs - } + payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs} # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings_batch = [data['embedding'] for data in response_data['data']] - embedding_used_tokens = response_data['usage']['total_tokens'] + embeddings_batch = [data["embedding"] for data in response_data["data"]] + embedding_used_tokens = response_data["usage"]["total_tokens"] used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) - - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -148,48 +128,38 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - endpoint_url='https://cloud.perfxlab.cn/v1/' + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + endpoint_url = "https://cloud.perfxlab.cn/v1/" else: - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") - payload = { - 'input': 'ping', - 'model': model - } + payload = {"input": "ping", "model": model} - response = requests.post( - url=endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if 'model' not in json_result: - raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response') + if "model" not in json_result: + raise CredentialsValidateFailedError("Credentials validation failed: invalid response") except CredentialsValidateFailedError: raise except Exception as ex: @@ -197,7 +167,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -205,20 +175,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -230,10 +199,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -244,7 +210,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py index 29d8427d8e..915f6e0eef 100644 --- a/api/core/model_runtime/model_providers/replicate/_common.py +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -4,12 +4,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError class _CommonReplicate: - @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - ReplicateError, - ModelError - ] - } + return {InvokeBadRequestError: [ReplicateError, ModelError]} diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 31b81a829e..87c8bc4a91 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -28,16 +28,22 @@ from core.model_runtime.model_providers.replicate._common import _CommonReplicat class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: - - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] - - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -48,39 +54,43 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): inputs = {**model_parameters} if prompt_messages[0].role == PromptMessageRole.SYSTEM: - if 'system_prompt' in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: - inputs['system_prompt'] = prompt_messages[0].content - inputs['prompt'] = prompt_messages[1].content + if "system_prompt" in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"]: + inputs["system_prompt"] = prompt_messages[0].content + inputs["prompt"] = prompt_messages[1].content else: - inputs['prompt'] = prompt_messages[0].content + inputs["prompt"] = prompt_messages[0].content - prediction = client.predictions.create( - version=model_info_version, input=inputs - ) + prediction = client.predictions.create(version=model_info_version, input=inputs) if stream: return self._handle_generate_stream_response(model, credentials, prediction, stop, prompt_messages) return self._handle_generate_response(model, credentials, prediction, stop, prompt_messages) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: - if 'replicate_api_token' not in credentials: - raise CredentialsValidateFailedError('Replicate Access Token must be provided.') + if "replicate_api_token" not in credentials: + raise CredentialsValidateFailedError("Replicate Access Token must be provided.") - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] if model.count("/") != 1: - raise CredentialsValidateFailedError('Replicate Model Name must be provided, ' - 'format: {user_name}/{model_name}') + raise CredentialsValidateFailedError( + "Replicate Model Name must be provided, " "format: {user_name}/{model_name}" + ) try: - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -91,45 +101,44 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): self._check_text_generation_model(model_info_version, model, model_version, model_info.description) except ReplicateError as e: raise CredentialsValidateFailedError( - f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}") + f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}" + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) @staticmethod def _check_text_generation_model(model_info_version, model_name, version, description): - if 'language model' in description.lower(): + if "language model" in description.lower(): return - if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ - or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ - or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: + if ( + "temperature" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + or "top_p" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + or "top_k" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + ): raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.") def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - model_type = LLMMode.CHAT if model.endswith('-chat') else LLMMode.COMPLETION + model_type = LLMMode.CHAT if model.endswith("-chat") else LLMMode.COMPLETION entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ - ModelPropertyKey.MODE: model_type.value - }, - parameter_rules=self._get_customizable_model_parameter_rules(model, credentials) + model_properties={ModelPropertyKey.MODE: model_type.value}, + parameter_rules=self._get_customizable_model_parameter_rules(model, credentials), ) return entity @classmethod def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]: - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -140,15 +149,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): parameter_rules = [] input_properties = sorted( - model_info_version.openapi_schema["components"]["schemas"]["Input"][ - "properties" - ].items(), + model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"].items(), key=lambda item: item[1].get("x-order", 0), ) for key, value in input_properties: - if key not in ['system_prompt', 'prompt'] and 'stop' not in key: - value_type = value.get('type') + if key not in ["system_prompt", "prompt"] and "stop" not in key: + value_type = value.get("type") if not value_type: continue @@ -157,28 +164,28 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): rule = ParameterRule( name=key, - label={ - 'en_US': value['title'] - }, + label={"en_US": value["title"]}, type=param_type, help={ - 'en_US': value.get('description'), + "en_US": value.get("description"), }, required=False, - default=value.get('default'), - min=value.get('minimum'), - max=value.get('maximum') + default=value.get("default"), + min=value.get("minimum"), + max=value.get("maximum"), ) parameter_rules.append(rule) return parameter_rules - def _handle_generate_stream_response(self, - model: str, - credentials: dict, - prediction: Prediction, - stop: list[str], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + prediction: Prediction, + stop: list[str], + prompt_messages: list[PromptMessage], + ) -> Generator: index = -1 current_completion: str = "" stop_condition_reached = False @@ -189,7 +196,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): for output in prediction.output_iterator(): current_completion += output - if not is_prediction_output_finished and prediction.status == 'succeeded': + if not is_prediction_output_finished and prediction.status == "succeeded": prediction_output_length = len(prediction.output) - 1 is_prediction_output_finished = True @@ -207,18 +214,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): index += 1 - assistant_prompt_message = AssistantPromptMessage( - content=output if output else '' - ) + assistant_prompt_message = AssistantPromptMessage(content=output if output else "") if index < prediction_output_length: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -229,15 +231,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message, usage=usage), ) - def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str], - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + prediction: Prediction, + stop: list[str], + prompt_messages: list[PromptMessage], + ) -> LLMResult: current_completion: str = "" stop_condition_reached = False for output in prediction.output_iterator(): @@ -255,9 +259,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): if stop_condition_reached: break - assistant_prompt_message = AssistantPromptMessage( - content=current_completion - ) + assistant_prompt_message = AssistantPromptMessage(content=current_completion) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) @@ -275,21 +277,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): @classmethod def _get_parameter_type(cls, param_type: str) -> str: - type_mapping = { - 'integer': 'int', - 'number': 'float', - 'boolean': 'boolean', - 'string': 'string' - } + type_mapping = {"integer": "int", "number": "float", "boolean": "boolean", "string": "string"} return type_mapping.get(param_type) def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() diff --git a/api/core/model_runtime/model_providers/replicate/replicate.py b/api/core/model_runtime/model_providers/replicate/replicate.py index 3a5c9b84a0..ca137579c9 100644 --- a/api/core/model_runtime/model_providers/replicate/replicate.py +++ b/api/core/model_runtime/model_providers/replicate/replicate.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class ReplicateProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 0e4cdbf5bc..f6b7754d74 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -13,32 +13,27 @@ from core.model_runtime.model_providers.replicate._common import _CommonReplicat class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) - - if 'model_version' in credentials: - model_version = credentials['model_version'] + if "model_version" in credentials: + model_version = credentials["model_version"] else: model_info = client.models.get(model) model_version = model_info.latest_version.id - replicate_model_version = f'{model}:{model_version}' + replicate_model_version = f"{model}:{model_version}" text_input_key = self._get_text_input_key(model, model_version, client) - embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, - texts) + embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, texts) tokens = self.get_num_tokens(model, credentials, texts) usage = self._calc_response_usage(model, credentials, tokens) - return TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + return TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: num_tokens = 0 @@ -47,39 +42,35 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - if 'replicate_api_token' not in credentials: - raise CredentialsValidateFailedError('Replicate Access Token must be provided.') + if "replicate_api_token" not in credentials: + raise CredentialsValidateFailedError("Replicate Access Token must be provided.") try: - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) - if 'model_version' in credentials: - model_version = credentials['model_version'] + if "model_version" in credentials: + model_version = credentials["model_version"] else: model_info = client.models.get(model) model_version = model_info.latest_version.id - replicate_model_version = f'{model}:{model_version}' + replicate_model_version = f"{model}:{model_version}" text_input_key = self._get_text_input_key(model, model_version, client) - self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, - ['Hello worlds!']) + self._generate_embeddings_by_text_input_key( + client, replicate_model_version, text_input_key, ["Hello worlds!"] + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - 'context_size': 4096, - 'max_chunks': 1 - } + model_properties={"context_size": 4096, "max_chunks": 1}, ) return entity @@ -90,49 +81,45 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): # sort through the openapi schema to get the name of text, texts or inputs input_properties = sorted( - model_info_version.openapi_schema["components"]["schemas"]["Input"][ - "properties" - ].items(), + model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"].items(), key=lambda item: item[1].get("x-order", 0), ) for input_property in input_properties: - if input_property[0] in ('text', 'texts', 'inputs'): + if input_property[0] in ("text", "texts", "inputs"): text_input_key = input_property[0] return text_input_key - return '' + return "" @staticmethod - def _generate_embeddings_by_text_input_key(client: ReplicateClient, replicate_model_version: str, - text_input_key: str, texts: list[str]) -> list[list[float]]: - - if text_input_key in ('text', 'inputs'): + def _generate_embeddings_by_text_input_key( + client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str] + ) -> list[list[float]]: + if text_input_key in ("text", "inputs"): embeddings = [] for text in texts: - result = client.run(replicate_model_version, input={ - text_input_key: text - }) - embeddings.append(result[0].get('embedding')) + result = client.run(replicate_model_version, input={text_input_key: text}) + embeddings.append(result[0].get("embedding")) return [list(map(float, e)) for e in embeddings] - elif 'texts' == text_input_key: - result = client.run(replicate_model_version, input={ - 'texts': json.dumps(texts), - "batch_size": 4, - "convert_to_numpy": False, - "normalize_embeddings": True - }) + elif "texts" == text_input_key: + result = client.run( + replicate_model_version, + input={ + "texts": json.dumps(texts), + "batch_size": 4, + "convert_to_numpy": False, + "normalize_embeddings": True, + }, + ) return result else: - raise ValueError(f'embeddings input key is invalid: {text_input_key}') + raise ValueError(f"embeddings input key is invalid: {text_input_key}") def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -143,7 +130,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index 3d4c5825af..2edd13d56d 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -44,10 +44,11 @@ from core.model_runtime.model_providers.__base.large_language_model import Large logger = logging.getLogger(__name__) -def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], stop:list, stream=False): - """ + +def inference(predictor, messages: list[dict[str, Any]], params: dict[str, Any], stop: list, stream=False): + """ params: - predictor : Sagemaker Predictor + predictor : Sagemaker Predictor messages (List[Dict[str,Any]]): message list。 messages = [ {"role": "system", "content":"please answer in Chinese"}, @@ -55,19 +56,19 @@ def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], sto ] params (Dict[str,Any]): model parameters for LLM。 stream (bool): False by default。 - + response: result of inference if stream is False Iterator of Chunks if stream is True """ payload = { - "model" : params.get('model_name'), - "stop" : stop, + "model": params.get("model_name"), + "stop": stop, "messages": messages, - "stream" : stream, - "max_tokens" : params.get('max_new_tokens', params.get('max_tokens', 2048)), - "temperature" : params.get('temperature', 0.1), - "top_p" : params.get('top_p', 0.9), + "stream": stream, + "max_tokens": params.get("max_new_tokens", params.get("max_tokens", 2048)), + "temperature": params.get("temperature", 0.1), + "top_p": params.get("top_p", 0.9), } if not stream: @@ -77,36 +78,41 @@ def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], sto response_stream = predictor.predict_stream(payload) return response_stream + class SageMakerLargeLanguageModel(LargeLanguageModel): """ Model class for Cohere large language model. """ - sagemaker_client: Any = None - sagemaker_sess : Any = None - predictor : Any = None - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: bytes) -> LLMResult: + sagemaker_client: Any = None + sagemaker_sess: Any = None + predictor: Any = None + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: bytes, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ - resp_obj = json.loads(resp.decode('utf-8')) - resp_str = resp_obj.get('choices')[0].get('message').get('content') + resp_obj = json.loads(resp.decode("utf-8")) + resp_str = resp_obj.get("choices")[0].get("message").get("content") if len(resp_str) == 0: raise InvokeServerUnavailableError("Empty response") - assistant_prompt_message = AssistantPromptMessage( - content=resp_str, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=resp_str, tool_calls=[]) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -118,37 +124,43 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): return response - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[bytes]) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[bytes], + ) -> Generator: """ - handle stream chat generate response + handle stream chat generate response """ - full_response = '' + full_response = "" buffer = "" for chunk_bytes in resp: - buffer += chunk_bytes.decode('utf-8') + buffer += chunk_bytes.decode("utf-8") last_idx = 0 - for match in re.finditer(r'^data:\s*(.+?)(\n\n)', buffer): + for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer): try: data = json.loads(match.group(1).strip()) last_idx = match.span()[1] if "content" in data["choices"][0]["delta"]: chunk_content = data["choices"][0]["delta"]["content"] - assistant_prompt_message = AssistantPromptMessage( - content=chunk_content, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[]) - if data["choices"][0]['finish_reason'] is not None: - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + if data["choices"][0]["finish_reason"] is not None: + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[] + ) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -157,8 +169,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=assistant_prompt_message, - finish_reason=data["choices"][0]['finish_reason'], - usage=usage + finish_reason=data["choices"][0]["finish_reason"], + usage=usage, ), ) else: @@ -166,10 +178,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): model=model, prompt_messages=prompt_messages, system_fingerprint=None, - delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message - ), + delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message), ) full_response += chunk_content @@ -179,11 +188,17 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): buffer = buffer[last_idx:] - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -198,15 +213,17 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ if not self.sagemaker_client: - access_key = credentials.get('access_key') - secret_key = credentials.get('secret_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("access_key") + secret_key = credentials.get("secret_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -214,25 +231,26 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client) self.predictor = Predictor( - endpoint_name=credentials.get('sagemaker_endpoint'), + endpoint_name=credentials.get("sagemaker_endpoint"), sagemaker_session=sagemaker_session, serializer=serializers.JSONSerializer(), ) - - messages:list[dict[str,Any]] = [ {"role": p.role.value, "content": p.content} for p in prompt_messages ] - response = inference(predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream) + messages: list[dict[str, Any]] = [{"role": p.role.value, "content": p.content} for p in prompt_messages] + response = inference( + predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream + ) if stream: if tools and len(tools) > 0: raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode") - return self._handle_chat_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=response) - return self._handle_chat_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=response) + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ @@ -247,19 +265,13 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -269,7 +281,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -282,8 +294,9 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): return message_dict - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], - is_completion_model: bool = False) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False + ) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -299,10 +312,10 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -339,8 +352,13 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): return num_tokens - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -381,89 +399,63 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=credentials.get('context_length', 2048), + max=credentials.get("context_length", 2048), default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] completion_type = LLMMode.value_of(credentials["mode"]).value features = [] - support_function_call = credentials.get('support_function_call', False) + support_function_call = credentials.get("support_function_call", False) if support_function_call: features.append(ModelFeature.TOOL_CALL) - support_vision = credentials.get('support_vision', False) + support_vision = credentials.get("support_vision", False) if support_vision: features.append(ModelFeature.VISION) - context_length = credentials.get('context_length', 2048) + context_length = credentials.get("context_length", 2048) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, features=features, - model_properties={ - ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: context_length - }, - parameter_rules=rules + model_properties={ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length}, + parameter_rules=rules, ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index 6b7cfc210b..7e7614055c 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -20,34 +20,36 @@ from core.model_runtime.model_providers.__base.rerank_model import RerankModel logger = logging.getLogger(__name__) + class SageMakerRerankModel(RerankModel): """ Model class for SageMaker rerank model. """ + sagemaker_client: Any = None - def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): - inputs = [query_input]*len(docs) + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): + inputs = [query_input] * len(docs) response_model = self.sagemaker_client.invoke_endpoint( EndpointName=rerank_endpoint, - Body=json.dumps( - { - "inputs": inputs, - "docs": docs - } - ), + Body=json.dumps({"inputs": inputs, "docs": docs}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - scores = json_obj['scores'] + scores = json_obj["scores"] return scores if isinstance(scores, list) else [scores] - - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -63,22 +65,21 @@ class SageMakerRerankModel(RerankModel): line = 0 try: if len(docs) == 0: - return RerankResult( - model=model, - docs=docs - ) + return RerankResult(model=model, docs=docs) line = 1 if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -86,22 +87,20 @@ class SageMakerRerankModel(RerankModel): line = 2 - sagemaker_endpoint = credentials.get('sagemaker_endpoint') + sagemaker_endpoint = credentials.get("sagemaker_endpoint") candidate_docs = [] scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint) for idx in range(len(scores)): - candidate_docs.append({"content" : docs[idx], "score": scores[idx]}) + candidate_docs.append({"content": docs[idx], "score": scores[idx]}) - sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + sorted(candidate_docs, key=lambda x: x["score"], reverse=True) line = 3 rerank_documents = [] for idx, result in enumerate(candidate_docs): rerank_document = RerankDocument( - index=idx, - text=result.get('content'), - score=result.get('score', -100.0) + index=idx, text=result.get("content"), score=result.get("score", -100.0) ) if score_threshold is not None: @@ -110,13 +109,10 @@ class SageMakerRerankModel(RerankModel): else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -137,7 +133,7 @@ class SageMakerRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -153,38 +149,24 @@ class SageMakerRerankModel(RerankModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py index 6f3e02489f..042155b152 100644 --- a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class SageMakerProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -17,27 +18,24 @@ class SageMakerProvider(ModelProvider): """ pass -def buffer_to_s3(s3_client:Any, file: IO[bytes], bucket:str, s3_prefix:str) -> str: - ''' - return s3_uri of this file - ''' - s3_key = f'{s3_prefix}{uuid.uuid4()}.mp3' - s3_client.put_object( - Body=file.read(), - Bucket=bucket, - Key=s3_key, - ContentType='audio/mp3' - ) + +def buffer_to_s3(s3_client: Any, file: IO[bytes], bucket: str, s3_prefix: str) -> str: + """ + return s3_uri of this file + """ + s3_key = f"{s3_prefix}{uuid.uuid4()}.mp3" + s3_client.put_object(Body=file.read(), Bucket=bucket, Key=s3_key, ContentType="audio/mp3") return s3_key -def generate_presigned_url(s3_client:Any, file: IO[bytes], bucket_name:str, s3_prefix:str, expiration=600) -> str: + +def generate_presigned_url(s3_client: Any, file: IO[bytes], bucket_name: str, s3_prefix: str, expiration=600) -> str: object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix) try: - response = s3_client.generate_presigned_url('get_object', - Params={'Bucket': bucket_name, 'Key': object_key}, - ExpiresIn=expiration) + response = s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket_name, "Key": object_key}, ExpiresIn=expiration + ) except Exception as e: print(f"Error generating presigned URL: {e}") return None - return response \ No newline at end of file + return response diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py index 8b57f182fe..6aa8c9995f 100644 --- a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py @@ -19,16 +19,16 @@ from core.model_runtime.model_providers.sagemaker.sagemaker import generate_pres logger = logging.getLogger(__name__) + class SageMakerSpeech2TextModel(Speech2TextModel): """ Model class for Xinference speech to text model. """ - sagemaker_client: Any = None - s3_client : Any = None - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + sagemaker_client: Any = None + s3_client: Any = None + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -42,19 +42,20 @@ class SageMakerSpeech2TextModel(Speech2TextModel): try: if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) - self.s3_client = boto3.client("s3", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) + self.s3_client = boto3.client( + "s3", aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) self.s3_client = boto3.client("s3", region_name=aws_region) @@ -62,25 +63,21 @@ class SageMakerSpeech2TextModel(Speech2TextModel): self.sagemaker_client = boto3.client("sagemaker-runtime") self.s3_client = boto3.client("s3") - s3_prefix='dify/speech2text/' - sagemaker_endpoint = credentials.get('sagemaker_endpoint') - bucket = credentials.get('audio_s3_cache_bucket') + s3_prefix = "dify/speech2text/" + sagemaker_endpoint = credentials.get("sagemaker_endpoint") + bucket = credentials.get("audio_s3_cache_bucket") s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix) - payload = { - "audio_s3_presign_uri" : s3_presign_url - } + payload = {"audio_s3_presign_uri": s3_presign_url} response_model = self.sagemaker_client.invoke_endpoint( - EndpointName=sagemaker_endpoint, - Body=json.dumps(payload), - ContentType="application/json" + EndpointName=sagemaker_endpoint, Body=json.dumps(payload), ContentType="application/json" ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - asr_text = json_obj['text'] + asr_text = json_obj["text"] except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") return asr_text @@ -105,38 +102,24 @@ class SageMakerSpeech2TextModel(Speech2TextModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index 4b2858b1a2..d55144f8a7 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -10,21 +10,22 @@ from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel BATCH_SIZE = 20 -CONTEXT_SIZE=8192 +CONTEXT_SIZE = 8192 logger = logging.getLogger(__name__) + def batch_generator(generator, batch_size): while True: batch = list(itertools.islice(generator, batch_size)) @@ -32,33 +33,28 @@ def batch_generator(generator, batch_size): break yield batch + class SageMakerEmbeddingModel(TextEmbeddingModel): """ Model class for Cohere text embedding model. """ + sagemaker_client: Any = None - def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]): + def _sagemaker_embedding(self, sm_client, endpoint_name, content_list: list[str]): response_model = sm_client.invoke_endpoint( EndpointName=endpoint_name, - Body=json.dumps( - { - "inputs": content_list, - "parameters": {}, - "is_query" : False, - "instruction" : '' - } - ), + Body=json.dumps({"inputs": content_list, "parameters": {}, "is_query": False, "instruction": ""}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - embeddings = json_obj['embeddings'] + embeddings = json_obj["embeddings"] return embeddings - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -72,25 +68,27 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): try: line = 1 if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: self.sagemaker_client = boto3.client("sagemaker-runtime") line = 2 - sagemaker_endpoint = credentials.get('sagemaker_endpoint') + sagemaker_endpoint = credentials.get("sagemaker_endpoint") line = 3 - truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ] + truncated_texts = [item[:CONTEXT_SIZE] for item in texts] batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE) all_embeddings = [] @@ -105,18 +103,14 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): usage = self._calc_response_usage( model=model, credentials=credentials, - tokens=0 # It's not SAAS API, usage is meaningless + tokens=0, # It's not SAAS API, usage is meaningless ) line = 6 - return TextEmbeddingResult( - embeddings=all_embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=all_embeddings, usage=usage, model=model) except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -153,10 +147,7 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -167,7 +158,7 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -175,40 +166,28 @@ class SageMakerEmbeddingModel(TextEmbeddingModel): @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ - + entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE, ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE, }, - parameter_rules=[] + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py index 315b31fd85..3dd5f8f64c 100644 --- a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py +++ b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py @@ -22,89 +22,93 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel logger = logging.getLogger(__name__) + class TTSModelType(Enum): PresetVoice = "PresetVoice" CloneVoice = "CloneVoice" CloneVoice_CrossLingual = "CloneVoice_CrossLingual" InstructVoice = "InstructVoice" -class SageMakerText2SpeechModel(TTSModel): +class SageMakerText2SpeechModel(TTSModel): sagemaker_client: Any = None - s3_client : Any = None - comprehend_client : Any = None + s3_client: Any = None + comprehend_client: Any = None def __init__(self): # preset voices, need support custom voice self.model_voices = { - '__default': { - 'all': [ - {'name': 'Default', 'value': 'default'}, + "__default": { + "all": [ + {"name": "Default", "value": "default"}, ] }, - 'CosyVoice': { - 'zh-Hans': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "CosyVoice": { + "zh-Hans": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'zh-Hant': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "zh-Hant": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'en-US': [ - {'name': '英文男', 'value': '英文男'}, - {'name': '英文女', 'value': '英文女'}, + "en-US": [ + {"name": "英文男", "value": "英文男"}, + {"name": "英文女", "value": "英文女"}, ], - 'ja-JP': [ - {'name': '日语男', 'value': '日语男'}, + "ja-JP": [ + {"name": "日语男", "value": "日语男"}, ], - 'ko-KR': [ - {'name': '韩语女', 'value': '韩语女'}, - ] - } + "ko-KR": [ + {"name": "韩语女", "value": "韩语女"}, + ], + }, } def validate_credentials(self, model: str, credentials: dict) -> None: """ - Validate model credentials + Validate model credentials - :param model: model name - :param credentials: model credentials - :return: - """ + :param model: model name + :param credentials: model credentials + :return: + """ pass - def _detect_lang_code(self, content:str, map_dict:dict=None): - map_dict = { - "zh" : "<|zh|>", - "en" : "<|en|>", - "ja" : "<|jp|>", - "zh-TW" : "<|yue|>", - "ko" : "<|ko|>" - } + def _detect_lang_code(self, content: str, map_dict: dict = None): + map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"} response = self.comprehend_client.detect_dominant_language(Text=content) - language_code = response['Languages'][0]['LanguageCode'] + language_code = response["Languages"][0]["LanguageCode"] - return map_dict.get(language_code, '<|zh|>') + return map_dict.get(language_code, "<|zh|>") - def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str): + def _build_tts_payload( + self, + model_type: str, + content_text: str, + model_role: str, + prompt_text: str, + prompt_audio: str, + instruct_text: str, + ): if model_type == TTSModelType.PresetVoice.value and model_role: - return { "tts_text" : content_text, "role" : model_role } + return {"tts_text": content_text, "role": model_role} if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio: - return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio } - if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: + return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: lang_tag = self._detect_lang_code(content_text) - return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag } - if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: - return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text } + return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag} + if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: + return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text} raise RuntimeError(f"Invalid params for {model_type}") - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ _invoke text2speech model @@ -117,61 +121,55 @@ class SageMakerText2SpeechModel(TTSModel): :return: text translated to audio file """ if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) - self.s3_client = boto3.client("s3", + region_name=aws_region, + ) + self.s3_client = boto3.client( + "s3", aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region + ) + self.comprehend_client = boto3.client( + "comprehend", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) - self.comprehend_client = boto3.client('comprehend', - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) self.s3_client = boto3.client("s3", region_name=aws_region) - self.comprehend_client = boto3.client('comprehend', region_name=aws_region) + self.comprehend_client = boto3.client("comprehend", region_name=aws_region) else: self.sagemaker_client = boto3.client("sagemaker-runtime") self.s3_client = boto3.client("s3") - self.comprehend_client = boto3.client('comprehend') + self.comprehend_client = boto3.client("comprehend") - model_type = credentials.get('audio_model_type', 'PresetVoice') - prompt_text = credentials.get('prompt_text') - prompt_audio = credentials.get('prompt_audio') - instruct_text = credentials.get('instruct_text') - sagemaker_endpoint = credentials.get('sagemaker_endpoint') - payload = self._build_tts_payload( - model_type, - content_text, - voice, - prompt_text, - prompt_audio, - instruct_text - ) + model_type = credentials.get("audio_model_type", "PresetVoice") + prompt_text = credentials.get("prompt_text") + prompt_audio = credentials.get("prompt_audio") + instruct_text = credentials.get("instruct_text") + sagemaker_endpoint = credentials.get("sagemaker_endpoint") + payload = self._build_tts_payload(model_type, content_text, voice, prompt_text, prompt_audio, instruct_text) return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity @@ -187,23 +185,11 @@ class SageMakerText2SpeechModel(TTSModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def _get_model_default_voice(self, model: str, credentials: dict) -> any: @@ -219,27 +205,27 @@ class SageMakerText2SpeechModel(TTSModel): return 5 def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: - audio_model_name = 'CosyVoice' + audio_model_name = "CosyVoice" for key, voices in self.model_voices.items(): if key in audio_model_name: if language and language in voices: return voices[language] - elif 'all' in voices: - return voices['all'] + elif "all" in voices: + return voices["all"] - return self.model_voices['__default']['all'] + return self.model_voices["__default"]["all"] - def _invoke_sagemaker(self, payload:dict, endpoint:str): + def _invoke_sagemaker(self, payload: dict, endpoint: str): response_model = self.sagemaker_client.invoke_endpoint( EndpointName=endpoint, Body=json.dumps(payload), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) return json_obj - def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint:str) -> any: + def _tts_invoke_streaming(self, model_type: str, payload: dict, sagemaker_endpoint: str) -> any: """ _tts_invoke_streaming text2speech model @@ -250,38 +236,40 @@ class SageMakerText2SpeechModel(TTSModel): :return: text translated to audio file """ try: - lang_tag = '' + lang_tag = "" if model_type == TTSModelType.CloneVoice_CrossLingual.value: - lang_tag = payload.pop('lang_tag') - - word_limit = self._get_model_word_limit(model='', credentials={}) + lang_tag = payload.pop("lang_tag") + + word_limit = self._get_model_word_limit(model="", credentials={}) content_text = payload.get("tts_text") if len(content_text) > word_limit: split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit) - sentences = [ f"{lang_tag}{s}" for s in split_sentences if len(s) ] + sentences = [f"{lang_tag}{s}" for s in split_sentences if len(s)] len_sent = len(sentences) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent)) - payloads = [ copy.deepcopy(payload) for i in range(len_sent) ] + payloads = [copy.deepcopy(payload) for i in range(len_sent)] for idx in range(len_sent): payloads[idx]["tts_text"] = sentences[idx] - futures = [ executor.submit( - self._invoke_sagemaker, - payload=payload, - endpoint=sagemaker_endpoint, - ) - for payload in payloads] + futures = [ + executor.submit( + self._invoke_sagemaker, + payload=payload, + endpoint=sagemaker_endpoint, + ) + for payload in payloads + ] for index, future in enumerate(futures): resp = future.result() - audio_bytes = requests.get(resp.get('s3_presign_url')).content + audio_bytes = requests.get(resp.get("s3_presign_url")).content for i in range(0, len(audio_bytes), 1024): - yield audio_bytes[i:i + 1024] + yield audio_bytes[i : i + 1024] else: resp = self._invoke_sagemaker(payload, sagemaker_endpoint) - audio_bytes = requests.get(resp.get('s3_presign_url')).content + audio_bytes = requests.get(resp.get("s3_presign_url")).content for i in range(0, len(audio_bytes), 1024): - yield audio_bytes[i:i + 1024] + yield audio_bytes[i : i + 1024] except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py index a9ce7b98c3..c1868b6ad0 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py +++ b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py @@ -7,11 +7,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,5 +27,5 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel): @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.siliconflow.cn/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py index 6835915816..6f652e9d52 100644 --- a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py @@ -16,39 +16,39 @@ from core.model_runtime.model_providers.__base.rerank_model import RerankModel class SiliconflowRerankModel(RerankModel): - - def _invoke(self, model: str, credentials: dict, query: str, docs: list[str], - score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: if len(docs) == 0: return RerankResult(model=model, docs=[]) - base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1') - if base_url.endswith('/'): + base_url = credentials.get("base_url", "https://api.siliconflow.cn/v1") + if base_url.endswith("/"): base_url = base_url[:-1] try: response = httpx.post( - base_url + '/rerank', - json={ - "model": model, - "query": query, - "documents": docs, - "top_n": top_n, - "return_documents": True - }, - headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_n": top_n, "return_documents": True}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, ) response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) @@ -57,7 +57,6 @@ class SiliconflowRerankModel(RerankModel): def validate_credentials(self, model: str, credentials: dict) -> None: try: - self._invoke( model=model, credentials=credentials, @@ -68,7 +67,7 @@ class SiliconflowRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -83,5 +82,5 @@ class SiliconflowRerankModel(RerankModel): InvokeServerUnavailableError: [httpx.RemoteProtocolError], InvokeRateLimitError: [], InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] - } \ No newline at end of file + InvokeBadRequestError: [httpx.RequestError], + } diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py index dd0eea362a..e121ab8c7e 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class SiliconflowProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class SiliconflowProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials=credentials - ) + model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py index 6ad3cab587..8d1932863e 100644 --- a/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py @@ -8,9 +8,7 @@ class SiliconflowSpeech2TextModel(OAICompatSpeech2TextModel): Model class for Siliconflow Speech to text model. """ - def _invoke( - self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None - ) -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py index c58765cecb..6cdf4933b4 100644 --- a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py @@ -10,20 +10,21 @@ class SiliconflowTextEmbeddingModel(OAICompatEmbeddingModel): """ Model class for Siliconflow text embedding model. """ + def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, texts, user) - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: self._add_custom_parameters(credentials) return super().get_num_tokens(model, credentials, texts) - + @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['endpoint_url'] = 'https://api.siliconflow.cn/v1' \ No newline at end of file + credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" diff --git a/api/core/model_runtime/model_providers/spark/llm/_client.py b/api/core/model_runtime/model_providers/spark/llm/_client.py index d57766a87a..25223e8340 100644 --- a/api/core/model_runtime/model_providers/spark/llm/_client.py +++ b/api/core/model_runtime/model_providers/spark/llm/_client.py @@ -15,54 +15,35 @@ import websocket class SparkLLMClient: def __init__(self, model: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): - domain = 'spark-api.xf-yun.com' - endpoint = 'chat' + domain = "spark-api.xf-yun.com" + endpoint = "chat" if api_domain: domain = api_domain model_api_configs = { - 'spark-lite': { - 'version': 'v1.1', - 'chat_domain': 'general' - }, - 'spark-pro': { - 'version': 'v3.1', - 'chat_domain': 'generalv3' - }, - 'spark-pro-128k': { - 'version': 'pro-128k', - 'chat_domain': 'pro-128k' - }, - 'spark-max': { - 'version': 'v3.5', - 'chat_domain': 'generalv3.5' - }, - 'spark-4.0-ultra': { - 'version': 'v4.0', - 'chat_domain': '4.0Ultra' - } + "spark-lite": {"version": "v1.1", "chat_domain": "general"}, + "spark-pro": {"version": "v3.1", "chat_domain": "generalv3"}, + "spark-pro-128k": {"version": "pro-128k", "chat_domain": "pro-128k"}, + "spark-max": {"version": "v3.5", "chat_domain": "generalv3.5"}, + "spark-4.0-ultra": {"version": "v4.0", "chat_domain": "4.0Ultra"}, } - api_version = model_api_configs[model]['version'] + api_version = model_api_configs[model]["version"] - self.chat_domain = model_api_configs[model]['chat_domain'] + self.chat_domain = model_api_configs[model]["chat_domain"] - if model == 'spark-pro-128k': + if model == "spark-pro-128k": self.api_base = f"wss://{domain}/{endpoint}/{api_version}" else: self.api_base = f"wss://{domain}/{api_version}/{endpoint}" self.app_id = app_id self.ws_url = self.create_url( - urlparse(self.api_base).netloc, - urlparse(self.api_base).path, - self.api_base, - api_key, - api_secret + urlparse(self.api_base).netloc, urlparse(self.api_base).path, self.api_base, api_key, api_secret ) self.queue = queue.Queue() - self.blocking_message = '' + self.blocking_message = "" def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str: # generate timestamp by RFC1123 @@ -74,33 +55,29 @@ class SparkLLMClient: signature_origin += "GET " + path + " HTTP/1.1" # encrypt using hmac-sha256 - signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() + signature_sha = hmac.new( + api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256 + ).digest() - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") - v = { - "authorization": authorization, - "date": date, - "host": host - } + v = {"authorization": authorization, "date": date, "host": host} # generate url - url = api_base + '?' + urlencode(v) + url = api_base + "?" + urlencode(v) return url - def run(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None, streaming: bool = False): + def run(self, messages: list, user_id: str, model_kwargs: Optional[dict] = None, streaming: bool = False): websocket.enableTrace(False) ws = websocket.WebSocketApp( self.ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, - on_open=self.on_open + on_open=self.on_open, ) ws.messages = messages ws.user_id = user_id @@ -109,86 +86,71 @@ class SparkLLMClient: ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) def on_error(self, ws, error): - self.queue.put({ - 'status_code': error.status_code, - 'error': error.resp_body.decode('utf-8') - }) + self.queue.put({"status_code": error.status_code, "error": error.resp_body.decode("utf-8")}) ws.close() def on_close(self, ws, close_status_code, close_reason): - self.queue.put({'done': True}) + self.queue.put({"done": True}) def on_open(self, ws): - self.blocking_message = '' - data = json.dumps(self.gen_params( - messages=ws.messages, - user_id=ws.user_id, - model_kwargs=ws.model_kwargs - )) + self.blocking_message = "" + data = json.dumps(self.gen_params(messages=ws.messages, user_id=ws.user_id, model_kwargs=ws.model_kwargs)) ws.send(data) def on_message(self, ws, message): data = json.loads(message) - code = data['header']['code'] + code = data["header"]["code"] if code != 0: - self.queue.put({ - 'status_code': 400, - 'error': f"Code: {code}, Error: {data['header']['message']}" - }) + self.queue.put({"status_code": 400, "error": f"Code: {code}, Error: {data['header']['message']}"}) ws.close() else: choices = data["payload"]["choices"] status = choices["status"] content = choices["text"][0]["content"] if ws.streaming: - self.queue.put({'data': content}) + self.queue.put({"data": content}) else: self.blocking_message += content if status == 2: if not ws.streaming: - self.queue.put({'data': self.blocking_message}) + self.queue.put({"data": self.blocking_message}) ws.close() - def gen_params(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None) -> dict: + def gen_params(self, messages: list, user_id: str, model_kwargs: Optional[dict] = None) -> dict: data = { "header": { "app_id": self.app_id, # resolve this error message => $.header.uid' length must be less or equal than 32 - "uid": user_id[:32] if user_id else None + "uid": user_id[:32] if user_id else None, }, - "parameter": { - "chat": { - "domain": self.chat_domain - } - }, - "payload": { - "message": { - "text": messages - } - } + "parameter": {"chat": {"domain": self.chat_domain}}, + "payload": {"message": {"text": messages}}, } if model_kwargs: - data['parameter']['chat'].update(model_kwargs) + data["parameter"]["chat"].update(model_kwargs) return data def subscribe(self): while True: content = self.queue.get() - if 'error' in content: - if content['status_code'] == 401: - raise SparkError('[Spark] The credentials you provided are incorrect. ' - 'Please double-check and fill them in again.') - elif content['status_code'] == 403: - raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " - "Please try again after obtaining the necessary permissions.") + if "error" in content: + if content["status_code"] == 401: + raise SparkError( + "[Spark] The credentials you provided are incorrect. " + "Please double-check and fill them in again." + ) + elif content["status_code"] == 403: + raise SparkError( + "[Spark] Sorry, the credentials you provided are access denied. " + "Please try again after obtaining the necessary permissions." + ) else: raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") - if 'data' not in content: + if "data" not in content: break yield content diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 65beae517c..0c42acf5aa 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -25,12 +25,17 @@ from ._client import SparkLLMClient class SparkLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -47,8 +52,13 @@ class SparkLargeLanguageModel(LargeLanguageModel): # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -80,15 +90,21 @@ class SparkLargeLanguageModel(LargeLanguageModel): model_parameters={ "temperature": 0.5, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -103,7 +119,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): """ extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) @@ -113,21 +129,33 @@ class SparkLargeLanguageModel(LargeLanguageModel): **credentials_kwargs, ) - thread = threading.Thread(target=client.run, args=( - [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages], - user, - model_parameters, - stream - )) + thread = threading.Thread( + target=client.run, + args=( + [ + {"role": prompt_message.role.value, "content": prompt_message.content} + for prompt_message in prompt_messages + ], + user, + model_parameters, + stream, + ), + ) thread.start() if stream: return self._handle_generate_stream_response(thread, model, credentials, client, prompt_messages) return self._handle_generate_response(thread, model, credentials, client, prompt_messages) - - def _handle_generate_response(self, thread: threading.Thread, model: str, credentials: dict, client: SparkLLMClient, - prompt_messages: list[PromptMessage]) -> LLMResult: + + def _handle_generate_response( + self, + thread: threading.Thread, + model: str, + credentials: dict, + client: SparkLLMClient, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -140,7 +168,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): for content in client.subscribe(): if isinstance(content, dict): - delta = content['data'] + delta = content["data"] else: delta = content @@ -148,9 +176,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): thread.join() # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=completion - ) + assistant_prompt_message = AssistantPromptMessage(content=completion) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -168,9 +194,15 @@ class SparkLargeLanguageModel(LargeLanguageModel): ) return result - - def _handle_generate_stream_response(self, thread: threading.Thread, model: str, credentials: dict, client: SparkLLMClient, - prompt_messages: list[PromptMessage]) -> Generator: + + def _handle_generate_stream_response( + self, + thread: threading.Thread, + model: str, + credentials: dict, + client: SparkLLMClient, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -183,12 +215,12 @@ class SparkLargeLanguageModel(LargeLanguageModel): """ for index, content in enumerate(client.subscribe()): if isinstance(content, dict): - delta = content['data'] + delta = content["data"] else: delta = content assistant_prompt_message = AssistantPromptMessage( - content=delta if delta else '', + content=delta if delta else "", ) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -199,11 +231,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message, usage=usage), ) thread.join() @@ -216,9 +244,9 @@ class SparkLargeLanguageModel(LargeLanguageModel): :return: """ credentials_kwargs = { - "app_id": credentials['app_id'], - "api_secret": credentials['api_secret'], - "api_key": credentials['api_key'], + "app_id": credentials["app_id"], + "api_secret": credentials["api_secret"], + "api_key": credentials["api_key"], } return credentials_kwargs @@ -244,7 +272,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): raise ValueError(f"Got unknown type {message}") return message_text - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model @@ -254,10 +282,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -277,5 +302,5 @@ class SparkLargeLanguageModel(LargeLanguageModel): InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/stepfun/llm/llm.py b/api/core/model_runtime/model_providers/stepfun/llm/llm.py index 6f6ffc8faa..dab666e4d0 100644 --- a/api/core/model_runtime/model_providers/stepfun/llm/llm.py +++ b/api/core/model_runtime/model_providers/stepfun/llm/llm.py @@ -30,11 +30,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) self._add_function_call(model, credentials) user = user[:32] if user else None @@ -49,51 +55,51 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): model=model, label=I18nObject(en_US=model, zh_Hans=model), model_type=ModelType.LLM, - features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] - if credentials.get('function_calling_type') == 'tool_call' - else [], + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)), ModelPropertyKey.MODE: LLMMode.CHAT.value, }, parameter_rules=[ ParameterRule( - name='temperature', - use_template='temperature', - label=I18nObject(en_US='Temperature', zh_Hans='温度'), + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), type=ParameterType.FLOAT, ), ParameterRule( - name='max_tokens', - use_template='max_tokens', + name="max_tokens", + use_template="max_tokens", default=512, min=1, - max=int(credentials.get('max_tokens', 1024)), - label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + max=int(credentials.get("max_tokens", 1024)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), type=ParameterType.INT, ), ParameterRule( - name='top_p', - use_template='top_p', - label=I18nObject(en_US='Top P', zh_Hans='Top P'), + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), type=ParameterType.FLOAT, ), - ] + ], ) def _add_custom_parameters(self, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.stepfun.com/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.stepfun.com/v1" def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) - if model_schema and { - ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL - }.intersection(model_schema.features or []): - credentials['function_calling_type'] = 'tool_call' + if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection( + model_schema.features or [] + ): + credentials["function_calling_type"] = "tool_call" - def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict: + def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict: """ Convert PromptMessage to dict for OpenAI API format """ @@ -106,10 +112,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -117,7 +120,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): "type": "image_url", "image_url": { "url": message_content.data, - } + }, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -127,14 +130,16 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): if message.tool_calls: message_dict["tool_calls"] = [] for function_call in message.tool_calls: - message_dict["tool_calls"].append({ - "id": function_call.id, - "type": function_call.type, - "function": { - "name": function_call.function.name, - "arguments": function_call.function.arguments + message_dict["tool_calls"].append( + { + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + }, } - }) + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} @@ -160,21 +165,26 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", - arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + name=response_tool_call["function"]["name"] + if response_tool_call.get("function", {}).get("name") + else "", + arguments=response_tool_call["function"]["arguments"] + if response_tool_call.get("function", {}).get("arguments") + else "", ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call["id"] if response_tool_call.get("id") else "", type=response_tool_call["type"] if response_tool_call.get("type") else "", - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -184,11 +194,12 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -199,12 +210,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) tools_calls: list[AssistantPromptMessage.ToolCall] = [] @@ -218,9 +224,9 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) if tool_call is None: tool_call = AssistantPromptMessage.ToolCall( - id='', - type='', - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + id="", + type="", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), ) tools_calls.append(tool_call) @@ -242,9 +248,9 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) @@ -253,21 +259,21 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = delta.get("tool_calls", None) # assistant_message_function_call = delta.delta.function_call # extract tool calls from response @@ -275,19 +281,18 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] + content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else [] ) full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -303,26 +308,21 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 - + if tools_calls: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason - ) \ No newline at end of file + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason + ) diff --git a/api/core/model_runtime/model_providers/stepfun/stepfun.py b/api/core/model_runtime/model_providers/stepfun/stepfun.py index 50b17392b5..e1c41a9153 100644 --- a/api/core/model_runtime/model_providers/stepfun/stepfun.py +++ b/api/core/model_runtime/model_providers/stepfun/stepfun.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class StepfunProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ class StepfunProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='step-1-8k', - credentials=credentials - ) + model_instance.validate_credentials(model="step-1-8k", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py b/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py index b62b9860cb..9fd4a45f45 100644 --- a/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py +++ b/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py @@ -69,8 +69,8 @@ class FlashRecognizer: """ response: request_id string - status Integer - message String + status Integer + message String audio_duration Integer flash_result Result Array @@ -81,16 +81,16 @@ class FlashRecognizer: Sentence: text String - start_time Integer - end_time Integer - speaker_id Integer + start_time Integer + end_time Integer + speaker_id Integer word_list Word Array Word: - word String - start_time Integer - end_time Integer - stable_flag: Integer + word String + start_time Integer + end_time Integer + stable_flag: Integer """ def __init__(self, appid, credential): @@ -100,13 +100,13 @@ class FlashRecognizer: def _format_sign_string(self, param): signstr = "POSTasr.cloud.tencent.com/asr/flash/v1/" for t in param: - if 'appid' in t: + if "appid" in t: signstr += str(t[1]) break signstr += "?" for x in param: tmp = x - if 'appid' in x: + if "appid" in x: continue for t in tmp: signstr += str(t) @@ -121,10 +121,9 @@ class FlashRecognizer: return header def _sign(self, signstr, secret_key): - hmacstr = hmac.new(secret_key.encode('utf-8'), - signstr.encode('utf-8'), hashlib.sha1).digest() + hmacstr = hmac.new(secret_key.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest() s = base64.b64encode(hmacstr) - s = s.decode('utf-8') + s = s.decode("utf-8") return s def _build_req_with_signature(self, secret_key, params, header): @@ -138,14 +137,22 @@ class FlashRecognizer: def _create_query_arr(self, req): return { - 'appid': self.appid, 'secretid': self.credential.secret_id, 'timestamp': str(int(time.time())), - 'engine_type': req.engine_type, 'voice_format': req.voice_format, - 'speaker_diarization': req.speaker_diarization, 'hotword_id': req.hotword_id, - 'customization_id': req.customization_id, 'filter_dirty': req.filter_dirty, - 'filter_modal': req.filter_modal, 'filter_punc': req.filter_punc, - 'convert_num_mode': req.convert_num_mode, 'word_info': req.word_info, - 'first_channel_only': req.first_channel_only, 'reinforce_hotword': req.reinforce_hotword, - 'sentence_max_length': req.sentence_max_length + "appid": self.appid, + "secretid": self.credential.secret_id, + "timestamp": str(int(time.time())), + "engine_type": req.engine_type, + "voice_format": req.voice_format, + "speaker_diarization": req.speaker_diarization, + "hotword_id": req.hotword_id, + "customization_id": req.customization_id, + "filter_dirty": req.filter_dirty, + "filter_modal": req.filter_modal, + "filter_punc": req.filter_punc, + "convert_num_mode": req.convert_num_mode, + "word_info": req.word_info, + "first_channel_only": req.first_channel_only, + "reinforce_hotword": req.reinforce_hotword, + "sentence_max_length": req.sentence_max_length, } def recognize(self, req, data): diff --git a/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py b/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py index 00ec5aa9c8..5b427663ca 100644 --- a/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py @@ -18,9 +18,7 @@ from core.model_runtime.model_providers.tencent.speech2text.flash_recognizer imp class TencentSpeech2TextModel(Speech2TextModel): - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -43,7 +41,7 @@ class TencentSpeech2TextModel(Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -83,10 +81,6 @@ class TencentSpeech2TextModel(Speech2TextModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - requests.exceptions.ConnectionError - ], - InvokeAuthorizationError: [ - CredentialsValidateFailedError - ] + InvokeConnectionError: [requests.exceptions.ConnectionError], + InvokeAuthorizationError: [CredentialsValidateFailedError], } diff --git a/api/core/model_runtime/model_providers/tencent/tencent.py b/api/core/model_runtime/model_providers/tencent/tencent.py index dd9f90bb47..79c6f577b8 100644 --- a/api/core/model_runtime/model_providers/tencent/tencent.py +++ b/api/core/model_runtime/model_providers/tencent/tencent.py @@ -18,12 +18,9 @@ class TencentProvider(ModelProvider): """ try: model_instance = self.get_model_instance(ModelType.SPEECH2TEXT) - model_instance.validate_credentials( - model='tencent', - credentials=credentials - ) + model_instance.validate_credentials(model="tencent", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index bb802d4071..b96d43979e 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -22,16 +22,21 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_endpoint_url(self, credentials: dict): - credentials['endpoint_url'] = "https://api.together.xyz/v1" + credentials["endpoint_url"] = "https://api.together.xyz/v1" return credentials - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) @@ -41,12 +46,22 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): return super().validate_credentials(model, cred_with_endpoint) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) - return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate( + model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user + ) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) @@ -61,45 +76,45 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, features=features, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get('context_size', "4096")), - ModelPropertyKey.MODE: cred_with_endpoint.get('mode'), + ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get("context_size", "4096")), + ModelPropertyKey.MODE: cred_with_endpoint.get("mode"), }, parameter_rules=[ ParameterRule( name=DefaultParameterName.TEMPERATURE.value, label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('temperature', 0.7)), + default=float(cred_with_endpoint.get("temperature", 0.7)), min=0, max=2, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.TOP_P.value, label=I18nObject(en_US="Top P"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('top_p', 1)), + default=float(cred_with_endpoint.get("top_p", 1)), min=0, max=1, - precision=2 + precision=2, ), ParameterRule( name=TOP_K, label=I18nObject(en_US="Top K"), type=ParameterType.INT, - default=int(cred_with_endpoint.get('top_k', 50)), + default=int(cred_with_endpoint.get("top_k", 50)), min=-2147483647, max=2147483647, - precision=0 + precision=0, ), ParameterRule( name=REPETITION_PENALTY, label=I18nObject(en_US="Repetition Penalty"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('repetition_penalty', 1)), + default=float(cred_with_endpoint.get("repetition_penalty", 1)), min=-3.4, max=3.4, - precision=1 + precision=1, ), ParameterRule( name=DefaultParameterName.MAX_TOKENS.value, @@ -107,46 +122,49 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): type=ParameterType.INT, default=512, min=1, - max=int(cred_with_endpoint.get('max_tokens_to_sample', 4096)), + max=int(cred_with_endpoint.get("max_tokens_to_sample", 4096)), ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, label=I18nObject(en_US="Frequency Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('frequency_penalty', 0)), + default=float(credentials.get("frequency_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY.value, label=I18nObject(en_US="Presence Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('presence_penalty', 0)), + default=float(credentials.get("presence_penalty", 0)), min=-2, - max=2 + max=2, ), ], pricing=PriceConfig( - input=Decimal(cred_with_endpoint.get('input_price', 0)), - output=Decimal(cred_with_endpoint.get('output_price', 0)), - unit=Decimal(cred_with_endpoint.get('unit', 0)), - currency=cred_with_endpoint.get('currency', "USD") + input=Decimal(cred_with_endpoint.get("input_price", 0)), + output=Decimal(cred_with_endpoint.get("output_price", 0)), + unit=Decimal(cred_with_endpoint.get("unit", 0)), + currency=cred_with_endpoint.get("currency", "USD"), ), ) - if cred_with_endpoint['mode'] == 'chat': + if cred_with_endpoint["mode"] == "chat": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif cred_with_endpoint['mode'] == 'completion': + elif cred_with_endpoint["mode"] == "completion": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {cred_with_endpoint['completion_type']}") return entity - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) - - diff --git a/api/core/model_runtime/model_providers/togetherai/togetherai.py b/api/core/model_runtime/model_providers/togetherai/togetherai.py index ffce4794e7..aa4100a7c9 100644 --- a/api/core/model_runtime/model_providers/togetherai/togetherai.py +++ b/api/core/model_runtime/model_providers/togetherai/togetherai.py @@ -6,6 +6,5 @@ logger = logging.getLogger(__name__) class TogetherAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/tongyi/_common.py b/api/core/model_runtime/model_providers/tongyi/_common.py index fab18b41fd..8a50c7aa05 100644 --- a/api/core/model_runtime/model_providers/tongyi/_common.py +++ b/api/core/model_runtime/model_providers/tongyi/_common.py @@ -21,7 +21,7 @@ class _CommonTongyi: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: credentials_kwargs = { - "dashscope_api_key": credentials['dashscope_api_key'], + "dashscope_api_key": credentials["dashscope_api_key"], } return credentials_kwargs @@ -51,5 +51,5 @@ class _CommonTongyi: InvalidParameter, UnsupportedModel, UnsupportedHTTPMethod, - ] + ], } diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 6667d40440..72c319d395 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -45,11 +45,17 @@ from core.model_runtime.model_providers.__base.large_language_model import Large class TongyiLargeLanguageModel(LargeLanguageModel): tokenizers = {} - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -65,8 +71,14 @@ class TongyiLargeLanguageModel(LargeLanguageModel): """ # invoke model without code wrapper return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -76,10 +88,10 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return: """ - if model in ['qwen-turbo-chat', 'qwen-plus-chat']: - model = model.replace('-chat', '') - if model == 'farui-plus': - model = 'qwen-farui-plus' + if model in ["qwen-turbo-chat", "qwen-plus-chat"]: + model = model.replace("-chat", "") + if model == "farui-plus": + model = "qwen-farui-plus" if model in self.tokenizers: tokenizer = self.tokenizers[model] @@ -110,16 +122,22 @@ class TongyiLargeLanguageModel(LargeLanguageModel): model_parameters={ "temperature": 0.5, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -138,18 +156,18 @@ class TongyiLargeLanguageModel(LargeLanguageModel): mode = self.get_model_mode(model, credentials) - if model in ['qwen-turbo-chat', 'qwen-plus-chat']: - model = model.replace('-chat', '') + if model in ["qwen-turbo-chat", "qwen-plus-chat"]: + model = model.replace("-chat", "") extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = self._convert_tools(tools) + extra_model_kwargs["tools"] = self._convert_tools(tools) if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop params = { - 'model': model, + "model": model, **model_parameters, **credentials_kwargs, **extra_model_kwargs, @@ -157,23 +175,22 @@ class TongyiLargeLanguageModel(LargeLanguageModel): model_schema = self.get_model_schema(model, credentials) if ModelFeature.VISION in (model_schema.features or []): - params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages, rich_content=True) + params["messages"] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages, rich_content=True) response = MultiModalConversation.call(**params, stream=stream) else: # nothing different between chat model and completion model in tongyi - params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) - response = Generation.call(**params, - result_format='message', - stream=stream) + params["messages"] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) + response = Generation.call(**params, result_format="message", stream=stream) if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: GenerationResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: GenerationResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -184,9 +201,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :return: llm response """ if response.status_code != 200 and response.status_code != HTTPStatus.OK: - raise ServiceUnavailableError( - response.message - ) + raise ServiceUnavailableError(response.message) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( content=response.output.choices[0].message.content, @@ -205,9 +220,13 @@ class TongyiLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, - responses: Generator[GenerationResponse, None, None], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + responses: Generator[GenerationResponse, None, None], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -217,7 +236,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" tool_calls = [] for index, response in enumerate(responses): if response.status_code != 200 and response.status_code != HTTPStatus.OK: @@ -228,22 +247,22 @@ class TongyiLargeLanguageModel(LargeLanguageModel): resp_finish_reason = response.output.choices[0].finish_reason - if resp_finish_reason is not None and resp_finish_reason != 'null': + if resp_finish_reason is not None and resp_finish_reason != "null": resp_content = response.output.choices[0].message.content assistant_prompt_message = AssistantPromptMessage( - content='', + content="", ) - if 'tool_calls' in response.output.choices[0].message: - tool_calls = response.output.choices[0].message['tool_calls'] + if "tool_calls" in response.output.choices[0].message: + tool_calls = response.output.choices[0].message["tool_calls"] elif resp_content: # special for qwen-vl if isinstance(resp_content, list): - resp_content = resp_content[0]['text'] + resp_content = resp_content[0]["text"] # transform assistant message to prompt message - assistant_prompt_message.content = resp_content.replace(full_text, '', 1) + assistant_prompt_message.content = resp_content.replace(full_text, "", 1) full_text = resp_content @@ -251,12 +270,11 @@ class TongyiLargeLanguageModel(LargeLanguageModel): message_tool_calls = [] for tool_call_obj in tool_calls: message_tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_obj['function']['name'], - type='function', + id=tool_call_obj["function"]["name"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call_obj['function']['name'], - arguments=tool_call_obj['function']['arguments'] - ) + name=tool_call_obj["function"]["name"], arguments=tool_call_obj["function"]["arguments"] + ), ) message_tool_calls.append(message_tool_call) @@ -270,26 +288,23 @@ class TongyiLargeLanguageModel(LargeLanguageModel): model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=resp_finish_reason, - usage=usage - ) + index=index, message=assistant_prompt_message, finish_reason=resp_finish_reason, usage=usage + ), ) else: resp_content = response.output.choices[0].message.content if not resp_content: - if 'tool_calls' in response.output.choices[0].message: - tool_calls = response.output.choices[0].message['tool_calls'] + if "tool_calls" in response.output.choices[0].message: + tool_calls = response.output.choices[0].message["tool_calls"] continue # special for qwen-vl if isinstance(resp_content, list): - resp_content = resp_content[0]['text'] + resp_content = resp_content[0]["text"] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=resp_content.replace(full_text, '', 1), + content=resp_content.replace(full_text, "", 1), ) full_text = resp_content @@ -297,10 +312,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) def _to_credential_kwargs(self, credentials: dict) -> dict: @@ -311,7 +323,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :return: """ credentials_kwargs = { - "api_key": credentials['dashscope_api_key'], + "api_key": credentials["dashscope_api_key"], } return credentials_kwargs @@ -356,16 +368,14 @@ class TongyiLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[PromptMessage], - rich_content: bool = False) -> list[dict]: + def _convert_prompt_messages_to_tongyi_messages( + self, prompt_messages: list[PromptMessage], rich_content: bool = False + ) -> list[dict]: """ Convert prompt messages to tongyi messages @@ -375,24 +385,28 @@ class TongyiLargeLanguageModel(LargeLanguageModel): tongyi_messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): - tongyi_messages.append({ - 'role': 'system', - 'content': prompt_message.content if not rich_content else [{"text": prompt_message.content}], - }) + tongyi_messages.append( + { + "role": "system", + "content": prompt_message.content if not rich_content else [{"text": prompt_message.content}], + } + ) elif isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, str): - tongyi_messages.append({ - 'role': 'user', - 'content': prompt_message.content if not rich_content else [{"text": prompt_message.content}], - }) + tongyi_messages.append( + { + "role": "user", + "content": prompt_message.content + if not rich_content + else [{"text": prompt_message.content}], + } + ) else: sub_messages = [] for message_content in prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "text": message_content.data - } + sub_message_dict = {"text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -402,35 +416,25 @@ class TongyiLargeLanguageModel(LargeLanguageModel): # convert image base64 data to file in /tmp image_url = self._save_base64_image_to_file(message_content.data) - sub_message_dict = { - "image": image_url - } + sub_message_dict = {"image": image_url} sub_messages.append(sub_message_dict) # resort sub_messages to ensure text is always at last - sub_messages = sorted(sub_messages, key=lambda x: 'text' in x) + sub_messages = sorted(sub_messages, key=lambda x: "text" in x) - tongyi_messages.append({ - 'role': 'user', - 'content': sub_messages - }) + tongyi_messages.append({"role": "user", "content": sub_messages}) elif isinstance(prompt_message, AssistantPromptMessage): content = prompt_message.content if not content: - content = ' ' - message = { - 'role': 'assistant', - 'content': content if not rich_content else [{"text": content}] - } + content = " " + message = {"role": "assistant", "content": content if not rich_content else [{"text": content}]} if prompt_message.tool_calls: - message['tool_calls'] = [tool_call.model_dump() for tool_call in prompt_message.tool_calls] + message["tool_calls"] = [tool_call.model_dump() for tool_call in prompt_message.tool_calls] tongyi_messages.append(message) elif isinstance(prompt_message, ToolPromptMessage): - tongyi_messages.append({ - "role": "tool", - "content": prompt_message.content, - "name": prompt_message.tool_call_id - }) + tongyi_messages.append( + {"role": "tool", "content": prompt_message.content, "name": prompt_message.tool_call_id} + ) else: raise ValueError(f"Got unknown type {prompt_message}") @@ -445,7 +449,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :return: image file path """ # get mime type and encoded string - mime_type, encoded_string = base64_image.split(',')[0].split(';')[0].split(':')[1], base64_image.split(',')[1] + mime_type, encoded_string = base64_image.split(",")[0].split(";")[0].split(":")[1], base64_image.split(",")[1] # save image to file temp_dir = tempfile.gettempdir() @@ -463,19 +467,18 @@ class TongyiLargeLanguageModel(LargeLanguageModel): """ tool_definitions = [] for tool in tools: - properties = tool.parameters['properties'] - required_properties = tool.parameters['required'] + properties = tool.parameters["properties"] + required_properties = tool.parameters["required"] properties_definitions = {} for p_key, p_val in properties.items(): - desc = p_val['description'] - if 'enum' in p_val: - desc += (f"; Only accepts one of the following predefined options: " - f"[{', '.join(p_val['enum'])}]") + desc = p_val["description"] + if "enum" in p_val: + desc += f"; Only accepts one of the following predefined options: " f"[{', '.join(p_val['enum'])}]" properties_definitions[p_key] = { - 'description': desc, - 'type': p_val['type'], + "description": desc, + "type": p_val["type"], } tool_definition = { @@ -484,8 +487,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel): "name": tool.name, "description": tool.description, "parameters": properties_definitions, - "required": required_properties - } + "required": required_properties, + }, } tool_definitions.append(tool_definition) @@ -517,5 +520,5 @@ class TongyiLargeLanguageModel(LargeLanguageModel): InvalidParameter, UnsupportedModel, UnsupportedHTTPMethod, - ] + ], } diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index 97dcb72f7c..5783d2e383 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -46,7 +46,6 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) @@ -71,12 +70,8 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=used_tokens - ) - return TextEmbeddingResult( - embeddings=batched_embeddings, usage=usage, model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -108,16 +103,12 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): credentials_kwargs = self._to_credential_kwargs(credentials) # call embedding model - self.embed_documents( - credentials_kwargs=credentials_kwargs, model=model, texts=["ping"] - ) + self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @staticmethod - def embed_documents( - credentials_kwargs: dict, model: str, texts: list[str] - ) -> tuple[list[list[float]], int]: + def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]: """Call out to Tongyi's embedding endpoint. Args: @@ -145,7 +136,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): raise ValueError("Embedding data is missing in the response.") else: raise ValueError("Response output is missing or does not contain embeddings.") - + if response.usage and "total_tokens" in response.usage: embedding_used_tokens += response.usage["total_tokens"] else: @@ -153,9 +144,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel): return [list(map(float, e)) for e in embeddings], embedding_used_tokens - def _calc_response_usage( - self, model: str, credentials: dict, tokens: int - ) -> EmbeddingUsage: + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage diff --git a/api/core/model_runtime/model_providers/tongyi/tongyi.py b/api/core/model_runtime/model_providers/tongyi/tongyi.py index d5e25e6ecf..a084512de9 100644 --- a/api/core/model_runtime/model_providers/tongyi/tongyi.py +++ b/api/core/model_runtime/model_providers/tongyi/tongyi.py @@ -20,12 +20,9 @@ class TongyiProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `qwen-turbo` model for validate, - model_instance.validate_credentials( - model='qwen-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="qwen-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index 664b02cd92..48a38897a8 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -18,8 +18,9 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): Model class for Tongyi Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: """ _invoke text2speech model @@ -31,14 +32,12 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): :param user: unique user id :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in - self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -53,14 +52,13 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model @@ -82,15 +80,21 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): else: sentences = list(self._split_text_into_sentences(org_text=content, max_length=wl)) for sentence in sentences: - SpeechSynthesizer.call(model=v, sample_rate=16000, - api_key=api_key, - text=sentence.strip(), - callback=cb, - format=at, word_timestamp_enabled=True, - phoneme_timestamp_enabled=True) + SpeechSynthesizer.call( + model=v, + sample_rate=16000, + api_key=api_key, + text=sentence.strip(), + callback=cb, + format=at, + word_timestamp_enabled=True, + phoneme_timestamp_enabled=True, + ) - threading.Thread(target=invoke_remote, args=( - content_text, voice, credentials.get('dashscope_api_key'), callback, audio_type, word_limit)).start() + threading.Thread( + target=invoke_remote, + args=(content_text, voice, credentials.get("dashscope_api_key"), callback, audio_type, word_limit), + ).start() while True: audio = audio_queue.get() @@ -112,16 +116,18 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): :param audio_type: audio file type :return: text translated to audio file """ - response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000, - api_key=credentials.get('dashscope_api_key'), - text=sentence.strip(), - format=audio_type) + response = dashscope.audio.tts.SpeechSynthesizer.call( + model=voice, + sample_rate=48000, + api_key=credentials.get("dashscope_api_key"), + text=sentence.strip(), + format=audio_type, + ) if isinstance(response.get_audio_data(), bytes): return response.get_audio_data() class Callback(ResultCallback): - def __init__(self, queue: Queue): self._queue = queue diff --git a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py index 95272a41c2..cf7e3f14be 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py +++ b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py @@ -33,198 +33,223 @@ from core.model_runtime.model_providers.__base.large_language_model import Large class TritonInferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - invoke LLM + invoke LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` """ return self._generate( - model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=tools, stop=stop, stream=stream, user=user, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, ) def validate_credentials(self, model: str, credentials: dict) -> None: """ - validate credentials + validate credentials """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') - + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") + try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content='ping') - ], model_parameters={}, stream=False) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={}, + stream=False, + ) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during connection: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during connection: {str(ex)}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ - get number of tokens + get number of tokens - cause TritonInference LLM is a customized model, we could net detect which tokenizer to use - so we just take the GPT2 tokenizer as default + cause TritonInference LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default """ return self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) - + def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ - convert prompt message to text + convert prompt message to text """ - text = '' + text = "" for item in message: if isinstance(item, UserPromptMessage): - text += f'User: {item.content}' + text += f"User: {item.content}" elif isinstance(item, SystemPromptMessage): - text += f'System: {item.content}' + text += f"System: {item.content}" elif isinstance(item, AssistantPromptMessage): - text += f'Assistant: {item.content}' + text += f"Assistant: {item.content}" else: - raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") return text def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=int(credentials.get('context_length', 2048)), - default=min(512, int(credentials.get('context_length', 2048))), - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + max=int(credentials.get("context_length", 2048)), + default=min(512, int(credentials.get("context_length", 2048))), + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] completion_type = None - if 'completion_type' in credentials: - if credentials['completion_type'] == 'chat': + if "completion_type" in credentials: + if credentials["completion_type"] == "chat": completion_type = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') - + entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), parameter_rules=rules, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties={ ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_length', 2048)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_length", 2048)), }, ) return entity - - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - generate text from LLM + generate text from LLM """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') - - if 'stream' in credentials and not bool(credentials['stream']) and stream: - raise ValueError(f'stream is not supported by model {model}') + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") + + if "stream" in credentials and not bool(credentials["stream"]) and stream: + raise ValueError(f"stream is not supported by model {model}") try: parameters = {} - if 'temperature' in model_parameters: - parameters['temperature'] = model_parameters['temperature'] - if 'top_p' in model_parameters: - parameters['top_p'] = model_parameters['top_p'] - if 'top_k' in model_parameters: - parameters['top_k'] = model_parameters['top_k'] - if 'presence_penalty' in model_parameters: - parameters['presence_penalty'] = model_parameters['presence_penalty'] - if 'frequency_penalty' in model_parameters: - parameters['frequency_penalty'] = model_parameters['frequency_penalty'] + if "temperature" in model_parameters: + parameters["temperature"] = model_parameters["temperature"] + if "top_p" in model_parameters: + parameters["top_p"] = model_parameters["top_p"] + if "top_k" in model_parameters: + parameters["top_k"] = model_parameters["top_k"] + if "presence_penalty" in model_parameters: + parameters["presence_penalty"] = model_parameters["presence_penalty"] + if "frequency_penalty" in model_parameters: + parameters["frequency_penalty"] = model_parameters["frequency_penalty"] - response = post(str(URL(credentials['server_url']) / 'v2' / 'models' / model / 'generate'), json={ - 'text_input': self._convert_prompt_message_to_text(prompt_messages), - 'max_tokens': model_parameters.get('max_tokens', 512), - 'parameters': { - 'stream': False, - **parameters + response = post( + str(URL(credentials["server_url"]) / "v2" / "models" / model / "generate"), + json={ + "text_input": self._convert_prompt_message_to_text(prompt_messages), + "max_tokens": model_parameters.get("max_tokens", 512), + "parameters": {"stream": False, **parameters}, }, - }, timeout=(10, 120)) + timeout=(10, 120), + ) response.raise_for_status() if response.status_code != 200: - raise InvokeBadRequestError(f'Invoke failed with status code {response.status_code}, {response.text}') - + raise InvokeBadRequestError(f"Invoke failed with status code {response.status_code}, {response.text}") + if stream: - return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=response) - return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=response) + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) except Exception as ex: - raise InvokeConnectionError(f'An error occurred during connection: {str(ex)}') - - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Response) -> LLMResult: + raise InvokeConnectionError(f"An error occurred during connection: {str(ex)}") + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ - text = resp.json()['text_output'] + text = resp.json()["text_output"] usage = LLMUsage.empty_usage() usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) usage.completion_tokens = self._get_num_tokens_by_gpt2(text) return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=text - ), - usage=usage + model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage ) - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Response) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response, + ) -> Generator: """ - handle normal chat generate response + handle normal chat generate response """ - text = resp.json()['text_output'] + text = resp.json()["text_output"] usage = LLMUsage.empty_usage() usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -233,13 +258,7 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel): yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=text - ), - usage=usage - ) + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text), usage=usage), ) @property @@ -253,15 +272,9 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - ], - InvokeRateLimitError: [ - ], - InvokeAuthorizationError: [ - ], - InvokeBadRequestError: [ - ValueError - ] - } \ No newline at end of file + InvokeConnectionError: [], + InvokeServerUnavailableError: [], + InvokeRateLimitError: [], + InvokeAuthorizationError: [], + InvokeBadRequestError: [ValueError], + } diff --git a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py index 06846825ab..d85f7c82e7 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py +++ b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py @@ -4,6 +4,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class XinferenceAIProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/upstage/_common.py b/api/core/model_runtime/model_providers/upstage/_common.py index 13b73181e9..47ebaccd84 100644 --- a/api/core/model_runtime/model_providers/upstage/_common.py +++ b/api/core/model_runtime/model_providers/upstage/_common.py @@ -1,4 +1,3 @@ - from collections.abc import Mapping import openai @@ -20,13 +19,13 @@ class _CommonUpstage: Transform credentials to kwargs for model instance :param credentials: - :return: + :return: """ credentials_kwargs = { - "api_key": credentials['upstage_api_key'], + "api_key": credentials["upstage_api_key"], "base_url": "https://api.upstage.ai/v1/solar", "timeout": Timeout(315.0, read=300.0, write=20.0, connect=10.0), - "max_retries": 1 + "max_retries": 1, } return credentials_kwargs @@ -53,5 +52,3 @@ class _CommonUpstage: openai.APIError, ], } - - diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index d1ed4619d6..1014b53f39 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -36,15 +36,23 @@ if you are not sure about the structure. """ + class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): """ - Model class for Upstage large language model. + Model class for Upstage large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -67,15 +75,25 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _code_block_mode_wrapper(self, - model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: stop = stop or [] self._transform_chat_json_prompts( model=model, @@ -86,9 +104,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke( model=model, @@ -98,15 +116,23 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ - Transform json prompts + Transform json prompts """ if stop is None: stop = [] @@ -117,20 +143,29 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): prompt_messages[0] = SystemPromptMessage( - content=UPSTAGE_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=UPSTAGE_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) else: - prompt_messages.insert(0, SystemPromptMessage( - content=UPSTAGE_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=UPSTAGE_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -155,30 +190,31 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): client = OpenAI(**credentials_kwargs) client.chat.completions.create( - messages=[{"role": "user", "content": "ping"}], - model=model, - temperature=0, - max_tokens=10, - stream=False + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=10, stream=False ) except Exception as e: raise CredentialsValidateFailedError(str(e)) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) extra_model_kwargs = {} if tools: - extra_model_kwargs["functions"] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] if stop: extra_model_kwargs["stop"] = stop @@ -198,10 +234,15 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -222,10 +263,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -251,9 +289,14 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: """ Handle llm chat stream response @@ -263,7 +306,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): :param tools: tools for tool calling :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" delta_assistant_message_function_call_storage: Optional[ChoiceDeltaFunctionCall] = None prompt_tokens = 0 completion_tokens = 0 @@ -273,8 +316,8 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -288,8 +331,11 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): delta = chunk.choices[0] has_finish_reason = delta.finish_reason is not None - if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls @@ -311,7 +357,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): # start of stream function call delta_assistant_message_function_call_storage = assistant_message_function_call if delta_assistant_message_function_call_storage.arguments is None: - delta_assistant_message_function_call_storage.arguments = '' + delta_assistant_message_function_call_storage.arguments = "" if not has_finish_reason: continue @@ -323,11 +369,10 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" if has_finish_reason: final_chunk = LLMResultChunk( @@ -338,7 +383,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -348,7 +393,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -356,8 +401,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): if not completion_tokens: full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=final_tool_calls + content=full_assistant_content, tool_calls=final_tool_calls ) completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) @@ -367,9 +411,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): yield final_chunk - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -380,21 +424,19 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -404,14 +446,11 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call @@ -429,19 +468,13 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -467,11 +500,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): # "content": message.content, # "tool_call_id": message.tool_call_id # } - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -483,16 +512,17 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): def _get_tokenizer(self) -> Tokenizer: return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer") - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Calculate num tokens for solar with Huggingface Solar tokenizer. - Solar tokenizer is opened in huggingface https://huggingface.co/upstage/solar-1-mini-tokenizer + Solar tokenizer is opened in huggingface https://huggingface.co/upstage/solar-1-mini-tokenizer """ tokenizer = self._get_tokenizer() - tokens_per_message = 5 # <|im_start|>{role}\n{message}<|im_end|> - tokens_prefix = 1 # <|startoftext|> - tokens_suffix = 3 # <|im_start|>assistant\n + tokens_per_message = 5 # <|im_start|>{role}\n{message}<|im_end|> + tokens_prefix = 1 # <|startoftext|> + tokens_suffix = 3 # <|im_start|>assistant\n num_tokens = 0 num_tokens += tokens_prefix @@ -502,10 +532,10 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text if key == "tool_calls": @@ -538,37 +568,37 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): """ num_tokens = 0 for tool in tools: - num_tokens += len(tokenizer.encode('type')) - num_tokens += len(tokenizer.encode('function')) + num_tokens += len(tokenizer.encode("type")) + num_tokens += len(tokenizer.encode("function")) # calculate num tokens for function object - num_tokens += len(tokenizer.encode('name')) + num_tokens += len(tokenizer.encode("name")) num_tokens += len(tokenizer.encode(tool.name)) - num_tokens += len(tokenizer.encode('description')) + num_tokens += len(tokenizer.encode("description")) num_tokens += len(tokenizer.encode(tool.description)) parameters = tool.parameters - num_tokens += len(tokenizer.encode('parameters')) - if 'title' in parameters: - num_tokens += len(tokenizer.encode('title')) + num_tokens += len(tokenizer.encode("parameters")) + if "title" in parameters: + num_tokens += len(tokenizer.encode("title")) num_tokens += len(tokenizer.encode(parameters.get("title"))) - num_tokens += len(tokenizer.encode('type')) + num_tokens += len(tokenizer.encode("type")) num_tokens += len(tokenizer.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(tokenizer.encode('properties')) - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += len(tokenizer.encode("properties")) + for key, value in parameters.get("properties").items(): num_tokens += len(tokenizer.encode(key)) for field_key, field_value in value.items(): num_tokens += len(tokenizer.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(tokenizer.encode(enum_field)) else: num_tokens += len(tokenizer.encode(field_key)) num_tokens += len(tokenizer.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(tokenizer.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(tokenizer.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(tokenizer.encode(required_field)) diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index 05ae8665d6..edd4a36d98 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -18,6 +18,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): """ Model class for Upstage text embedding model. """ + def _get_tokenizer(self) -> Tokenizer: return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer") @@ -53,9 +54,9 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): for i, text in enumerate(texts): token = tokenizer.encode(text, add_special_tokens=False).tokens for j in range(0, len(token), context_size): - tokens += [token[j:j+context_size]] + tokens += [token[j : j + context_size]] indices += [i] - + batched_embeddings = [] _iter = range(0, len(tokens), max_chunks) @@ -63,20 +64,20 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): embeddings_batch, embedding_used_tokens = self._embedding_invoke( model=model, client=client, - texts=tokens[i:i+max_chunks], + texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs, ) used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch - + results: list[list[list[float]]] = [[] for _ in range(len(texts))] num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))] for i in range(len(indices)): results[indices[i]].append(batched_embeddings[i]) num_tokens_in_batch[indices[i]].append(len(tokens[i])) - + for i in range(len(texts)): _result = results[i] if len(_result) == 0: @@ -91,15 +92,11 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() - - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: tokenizer = self._get_tokenizer() """ @@ -122,7 +119,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): total_num_tokens += len(tokenized_text) return total_num_tokens - + def validate_credentials(self, model: str, credentials: Mapping) -> None: """ Validate model credentials @@ -137,16 +134,13 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): client = OpenAI(**credentials_kwargs) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + + def _embedding_invoke( + self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: """ Invoke embedding model :param model: model name @@ -155,17 +149,19 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): :param extra_model_kwargs: extra model kwargs :return: embeddings and used tokens """ - response = client.embeddings.create( - model=model, - input=texts, - **extra_model_kwargs - ) + response = client.embeddings.create(model=model, input=texts, **extra_model_kwargs) + + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": + return ( + [ + list(np.frombuffer(base64.b64decode(embedding.embedding), dtype=np.float32)) + for embedding in response.data + ], + response.usage.total_tokens, + ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': - return ([list(np.frombuffer(base64.b64decode(embedding.embedding), dtype=np.float32)) for embedding in response.data], response.usage.total_tokens) - return [data.embedding for data in response.data], response.usage.total_tokens - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -176,10 +172,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): :return: usage """ input_price_info = self.get_price( - model=model, - credentials=credentials, - tokens=tokens, - price_type=PriceType.INPUT + model=model, credentials=credentials, tokens=tokens, price_type=PriceType.INPUT ) usage = EmbeddingUsage( @@ -189,7 +182,7 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/upstage/upstage.py b/api/core/model_runtime/model_providers/upstage/upstage.py index 56c91c0061..e45d4aae19 100644 --- a/api/core/model_runtime/model_providers/upstage/upstage.py +++ b/api/core/model_runtime/model_providers/upstage/upstage.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class UpstageProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,14 +18,10 @@ class UpstageProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model="solar-1-mini-chat", - credentials=credentials - ) + model_instance.validate_credentials(model="solar-1-mini-chat", credentials=credentials) except CredentialsValidateFailedError as e: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise e except Exception as e: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise e - diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index af6ec3937c..ecb22e21bd 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -49,12 +49,17 @@ logger = logging.getLogger(__name__) class VertexAiLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -74,8 +79,16 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): # invoke Gemini model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate_anthropic( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke Anthropic large language model @@ -92,7 +105,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) project_id = credentials["vertex_project_id"] SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] - token = '' + token = "" # get access token from service account credential if service_account_info: @@ -102,40 +115,32 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): token = credentials.token # Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available in us-central1 region - if 'opus' or 'claude-3-5-sonnet' in model: - location = 'us-east5' + if "opus" or "claude-3-5-sonnet" in model: + location = "us-east5" else: - location = 'us-central1' - + location = "us-central1" + # use access token to authenticate if token: - client = AnthropicVertex( - region=location, - project_id=project_id, - access_token=token - ) + client = AnthropicVertex(region=location, project_id=project_id, access_token=token) # When access token is empty, try to use the Google Cloud VM's built-in service account or the GOOGLE_APPLICATION_CREDENTIALS environment variable else: client = AnthropicVertex( - region=location, + region=location, project_id=project_id, ) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages) if system: - extra_model_kwargs['system'] = system + extra_model_kwargs["system"] = system response = client.messages.create( - model=model, - messages=prompt_message_dicts, - stream=stream, - **model_parameters, - **extra_model_kwargs + model=model, messages=prompt_message_dicts, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -143,8 +148,9 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): return self._handle_claude_response(model, credentials, response, prompt_messages) - def _handle_claude_response(self, model: str, credentials: dict, response: Message, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_claude_response( + self, model: str, credentials: dict, response: Message, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -156,9 +162,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.content[0].text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.content[0].text) # calculate num tokens if response.usage: @@ -175,16 +179,18 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): # transform response response = LLMResult( - model=response.model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent], - prompt_messages: list[PromptMessage], ) -> Generator: + def _handle_claude_stream_response( + self, + model: str, + credentials: dict, + response: Stream[MessageStreamEvent], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -196,7 +202,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): """ try: - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -217,18 +223,16 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index + 1, - message=AssistantPromptMessage( - content='' - ), + message=AssistantPromptMessage(content=""), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else '' + chunk_text = chunk.delta.text if chunk.delta.text else "" full_assistant_content += chunk_text assistant_prompt_message = AssistantPromptMessage( - content=chunk_text if chunk_text else '', + content=chunk_text if chunk_text else "", ) index = chunk.index yield LLMResultChunk( @@ -237,12 +241,14 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) except Exception as ex: raise InvokeError(str(ex)) - def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + def _calc_claude_response_usage( + self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int + ) -> LLMUsage: """ Calculate response usage @@ -262,10 +268,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): # get completion price info completion_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.OUTPUT, - tokens=completion_tokens + model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens ) # transform usage @@ -281,7 +284,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): total_tokens=prompt_tokens + completion_tokens, total_price=prompt_price_info.total_amount + completion_price_info.total_amount, currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -295,13 +298,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): first_loop = True for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content=message.content.strip() + message.content = message.content.strip() if first_loop: - system=message.content - first_loop=False + system = message.content + first_loop = False else: - system+="\n" - system+=message.content + system += "\n" + system += message.content prompt_message_dicts = [] for message in prompt_messages: @@ -323,10 +326,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -336,7 +336,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): image_content = requests.get(message_content.data).content with Image.open(io.BytesIO(image_content)) as img: mime_type = f"image/{img.format.lower()}" - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") else: @@ -345,16 +345,14 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): base64_data = data_split[1] if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": base64_data - } + "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) @@ -370,8 +368,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): return message_dict - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -384,7 +387,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Google model @@ -394,13 +397,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() - + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: """ Convert tool messages to glm tools @@ -416,14 +416,16 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): type=glm.Type.OBJECT, properties={ key: { - 'type_': value.get('type', 'string').upper(), - 'description': value.get('description', ''), - 'enum': value.get('enum', []) - } for key, value in tool.parameters.get('properties', {}).items() + "type_": value.get("type", "string").upper(), + "description": value.get("description", ""), + "enum": value.get("enum", []), + } + for key, value in tool.parameters.get("properties", {}).items() }, - required=tool.parameters.get('required', []) + required=tool.parameters.get("required", []), ), - ) for tool in tools + ) + for tool in tools ] ) @@ -435,20 +437,25 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): :param credentials: model credentials :return: """ - + try: ping_message = SystemPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) - + except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -462,7 +469,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if stop: config_kwargs["stop_sequences"] = stop @@ -494,26 +501,21 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): else: history.append(content) - safety_settings={ + safety_settings = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } - google_model = glm.GenerativeModel( - model_name=model, - system_instruction=system_instruction - ) + google_model = glm.GenerativeModel(model_name=model, system_instruction=system_instruction) response = google_model.generate_content( contents=history, - generation_config=glm.GenerationConfig( - **config_kwargs - ), + generation_config=glm.GenerationConfig(**config_kwargs), stream=stream, safety_settings=safety_settings, - tools=self._convert_tools_to_glm_tool(tools) if tools else None + tools=self._convert_tools_to_glm_tool(tools) if tools else None, ) if stream: @@ -521,8 +523,9 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -533,9 +536,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.candidates[0].content.parts[0].text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.candidates[0].content.parts[0].text) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -554,8 +555,9 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -568,9 +570,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): index = -1 for chunk in response: for part in chunk.candidates[0].content.parts: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if part.text: assistant_prompt_message.content += part.text @@ -579,35 +579,31 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): assistant_prompt_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=part.function_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=part.function_call.name, - arguments=json.dumps(dict(part.function_call.args.items())) - ) + arguments=json.dumps(dict(part.function_call.args.items())), + ), ) ] index += 1 - - if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason: + + if not hasattr(chunk, "finish_reason") or not chunk.finish_reason: # transform assistant message to prompt message yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: - # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -615,8 +611,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): index=index, message=assistant_prompt_message, finish_reason=chunk.candidates[0].finish_reason, - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -631,9 +627,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" @@ -658,7 +652,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): if isinstance(message, UserPromptMessage): glm_content = glm.Content(role="user", parts=[]) - if (isinstance(message.content, str)): + if isinstance(message.content, str): glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)]) else: parts = [] @@ -666,8 +660,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): if c.type == PromptMessageContentType.TEXT: parts.append(glm.Part.from_text(c.data)) else: - metadata, data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] + metadata, data = c.data.split(",", 1) + mime_type = metadata.split(";", 1)[0].split(":")[1] parts.append(glm.Part.from_data(mime_type=mime_type, data=data)) glm_content = glm.Content(role="user", parts=parts) return glm_content @@ -675,22 +669,33 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): if message.content: glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)]) if message.tool_calls: - glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall( - name=message.tool_calls[0].function.name, - args=json.loads(message.tool_calls[0].function.arguments), - ))]) + glm_content = glm.Content( + role="model", + parts=[ + glm.Part.from_function_response( + glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ) + ) + ], + ) return glm_content elif isinstance(message, ToolPromptMessage): - glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse( - name=message.name, - response={ - "response": message.content - } - ))]) + glm_content = glm.Content( + role="function", + parts=[ + glm.Part( + function_response=glm.FunctionResponse( + name=message.name, response={"response": message.content} + ) + ) + ], + ) return glm_content else: raise ValueError(f"Got unknown type {message}") - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -702,25 +707,20 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): :return: Invoke emd = gml.GenerativeModel(model) error mapping """ return { - InvokeConnectionError: [ - exceptions.RetryError - ], + InvokeConnectionError: [exceptions.RetryError], InvokeServerUnavailableError: [ exceptions.ServiceUnavailable, exceptions.InternalServerError, exceptions.BadGateway, exceptions.GatewayTimeout, - exceptions.DeadlineExceeded - ], - InvokeRateLimitError: [ - exceptions.ResourceExhausted, - exceptions.TooManyRequests + exceptions.DeadlineExceeded, ], + InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests], InvokeAuthorizationError: [ exceptions.Unauthenticated, exceptions.PermissionDenied, exceptions.Unauthenticated, - exceptions.Forbidden + exceptions.Forbidden, ], InvokeBadRequestError: [ exceptions.BadRequest, @@ -736,5 +736,5 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): exceptions.PreconditionFailed, exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, - ] + ], } diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index 2404ba5894..519373a7f3 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -29,9 +29,9 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): Model class for Vertex AI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -51,23 +51,12 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): client = VertexTextEmbeddingModel.from_pretrained(model) - embeddings_batch, embedding_used_tokens = self._embedding_invoke( - client=client, - texts=texts - ) + embeddings_batch, embedding_used_tokens = self._embedding_invoke(client=client, texts=texts) # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=embedding_used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=embedding_used_tokens) - return TextEmbeddingResult( - embeddings=embeddings_batch, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings_batch, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -115,15 +104,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): client = VertexTextEmbeddingModel.from_pretrained(model) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'] - ) + self._embedding_invoke(model=model, client=client, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore + def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore """ Invoke embedding model @@ -154,10 +139,7 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -168,14 +150,14 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -183,15 +165,15 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py index 3cbfb088d1..466a86fd36 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py +++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py @@ -20,12 +20,9 @@ class VertexAiProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `gemini-1.0-pro-002` model for validate, - model_instance.validate_credentials( - model='gemini-1.0-pro-002', - credentials=credentials - ) + model_instance.validate_credentials(model="gemini-1.0-pro-002", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index a4d89dabcb..d6f1356651 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -69,31 +69,26 @@ class ArkClientV3: def from_credentials(cls, credentials): """Initialize the client using the credentials provided.""" args = { - "base_url": credentials['api_endpoint_host'], - "region": credentials['volc_region'], + "base_url": credentials["api_endpoint_host"], + "region": credentials["volc_region"], } if credentials.get("auth_method") == "api_key": args = { **args, - "api_key": credentials['volc_api_key'], + "api_key": credentials["volc_api_key"], } else: args = { **args, - "ak": credentials['volc_access_key_id'], - "sk": credentials['volc_secret_access_key'], + "ak": credentials["volc_access_key_id"], + "sk": credentials["volc_secret_access_key"], } if cls.is_compatible_with_legacy(credentials): - args = { - **args, - "base_url": DEFAULT_V3_ENDPOINT - } + args = {**args, "base_url": DEFAULT_V3_ENDPOINT} - client = ArkClientV3( - **args - ) - client.endpoint_id = credentials['endpoint_id'] + client = ArkClientV3(**args) + client.endpoint_id = credentials["endpoint_id"] return client @staticmethod @@ -107,54 +102,48 @@ class ArkClientV3: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - content.append(ChatCompletionContentPartTextParam( - text=message_content.text, - type='text', - )) + content.append( + ChatCompletionContentPartTextParam( + text=message_content.text, + type="text", + ) + ) elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content) - image_data = re.sub( - r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append(ChatCompletionContentPartImageParam( - image_url=ImageURL( - url=image_data, - detail=message_content.detail.value, - ), - type='image_url', - )) - message_dict = ChatCompletionUserMessageParam( - role='user', - content=content - ) + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) + content.append( + ChatCompletionContentPartImageParam( + image_url=ImageURL( + url=image_data, + detail=message_content.detail.value, + ), + type="image_url", + ) + ) + message_dict = ChatCompletionUserMessageParam(role="user", content=content) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) message_dict = ChatCompletionAssistantMessageParam( content=message.content, - role='assistant', - tool_calls=None if not message.tool_calls else [ + role="assistant", + tool_calls=None + if not message.tool_calls + else [ ChatCompletionMessageToolCallParam( id=call.id, - function=Function( - name=call.function.name, - arguments=call.function.arguments - ), - type='function' - ) for call in message.tool_calls - ] + function=Function(name=call.function.name, arguments=call.function.arguments), + type="function", + ) + for call in message.tool_calls + ], ) elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = ChatCompletionSystemMessageParam( - content=message.content, - role='system' - ) + message_dict = ChatCompletionSystemMessageParam(content=message.content, role="system") elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = ChatCompletionToolMessageParam( - content=message.content, - role='tool', - tool_call_id=message.tool_call_id + content=message.content, role="tool", tool_call_id=message.tool_call_id ) else: raise ValueError(f"Got unknown PromptMessage type {message}") @@ -164,23 +153,25 @@ class ArkClientV3: @staticmethod def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam: return ChatCompletionToolParam( - type='function', + type="function", function=FunctionDefinition( name=message.name, description=message.description, parameters=message.parameters, - ) + ), ) - def chat(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - frequency_penalty: Optional[float] = None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - ) -> ChatCompletion: + def chat( + self, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> ChatCompletion: """Block chat""" return self.ark.chat.completions.create( model=self.endpoint_id, @@ -194,15 +185,17 @@ class ArkClientV3: temperature=temperature, ) - def stream_chat(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - frequency_penalty: Optional[float] = None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - ) -> Generator[ChatCompletionChunk]: + def stream_chat( + self, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> Generator[ChatCompletionChunk]: """Stream chat""" chunks = self.ark.chat.completions.create( stream=True, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py index 1978c11680..025b1ed6d2 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py @@ -25,12 +25,12 @@ class MaaSClient(MaasService): self.endpoint_id = endpoint_id @classmethod - def from_credential(cls, credentials: dict) -> 'MaaSClient': - host = credentials['api_endpoint_host'] - region = credentials['volc_region'] - ak = credentials['volc_access_key_id'] - sk = credentials['volc_secret_access_key'] - endpoint_id = credentials['endpoint_id'] + def from_credential(cls, credentials: dict) -> "MaaSClient": + host = credentials["api_endpoint_host"] + region = credentials["volc_region"] + ak = credentials["volc_access_key_id"] + sk = credentials["volc_secret_access_key"] + endpoint_id = credentials["endpoint_id"] client = cls(host, region) client.set_endpoint_id(endpoint_id) @@ -40,8 +40,8 @@ class MaaSClient(MaasService): def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: req = { - 'parameters': params, - 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], + "parameters": params, + "messages": [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], **extra_model_kwargs, } if not stream: @@ -55,9 +55,7 @@ class MaaSClient(MaasService): ) def embeddings(self, texts: list[str]) -> dict: - req = { - 'input': texts - } + req = {"input": texts} return super().embeddings(self.endpoint_id, req) @staticmethod @@ -65,49 +63,40 @@ class MaaSClient(MaasService): if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": ChatRole.USER, - "content": message.content} + message_dict = {"role": ChatRole.USER, "content": message.content} else: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - raise ValueError( - 'Content object type only support image_url') + raise ValueError("Content object type only support image_url") elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content) - image_data = re.sub( - r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append({ - 'type': 'image_url', - 'image_url': { - 'url': '', - 'image_bytes': image_data, - 'detail': message_content.detail, + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) + content.append( + { + "type": "image_url", + "image_url": { + "url": "", + "image_bytes": image_data, + "detail": message_content.detail, + }, } - }) + ) - message_dict = {'role': ChatRole.USER, 'content': content} + message_dict = {"role": ChatRole.USER, "content": content} elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - message_dict = {'role': ChatRole.ASSISTANT, - 'content': message.content} + message_dict = {"role": ChatRole.ASSISTANT, "content": message.content} if message.tool_calls: - message_dict['tool_calls'] = [ - { - 'name': call.function.name, - 'arguments': call.function.arguments - } for call in message.tool_calls + message_dict["tool_calls"] = [ + {"name": call.function.name, "arguments": call.function.arguments} for call in message.tool_calls ] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {'role': ChatRole.SYSTEM, - 'content': message.content} + message_dict = {"role": ChatRole.SYSTEM, "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = {'role': ChatRole.FUNCTION, - 'content': message.content, - 'name': message.tool_call_id} + message_dict = {"role": ChatRole.FUNCTION, "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown PromptMessage type {message}") @@ -130,5 +119,5 @@ class MaaSClient(MaasService): "name": tool.name, "description": tool.description, "parameters": tool.parameters, - } + }, } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 21ffaf1258..8b9c346265 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -102,43 +102,43 @@ class ServiceNotOpen(MaasException): AuthErrors = { - 'SignatureDoesNotMatch': SignatureDoesNotMatch, - 'MissingAuthenticationHeader': MissingAuthenticationHeader, - 'AuthenticationHeaderIsInvalid': AuthenticationHeaderIsInvalid, - 'AuthenticationExpire': AuthenticationExpire, - 'UnauthorizedUserForEndpoint': UnauthorizedUserForEndpoint, + "SignatureDoesNotMatch": SignatureDoesNotMatch, + "MissingAuthenticationHeader": MissingAuthenticationHeader, + "AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalid, + "AuthenticationExpire": AuthenticationExpire, + "UnauthorizedUserForEndpoint": UnauthorizedUserForEndpoint, } BadRequestErrors = { - 'MissingParameter': MissingParameter, - 'InvalidParameter': InvalidParameter, - 'EndpointIsInvalid': EndpointIsInvalid, - 'EndpointIsNotEnable': EndpointIsNotEnable, - 'ModelNotSupportStreamMode': ModelNotSupportStreamMode, - 'ReqTextExistRisk': ReqTextExistRisk, - 'RespTextExistRisk': RespTextExistRisk, - 'InvalidEndpointWithNoURL': InvalidEndpointWithNoURL, - 'ServiceNotOpen': ServiceNotOpen, + "MissingParameter": MissingParameter, + "InvalidParameter": InvalidParameter, + "EndpointIsInvalid": EndpointIsInvalid, + "EndpointIsNotEnable": EndpointIsNotEnable, + "ModelNotSupportStreamMode": ModelNotSupportStreamMode, + "ReqTextExistRisk": ReqTextExistRisk, + "RespTextExistRisk": RespTextExistRisk, + "InvalidEndpointWithNoURL": InvalidEndpointWithNoURL, + "ServiceNotOpen": ServiceNotOpen, } RateLimitErrors = { - 'EndpointRateLimitExceeded': EndpointRateLimitExceeded, - 'EndpointAccountRpmRateLimitExceeded': EndpointAccountRpmRateLimitExceeded, - 'EndpointAccountTpmRateLimitExceeded': EndpointAccountTpmRateLimitExceeded, + "EndpointRateLimitExceeded": EndpointRateLimitExceeded, + "EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceeded, + "EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceeded, } ServerUnavailableErrors = { - 'InternalServiceError': InternalServiceError, - 'EndpointIsPending': EndpointIsPending, - 'ServiceResourceWaitQueueFull': ServiceResourceWaitQueueFull, + "InternalServiceError": InternalServiceError, + "EndpointIsPending": EndpointIsPending, + "ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFull, } ConnectionErrors = { - 'ClientSDKRequestError': ClientSDKRequestError, - 'RequestTimeout': RequestTimeout, - 'ServiceConnectionTimeout': ServiceConnectionTimeout, - 'ServiceConnectionRefused': ServiceConnectionRefused, - 'ServiceConnectionClosed': ServiceConnectionClosed, + "ClientSDKRequestError": ClientSDKRequestError, + "RequestTimeout": RequestTimeout, + "ServiceConnectionTimeout": ServiceConnectionTimeout, + "ServiceConnectionRefused": ServiceConnectionRefused, + "ServiceConnectionClosed": ServiceConnectionClosed, } ErrorCodeMap = { diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py index 64f342f16e..53f320736b 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py @@ -1,4 +1,4 @@ from .common import ChatRole from .maas import MaasException, MaasService -__all__ = ['MaasService', 'ChatRole', 'MaasException'] +__all__ = ["MaasService", "ChatRole", "MaasException"] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py index 053432a089..8f8139426c 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py @@ -8,12 +8,12 @@ from .util import Util class MetaData: def __init__(self): - self.algorithm = '' - self.credential_scope = '' - self.signed_headers = '' - self.date = '' - self.region = '' - self.service = '' + self.algorithm = "" + self.credential_scope = "" + self.signed_headers = "" + self.date = "" + self.region = "" + self.service = "" def set_date(self, date): self.date = date @@ -36,23 +36,23 @@ class MetaData: class SignResult: def __init__(self): - self.xdate = '' - self.xCredential = '' - self.xAlgorithm = '' - self.xSignedHeaders = '' - self.xSignedQueries = '' - self.xSignature = '' - self.xContextSha256 = '' - self.xSecurityToken = '' + self.xdate = "" + self.xCredential = "" + self.xAlgorithm = "" + self.xSignedHeaders = "" + self.xSignedQueries = "" + self.xSignature = "" + self.xContextSha256 = "" + self.xSecurityToken = "" - self.authorization = '' + self.authorization = "" def __str__(self): - return '\n'.join(['{}:{}'.format(*item) for item in self.__dict__.items()]) + return "\n".join(["{}:{}".format(*item) for item in self.__dict__.items()]) class Credentials: - def __init__(self, ak, sk, service, region, session_token=''): + def __init__(self, ak, sk, service, region, session_token=""): self.ak = ak self.sk = sk self.service = service @@ -72,73 +72,88 @@ class Credentials: class Signer: @staticmethod def sign(request, credentials): - if request.path == '': - request.path = '/' - if request.method != 'GET' and not ('Content-Type' in request.headers): - request.headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8' + if request.path == "": + request.path = "/" + if request.method != "GET" and not ("Content-Type" in request.headers): + request.headers["Content-Type"] = "application/x-www-form-urlencoded; charset=utf-8" format_date = Signer.get_current_format_date() - request.headers['X-Date'] = format_date - if credentials.session_token != '': - request.headers['X-Security-Token'] = credentials.session_token + request.headers["X-Date"] = format_date + if credentials.session_token != "": + request.headers["X-Security-Token"] = credentials.session_token md = MetaData() - md.set_algorithm('HMAC-SHA256') + md.set_algorithm("HMAC-SHA256") md.set_service(credentials.service) md.set_region(credentials.region) md.set_date(format_date[:8]) hashed_canon_req = Signer.hashed_canonical_request_v4(request, md) - md.set_credential_scope('/'.join([md.date, md.region, md.service, 'request'])) + md.set_credential_scope("/".join([md.date, md.region, md.service, "request"])) - signing_str = '\n'.join([md.algorithm, format_date, md.credential_scope, hashed_canon_req]) + signing_str = "\n".join([md.algorithm, format_date, md.credential_scope, hashed_canon_req]) signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) - request.headers['Authorization'] = Signer.build_auth_header_v4(sign, md, credentials) + request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials) return @staticmethod def hashed_canonical_request_v4(request, meta): body_hash = Util.sha256(request.body) - request.headers['X-Content-Sha256'] = body_hash + request.headers["X-Content-Sha256"] = body_hash signed_headers = {} for key in request.headers: - if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'): + if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"): signed_headers[key.lower()] = request.headers[key] - if 'host' in signed_headers: - v = signed_headers['host'] - if v.find(':') != -1: - split = v.split(':') + if "host" in signed_headers: + v = signed_headers["host"] + if v.find(":") != -1: + split = v.split(":") port = split[1] - if str(port) == '80' or str(port) == '443': - signed_headers['host'] = split[0] + if str(port) == "80" or str(port) == "443": + signed_headers["host"] = split[0] - signed_str = '' + signed_str = "" for key in sorted(signed_headers.keys()): - signed_str += key + ':' + signed_headers[key] + '\n' + signed_str += key + ":" + signed_headers[key] + "\n" - meta.set_signed_headers(';'.join(sorted(signed_headers.keys()))) + meta.set_signed_headers(";".join(sorted(signed_headers.keys()))) - canonical_request = '\n'.join( - [request.method, Util.norm_uri(request.path), Util.norm_query(request.query), signed_str, - meta.signed_headers, body_hash]) + canonical_request = "\n".join( + [ + request.method, + Util.norm_uri(request.path), + Util.norm_query(request.query), + signed_str, + meta.signed_headers, + body_hash, + ] + ) return Util.sha256(canonical_request) @staticmethod def get_signing_secret_key_v4(sk, date, region, service): - date = Util.hmac_sha256(bytes(sk, encoding='utf-8'), date) + date = Util.hmac_sha256(bytes(sk, encoding="utf-8"), date) region = Util.hmac_sha256(date, region) service = Util.hmac_sha256(region, service) - return Util.hmac_sha256(service, 'request') + return Util.hmac_sha256(service, "request") @staticmethod def build_auth_header_v4(signature, meta, credentials): - credential = credentials.ak + '/' + meta.credential_scope - return meta.algorithm + ' Credential=' + credential + ', SignedHeaders=' + meta.signed_headers + ', Signature=' + signature + credential = credentials.ak + "/" + meta.credential_scope + return ( + meta.algorithm + + " Credential=" + + credential + + ", SignedHeaders=" + + meta.signed_headers + + ", Signature=" + + signature + ) @staticmethod def get_current_format_date(): - return datetime.datetime.now(tz=pytz.timezone('UTC')).strftime("%Y%m%dT%H%M%SZ") + return datetime.datetime.now(tz=pytz.timezone("UTC")).strftime("%Y%m%dT%H%M%SZ") diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py index 7271ae63fd..096339b3c7 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py @@ -6,7 +6,7 @@ import requests from .auth import Signer -VERSION = 'v1.0.137' +VERSION = "v1.0.137" class Service: @@ -40,8 +40,9 @@ class Service: Signer.sign(r, self.service_info.credentials) url = r.build(doseq) - resp = self.session.get(url, headers=r.headers, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + resp = self.session.get( + url, headers=r.headers, timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout) + ) if resp.status_code == 200: return resp.text else: @@ -52,15 +53,19 @@ class Service: raise Exception("no such api") api_info = self.api_info[api] r = self.prepare_request(api_info, params) - r.headers['Content-Type'] = 'application/x-www-form-urlencoded' + r.headers["Content-Type"] = "application/x-www-form-urlencoded" r.form = self.merge(api_info.form, form) r.body = urlencode(r.form, True) Signer.sign(r, self.service_info.credentials) url = r.build() - resp = self.session.post(url, headers=r.headers, data=r.form, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + resp = self.session.post( + url, + headers=r.headers, + data=r.form, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout), + ) if resp.status_code == 200: return resp.text else: @@ -71,21 +76,25 @@ class Service: raise Exception("no such api") api_info = self.api_info[api] r = self.prepare_request(api_info, params) - r.headers['Content-Type'] = 'application/json' + r.headers["Content-Type"] = "application/json" r.body = body Signer.sign(r, self.service_info.credentials) url = r.build() - resp = self.session.post(url, headers=r.headers, data=r.body, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) + resp = self.session.post( + url, + headers=r.headers, + data=r.body, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout), + ) if resp.status_code == 200: return json.dumps(resp.json()) else: raise Exception(resp.text.encode("utf-8")) def put(self, url, file_path, headers): - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: resp = self.session.put(url, headers=headers, data=f) if resp.status_code == 200: return True, resp.text.encode("utf-8") @@ -105,7 +114,7 @@ class Service: params[key] = str(params[key]) elif type(params[key]) == list: if not doseq: - params[key] = ','.join(params[key]) + params[key] = ",".join(params[key]) connection_timeout = self.service_info.connection_timeout socket_timeout = self.service_info.socket_timeout @@ -117,8 +126,8 @@ class Service: r.set_socket_timeout(socket_timeout) headers = self.merge(api_info.header, self.service_info.header) - headers['Host'] = self.service_info.host - headers['User-Agent'] = 'volc-sdk-python/' + VERSION + headers["Host"] = self.service_info.host + headers["User-Agent"] = "volc-sdk-python/" + VERSION r.set_headers(headers) query = self.merge(api_info.query, params) @@ -143,13 +152,13 @@ class Service: class Request: def __init__(self): - self.schema = '' - self.method = '' - self.host = '' - self.path = '' + self.schema = "" + self.method = "" + self.host = "" + self.path = "" self.headers = OrderedDict() self.query = OrderedDict() - self.body = '' + self.body = "" self.form = {} self.connection_timeout = 0 self.socket_timeout = 0 @@ -182,11 +191,11 @@ class Request: self.socket_timeout = socket_timeout def build(self, doseq=0): - return self.schema + '://' + self.host + self.path + '?' + urlencode(self.query, doseq) + return self.schema + "://" + self.host + self.path + "?" + urlencode(self.query, doseq) class ServiceInfo: - def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme='http'): + def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme="http"): self.host = host self.header = header self.credentials = credentials @@ -204,4 +213,4 @@ class ApiInfo: self.header = header def __str__(self): - return 'method: ' + self.method + ', path: ' + self.path + return "method: " + self.method + ", path: " + self.path diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py index 7eb5fdfa91..44f9959965 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py @@ -7,28 +7,28 @@ from urllib.parse import quote class Util: @staticmethod def norm_uri(path): - return quote(path).replace('%2F', '/').replace('+', '%20') + return quote(path).replace("%2F", "/").replace("+", "%20") @staticmethod def norm_query(params): - query = '' + query = "" for key in sorted(params.keys()): if type(params[key]) == list: for k in params[key]: - query = query + quote(key, safe='-_.~') + '=' + quote(k, safe='-_.~') + '&' + query = query + quote(key, safe="-_.~") + "=" + quote(k, safe="-_.~") + "&" else: - query = query + quote(key, safe='-_.~') + '=' + quote(params[key], safe='-_.~') + '&' + query = query + quote(key, safe="-_.~") + "=" + quote(params[key], safe="-_.~") + "&" query = query[:-1] - return query.replace('+', '%20') + return query.replace("+", "%20") @staticmethod def hmac_sha256(key, content): - return hmac.new(key, bytes(content, encoding='utf-8'), hashlib.sha256).digest() + return hmac.new(key, bytes(content, encoding="utf-8"), hashlib.sha256).digest() @staticmethod def sha256(content): if isinstance(content, str) is True: - return hashlib.sha256(content.encode('utf-8')).hexdigest() + return hashlib.sha256(content.encode("utf-8")).hexdigest() else: return hashlib.sha256(content).hexdigest() @@ -36,8 +36,8 @@ class Util: def to_hex(content): lst = [] for ch in content: - hv = hex(ch).replace('0x', '') + hv = hex(ch).replace("0x", "") if len(hv) == 1: - hv = '0' + hv + hv = "0" + hv lst.append(hv) return reduce(lambda x, y: x + y, lst) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py index 8b14d026d9..3825fd6574 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py @@ -43,9 +43,7 @@ def json_to_object(json_str, req_id=None): def gen_req_id(): - return datetime.now().strftime("%Y%m%d%H%M%S") + format( - random.randint(0, 2 ** 64 - 1), "020X" - ) + return datetime.now().strftime("%Y%m%d%H%M%S") + format(random.randint(0, 2**64 - 1), "020X") class SSEDecoder: @@ -53,13 +51,13 @@ class SSEDecoder: self.source = source def _read(self): - data = b'' + data = b"" for chunk in self.source: for line in chunk.splitlines(True): data += line - if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')): + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): yield data - data = b'' + data = b"" if data: yield data @@ -67,13 +65,13 @@ class SSEDecoder: for chunk in self._read(): for line in chunk.splitlines(): # skip comment - if line.startswith(b':'): + if line.startswith(b":"): continue - if b':' in line: - field, value = line.split(b':', 1) + if b":" in line: + field, value = line.split(b":", 1) else: - field, value = line, b'' + field, value = line, b"" - if field == b'data' and len(value) > 0: + if field == b"data" and len(value) > 0: yield value diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py index 3cbe9d9f09..01f15aec24 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py @@ -9,9 +9,7 @@ from .common import SSEDecoder, dict_to_object, gen_req_id, json_to_object class MaasService(Service): def __init__(self, host, region, connection_timeout=60, socket_timeout=60): - service_info = self.get_service_info( - host, region, connection_timeout, socket_timeout - ) + service_info = self.get_service_info(host, region, connection_timeout, socket_timeout) self._apikey = None api_info = self.get_api_info() super().__init__(service_info, api_info) @@ -35,9 +33,7 @@ class MaasService(Service): def get_api_info(): api_info = { "chat": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/chat", {}, {}, {}), - "embeddings": ApiInfo( - "POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {} - ), + "embeddings": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {}), } return api_info @@ -52,9 +48,7 @@ class MaasService(Service): try: req["stream"] = True - res = self._call( - endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True - ) + res = self._call(endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True) decoder = SSEDecoder(res) @@ -64,8 +58,7 @@ class MaasService(Service): return try: - res = json_to_object( - str(data, encoding="utf-8"), req_id=req_id) + res = json_to_object(str(data, encoding="utf-8"), req_id=req_id) except Exception: raise @@ -95,8 +88,7 @@ class MaasService(Service): apikey = self._apikey try: - res = self._call(endpoint_id, api, req_id, params, - json.dumps(req).encode("utf-8"), apikey) + res = self._call(endpoint_id, api, req_id, params, json.dumps(req).encode("utf-8"), apikey) resp = dict_to_object(res.json()) if resp and isinstance(resp, dict): resp["req_id"] = req_id @@ -109,9 +101,9 @@ class MaasService(Service): def _validate(self, api, req_id): credentials_exist = ( - self.service_info.credentials is not None and - self.service_info.credentials.sk is not None and - self.service_info.credentials.ak is not None + self.service_info.credentials is not None + and self.service_info.credentials.sk is not None + and self.service_info.credentials.ak is not None ) if not self._apikey and not credentials_exist: @@ -150,15 +142,12 @@ class MaasService(Service): raw = res.text.encode() res.close() try: - resp = json_to_object( - str(raw, encoding="utf-8"), req_id=req_id) + resp = json_to_object(str(raw, encoding="utf-8"), req_id=req_id) except Exception: raise new_client_sdk_request_error(raw, req_id) if resp.error: - raise MaasException( - resp.error.code_n, resp.error.code, resp.error.message, req_id - ) + raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, req_id) else: raise new_client_sdk_request_error(resp, req_id) @@ -173,11 +162,13 @@ class MaasException(Exception): self.req_id = req_id def __str__(self): - return ("Detailed exception information is listed below.\n" + - "req_id: {}\n" + - "code_n: {}\n" + - "code: {}\n" + - "message: {}").format(self.req_id, self.code_n, self.code, self.message) + return ( + "Detailed exception information is listed below.\n" + + "req_id: {}\n" + + "code_n: {}\n" + + "code: {}\n" + + "message: {}" + ).format(self.req_id, self.code_n, self.code, self.message) def new_client_sdk_request_error(raw, req_id=""): @@ -189,25 +180,19 @@ class BinaryResponseContent: self.response = response self.request_id = request_id - def stream_to_file( - self, - file: str - ) -> None: + def stream_to_file(self, file: str) -> None: is_first = True - error_bytes = b'' + error_bytes = b"" with open(file, mode="wb") as f: for data in self.response: - if len(error_bytes) > 0 or (is_first and "\"error\":" in str(data)): + if len(error_bytes) > 0 or (is_first and '"error":' in str(data)): error_bytes += data else: f.write(data) if len(error_bytes) > 0: - resp = json_to_object( - str(error_bytes, encoding="utf-8"), req_id=self.request_id) - raise MaasException( - resp.error.code_n, resp.error.code, resp.error.message, self.request_id - ) + resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id) + raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, self.request_id) def iter_bytes(self) -> Iterator[bytes]: yield from self.response diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 996c66e604..98409ab872 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -49,10 +49,17 @@ logger = logging.getLogger(__name__) class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: if ArkClientV3.is_legacy(credentials): return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -71,12 +78,12 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): try: client.chat( { - 'max_new_tokens': 16, - 'temperature': 0.7, - 'top_p': 0.9, - 'top_k': 15, + "max_new_tokens": 16, + "temperature": 0.7, + "top_p": 0.9, + "top_k": 15, }, - [UserPromptMessage(content='ping\nAnswer: ')], + [UserPromptMessage(content="ping\nAnswer: ")], ) except MaasException as e: raise CredentialsValidateFailedError(e.message) @@ -85,13 +92,22 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): def _validate_credentials_v3(credentials: dict) -> None: client = ArkClientV3.from_credentials(credentials) try: - client.chat(max_tokens=16, temperature=0.7, top_p=0.9, - messages=[UserPromptMessage(content='ping\nAnswer: ')], ) + client.chat( + max_tokens=16, + temperature=0.7, + top_p=0.9, + messages=[UserPromptMessage(content="ping\nAnswer: ")], + ) except Exception as e: raise CredentialsValidateFailedError(e) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: if ArkClientV3.is_legacy(credentials): return self._get_num_tokens_v2(prompt_messages) return self._get_num_tokens_v3(prompt_messages) @@ -100,8 +116,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): if len(messages) == 0: return 0 num_tokens = 0 - messages_dict = [ - MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] + messages_dict = [MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] for message in messages_dict: for key, value in message.items(): num_tokens += self._get_num_tokens_by_gpt2(str(key)) @@ -113,8 +128,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): if len(messages) == 0: return 0 num_tokens = 0 - messages_dict = [ - ArkClientV3.convert_prompt_message(m) for m in messages] + messages_dict = [ArkClientV3.convert_prompt_message(m) for m in messages] for message in messages_dict: for key, value in message.items(): num_tokens += self._get_num_tokens_by_gpt2(str(key)) @@ -122,97 +136,108 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): return num_tokens - def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate_v2( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = MaaSClient.from_credential(credentials) req_params = get_v2_req_params(credentials, model_parameters, stop) extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = [ - MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools - ] - resp = MaaSClient.wrap_exception( - lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) + extra_model_kwargs["tools"] = [MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools] + resp = MaaSClient.wrap_exception(lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) def _handle_stream_chat_response() -> Generator: for index, r in enumerate(resp): - choices = r['choices'] + choices = r["choices"] if not choices: continue choice = choices[0] - message = choice['message'] + message = choice["message"] usage = None - if r.get('usage'): - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=r['usage']['prompt_tokens'], - completion_tokens=r['usage']['completion_tokens'] - ) + if r.get("usage"): + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=r["usage"]["prompt_tokens"], + completion_tokens=r["usage"]["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=[] + content=message["content"] if message["content"] else "", tool_calls=[] ), usage=usage, - finish_reason=choice.get('finish_reason'), + finish_reason=choice.get("finish_reason"), ), ) def _handle_chat_response() -> LLMResult: - choices = resp['choices'] + choices = resp["choices"] if not choices: raise ValueError("No choices found") choice = choices[0] - message = choice['message'] + message = choice["message"] # parse tool calls tool_calls = [] - if message['tool_calls']: - for call in message['tool_calls']: + if message["tool_calls"]: + for call in message["tool_calls"]: tool_call = AssistantPromptMessage.ToolCall( - id=call['function']['name'], - type=call['type'], + id=call["function"]["name"], + type=call["type"], function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call['function']['name'], - arguments=call['function']['arguments'] - ) + name=call["function"]["name"], arguments=call["function"]["arguments"] + ), ) tool_calls.append(tool_call) - usage = resp['usage'] + usage = resp["usage"] return LLMResult( model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', + content=message["content"] if message["content"] else "", tool_calls=tool_calls, ), - usage=self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'] - ), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=usage["prompt_tokens"], + completion_tokens=usage["completion_tokens"], + ), ) if not stream: return _handle_chat_response() return _handle_stream_chat_response() - def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate_v3( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = ArkClientV3.from_credentials(credentials) req_params = get_v3_req_params(credentials, model_parameters, stop) if tools: - req_params['tools'] = tools + req_params["tools"] = tools def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator: for chunk in chunks: @@ -225,14 +250,15 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=choice.index, - message=AssistantPromptMessage( - content=choice.delta.content, - tool_calls=[] - ), - usage=self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=chunk.usage.prompt_tokens, - completion_tokens=chunk.usage.completion_tokens - ) if chunk.usage else None, + message=AssistantPromptMessage(content=choice.delta.content, tool_calls=[]), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + ) + if chunk.usage + else None, finish_reason=choice.finish_reason, ), ) @@ -248,9 +274,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): id=call.id, type=call.type, function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call.function.name, - arguments=call.function.arguments - ) + name=call.function.name, arguments=call.function.arguments + ), ) tool_calls.append(tool_call) @@ -262,10 +287,12 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): content=message.content if message.content else "", tool_calls=tool_calls, ), - usage=self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage.prompt_tokens, - completion_tokens=usage.completion_tokens - ), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + ), ) if not stream: @@ -277,72 +304,56 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ model_config = get_model_config(credentials) rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='top_k', - type=ParameterType.INT, - min=1, - default=1, - label=I18nObject( - zh_Hans='Top K', - en_US='Top K' - ) + name="top_k", type=ParameterType.INT, min=1, default=1, label=I18nObject(zh_Hans="Top K", en_US="Top K") ), ParameterRule( - name='presence_penalty', + name="presence_penalty", type=ParameterType.FLOAT, - use_template='presence_penalty', + use_template="presence_penalty", label=I18nObject( - en_US='Presence Penalty', - zh_Hans='存在惩罚', + en_US="Presence Penalty", + zh_Hans="存在惩罚", ), min=-2.0, max=2.0, ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", type=ParameterType.FLOAT, - use_template='frequency_penalty', + use_template="frequency_penalty", label=I18nObject( - en_US='Frequency Penalty', - zh_Hans='频率惩罚', + en_US="Frequency Penalty", + zh_Hans="频率惩罚", ), min=-2.0, max=2.0, ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, max=model_config.properties.max_tokens, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), ), ] @@ -352,9 +363,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties=model_properties, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index a882f68a36..d8be14b024 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -16,138 +16,127 @@ class ModelConfig(BaseModel): configs: dict[str, ModelConfig] = { - 'Doubao-pro-4k': ModelConfig( + "Doubao-pro-4k": ModelConfig( properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-lite-4k': ModelConfig( + "Doubao-lite-4k": ModelConfig( properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-pro-32k': ModelConfig( + "Doubao-pro-32k": ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-lite-32k': ModelConfig( + "Doubao-lite-32k": ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-pro-128k': ModelConfig( + "Doubao-pro-128k": ModelConfig( properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Doubao-lite-128k': ModelConfig( - properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), - features=[] + "Doubao-lite-128k": ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[] ), - 'Skylark2-pro-4k': ModelConfig( - properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), - features=[] + "Skylark2-pro-4k": ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[] ), - 'Llama3-8B': ModelConfig( - properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), - features=[] + "Llama3-8B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[] ), - 'Llama3-70B': ModelConfig( - properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), - features=[] + "Llama3-70B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[] ), - 'Moonshot-v1-8k': ModelConfig( + "Moonshot-v1-8k": ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Moonshot-v1-32k': ModelConfig( + "Moonshot-v1-32k": ModelConfig( properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'Moonshot-v1-128k': ModelConfig( + "Moonshot-v1-128k": ModelConfig( properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'GLM3-130B': ModelConfig( + "GLM3-130B": ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], ), - 'GLM3-130B-Fin': ModelConfig( + "GLM3-130B-Fin": ModelConfig( properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), - features=[ModelFeature.TOOL_CALL] + features=[ModelFeature.TOOL_CALL], + ), + "Mistral-7B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), features=[] ), - 'Mistral-7B': ModelConfig( - properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), - features=[] - ) } def get_model_config(credentials: dict) -> ModelConfig: - base_model = credentials.get('base_model_name', '') + base_model = credentials.get("base_model_name", "") model_configs = configs.get(base_model) if not model_configs: return ModelConfig( properties=ModelProperties( - context_size=int(credentials.get('context_size', 0)), - max_tokens=int(credentials.get('max_tokens', 0)), - mode=LLMMode.value_of(credentials.get('mode', 'chat')), + context_size=int(credentials.get("context_size", 0)), + max_tokens=int(credentials.get("max_tokens", 0)), + mode=LLMMode.value_of(credentials.get("mode", "chat")), ), - features=[] + features=[], ) return model_configs -def get_v2_req_params(credentials: dict, model_parameters: dict, - stop: list[str] | None = None): +def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): req_params = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: - req_params['max_prompt_tokens'] = model_configs.properties.context_size - req_params['max_new_tokens'] = model_configs.properties.max_tokens + req_params["max_prompt_tokens"] = model_configs.properties.context_size + req_params["max_new_tokens"] = model_configs.properties.max_tokens # model parameters - if model_parameters.get('max_tokens'): - req_params['max_new_tokens'] = model_parameters.get('max_tokens') - if model_parameters.get('temperature'): - req_params['temperature'] = model_parameters.get('temperature') - if model_parameters.get('top_p'): - req_params['top_p'] = model_parameters.get('top_p') - if model_parameters.get('top_k'): - req_params['top_k'] = model_parameters.get('top_k') - if model_parameters.get('presence_penalty'): - req_params['presence_penalty'] = model_parameters.get( - 'presence_penalty') - if model_parameters.get('frequency_penalty'): - req_params['frequency_penalty'] = model_parameters.get( - 'frequency_penalty') + if model_parameters.get("max_tokens"): + req_params["max_new_tokens"] = model_parameters.get("max_tokens") + if model_parameters.get("temperature"): + req_params["temperature"] = model_parameters.get("temperature") + if model_parameters.get("top_p"): + req_params["top_p"] = model_parameters.get("top_p") + if model_parameters.get("top_k"): + req_params["top_k"] = model_parameters.get("top_k") + if model_parameters.get("presence_penalty"): + req_params["presence_penalty"] = model_parameters.get("presence_penalty") + if model_parameters.get("frequency_penalty"): + req_params["frequency_penalty"] = model_parameters.get("frequency_penalty") if stop: - req_params['stop'] = stop + req_params["stop"] = stop return req_params -def get_v3_req_params(credentials: dict, model_parameters: dict, - stop: list[str] | None = None): +def get_v3_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): req_params = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: - req_params['max_tokens'] = model_configs.properties.max_tokens + req_params["max_tokens"] = model_configs.properties.max_tokens # model parameters - if model_parameters.get('max_tokens'): - req_params['max_tokens'] = model_parameters.get('max_tokens') - if model_parameters.get('temperature'): - req_params['temperature'] = model_parameters.get('temperature') - if model_parameters.get('top_p'): - req_params['top_p'] = model_parameters.get('top_p') - if model_parameters.get('presence_penalty'): - req_params['presence_penalty'] = model_parameters.get( - 'presence_penalty') - if model_parameters.get('frequency_penalty'): - req_params['frequency_penalty'] = model_parameters.get( - 'frequency_penalty') + if model_parameters.get("max_tokens"): + req_params["max_tokens"] = model_parameters.get("max_tokens") + if model_parameters.get("temperature"): + req_params["temperature"] = model_parameters.get("temperature") + if model_parameters.get("top_p"): + req_params["top_p"] = model_parameters.get("top_p") + if model_parameters.get("presence_penalty"): + req_params["presence_penalty"] = model_parameters.get("presence_penalty") + if model_parameters.get("frequency_penalty"): + req_params["frequency_penalty"] = model_parameters.get("frequency_penalty") if stop: - req_params['stop'] = stop + req_params["stop"] = stop return req_params diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 74cf26247c..ce4f0c3ab1 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -11,20 +11,18 @@ class ModelConfig(BaseModel): ModelConfigs = { - 'Doubao-embedding': ModelConfig( - properties=ModelProperties(context_size=4096, max_chunks=32) - ), + "Doubao-embedding": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)), } def get_model_config(credentials: dict) -> ModelConfig: - base_model = credentials.get('base_model_name', '') + base_model = credentials.get("base_model_name", "") model_configs = ModelConfigs.get(base_model) if not model_configs: return ModelConfig( properties=ModelProperties( - context_size=int(credentials.get('context_size', 0)), - max_chunks=int(credentials.get('max_chunks', 0)), + context_size=int(credentials.get("context_size", 0)), + max_chunks=int(credentials.get("max_chunks", 0)), ) ) return model_configs diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index d54aeeb0b1..3cdcd2740c 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -40,9 +40,9 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): Model class for VolcengineMaaS text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -57,37 +57,27 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): return self._generate_v3(model, credentials, texts, user) - def _generate_v2(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _generate_v2( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: client = MaaSClient.from_credential(credentials) resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=resp['usage']['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=resp["usage"]["total_tokens"]) - result = TextEmbeddingResult( - model=model, - embeddings=[v['embedding'] for v in resp['data']], - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=[v["embedding"] for v in resp["data"]], usage=usage) return result - def _generate_v3(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _generate_v3( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: client = ArkClientV3.from_credentials(credentials) resp = client.embeddings(texts) - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=resp.usage.total_tokens) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=resp.usage.total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=[v.embedding for v in resp.data], - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=[v.embedding for v in resp.data], usage=usage) return result @@ -120,13 +110,13 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): def _validate_credentials_v2(self, model: str, credentials: dict) -> None: try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except MaasException as e: raise CredentialsValidateFailedError(e.message) def _validate_credentials_v3(self, model: str, credentials: dict) -> None: try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as e: raise CredentialsValidateFailedError(e) @@ -150,12 +140,12 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ model_config = get_model_config(credentials) model_properties = { ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size, - ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks + ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks, } entity = AIModelEntity( model=model, @@ -165,10 +155,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): model_properties=model_properties, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity @@ -184,10 +174,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -198,7 +185,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py index 017856bdde..d72d1bd83a 100644 --- a/api/core/model_runtime/model_providers/wenxin/_common.py +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -11,7 +11,7 @@ from core.model_runtime.model_providers.wenxin.wenxin_errors import ( RateLimitReachedError, ) -baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} +baidu_access_tokens: dict[str, "BaiduAccessToken"] = {} baidu_access_tokens_lock = Lock() @@ -22,49 +22,46 @@ class BaiduAccessToken: def __init__(self, api_key: str) -> None: self.api_key = api_key - self.access_token = '' + self.access_token = "" self.expires = datetime.now() + timedelta(days=3) @staticmethod def _get_access_token(api_key: str, secret_key: str) -> str: """ - request access token from Baidu + request access token from Baidu """ try: response = post( - url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', - headers={ - 'Content-Type': 'application/json', - 'Accept': 'application/json' - }, + url=f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}", + headers={"Content-Type": "application/json", "Accept": "application/json"}, ) except Exception as e: - raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') + raise InvalidAuthenticationError(f"Failed to get access token from Baidu: {e}") resp = response.json() - if 'error' in resp: - if resp['error'] == 'invalid_client': + if "error" in resp: + if resp["error"] == "invalid_client": raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') - elif resp['error'] == 'unknown_error': + elif resp["error"] == "unknown_error": raise InternalServerError(f'Internal server error: {resp["error_description"]}') - elif resp['error'] == 'invalid_request': + elif resp["error"] == "invalid_request": raise BadRequestError(f'Bad request: {resp["error_description"]}') - elif resp['error'] == 'rate_limit_exceeded': + elif resp["error"] == "rate_limit_exceeded": raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') else: raise Exception(f'Unknown error: {resp["error_description"]}') - return resp['access_token'] + return resp["access_token"] @staticmethod - def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': + def get_access_token(api_key: str, secret_key: str) -> "BaiduAccessToken": """ - LLM from Baidu requires access token to invoke the API. - however, we have api_key and secret_key, and access token is valid for 30 days. - so we can cache the access token for 3 days. (avoid memory leak) + LLM from Baidu requires access token to invoke the API. + however, we have api_key and secret_key, and access token is valid for 30 days. + so we can cache the access token for 3 days. (avoid memory leak) - it may be more efficient to use a ticker to refresh access token, but it will cause - more complexity, so we just refresh access tokens when get_access_token is called. + it may be more efficient to use a ticker to refresh access token, but it will cause + more complexity, so we just refresh access tokens when get_access_token is called. """ # loop up cache, remove expired access token @@ -98,49 +95,49 @@ class BaiduAccessToken: class _CommonWenxin: api_bases = { - 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', - 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', - 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', - 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', - 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', - 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', - 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', - 'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1', - 'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en', - 'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh', - 'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k', + "ernie-bot": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205", + "ernie-bot-4": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-bot-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions", + "ernie-bot-turbo": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", + "ernie-3.5-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions", + "ernie-3.5-8k-0205": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205", + "ernie-3.5-8k-1222": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222", + "ernie-3.5-4k-0205": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205", + "ernie-3.5-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k", + "ernie-4.0-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-4.0-8k-latest": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-speed-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed", + "ernie-speed-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k", + "ernie-speed-appbuilder": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas", + "ernie-lite-8k-0922": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", + "ernie-lite-8k-0308": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k", + "ernie-character-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k", + "ernie-character-8k-0321": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k", + "ernie-4.0-turbo-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k", + "ernie-4.0-turbo-8k-preview": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview", + "yi_34b_chat": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat", + "embedding-v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1", + "bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en", + "bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh", + "tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k", } function_calling_supports = [ - 'ernie-bot', - 'ernie-bot-8k', - 'ernie-3.5-8k', - 'ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205', - 'ernie-3.5-128k', - 'ernie-4.0-8k', - 'ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview', - 'yi_34b_chat' + "ernie-bot", + "ernie-bot-8k", + "ernie-3.5-8k", + "ernie-3.5-8k-0205", + "ernie-3.5-8k-1222", + "ernie-3.5-4k-0205", + "ernie-3.5-128k", + "ernie-4.0-8k", + "ernie-4.0-turbo-8k", + "ernie-4.0-turbo-8k-preview", + "yi_34b_chat", ] - api_key: str = '' - secret_key: str = '' + api_key: str = "" + secret_key: str = "" def __init__(self, api_key: str, secret_key: str): self.api_key = api_key @@ -148,10 +145,7 @@ class _CommonWenxin: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: - credentials_kwargs = { - "api_key": credentials['api_key'], - "secret_key": credentials['secret_key'] - } + credentials_kwargs = {"api_key": credentials["api_key"], "secret_key": credentials["secret_key"]} return credentials_kwargs def _handle_error(self, code: int, msg: str): @@ -187,13 +181,13 @@ class _CommonWenxin: 336105: BadRequestError, 336200: InternalServerError, 336303: BadRequestError, - 337006: BadRequestError + 337006: BadRequestError, } if code in error_map: raise error_map[code](msg) else: - raise InternalServerError(f'Unknown error: {msg}') + raise InternalServerError(f"Unknown error: {msg}") def _get_access_token(self) -> str: token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 8109949b1d..07b970f810 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -15,33 +15,39 @@ from core.model_runtime.model_providers.wenxin.wenxin_errors import ( class ErnieMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' - FUNCTION = 'function' - SYSTEM = 'system' + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + SYSTEM = "system" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - def __init__(self, content: str, role: str = 'user') -> None: + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role + class ErnieBotModel(_CommonWenxin): - - def generate(self, model: str, stream: bool, messages: list[ErnieMessage], - parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ - stop: list[str], user: str) \ - -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: - + def generate( + self, + model: str, + stream: bool, + messages: list[ErnieMessage], + parameters: dict[str, Any], + timeout: int, + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: # check parameters self._check_parameters(model, parameters, tools, stop) @@ -49,22 +55,23 @@ class ErnieBotModel(_CommonWenxin): access_token = self._get_access_token() # generate request body - url = f'{self.api_bases[model]}?access_token={access_token}' + url = f"{self.api_bases[model]}?access_token={access_token}" # clone messages messages_cloned = self._copy_messages(messages=messages) # build body - body = self._build_request_body(model, messages=messages_cloned, stream=stream, - parameters=parameters, tools=tools, stop=stop, user=user) + body = self._build_request_body( + model, messages=messages_cloned, stream=stream, parameters=parameters, tools=tools, stop=stop, user=user + ) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } resp = post(url=url, data=dumps(body), headers=headers, stream=stream) if resp.status_code != 200: - raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + raise InternalServerError(f"Failed to invoke ernie bot: {resp.text}") if stream: return self._handle_chat_stream_generate_response(resp) @@ -73,10 +80,11 @@ class ErnieBotModel(_CommonWenxin): def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] - def _check_parameters(self, model: str, parameters: dict[str, Any], - tools: list[PromptMessageTool], stop: list[str]) -> None: + def _check_parameters( + self, model: str, parameters: dict[str, Any], tools: list[PromptMessageTool], stop: list[str] + ) -> None: if model not in self.api_bases: - raise BadRequestError(f'Invalid model: {model}') + raise BadRequestError(f"Invalid model: {model}") # if model not in self.function_calling_supports and tools is not None and len(tools) > 0: # raise BadRequestError(f'Model {model} does not support calling function.') @@ -85,86 +93,106 @@ class ErnieBotModel(_CommonWenxin): # so, we just disable function calling for now. if tools is not None and len(tools) > 0: - raise BadRequestError('function calling is not supported yet.') + raise BadRequestError("function calling is not supported yet.") if stop is not None: if len(stop) > 4: - raise BadRequestError('stop list should not exceed 4 items.') + raise BadRequestError("stop list should not exceed 4 items.") for s in stop: if len(s) > 20: - raise BadRequestError('stop item should not exceed 20 characters.') + raise BadRequestError("stop item should not exceed 20 characters.") - def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any], - tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]: + def _build_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> dict[str, Any]: # if model in self.function_calling_supports: # return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user) return self._build_chat_request_body(model, messages, stream, parameters, stop, user) - def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], tools: list[PromptMessageTool], - stop: list[str], user: str) \ - -> dict[str, Any]: + def _build_function_calling_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> dict[str, Any]: if len(messages) % 2 == 0: - raise BadRequestError('The number of messages should be odd.') - if messages[0].role == 'function': - raise BadRequestError('The first message should be user message.') + raise BadRequestError("The number of messages should be odd.") + if messages[0].role == "function": + raise BadRequestError("The first message should be user message.") """ TODO: implement function calling """ - def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], stop: list[str], user: str) \ - -> dict[str, Any]: + def _build_chat_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + stop: list[str], + user: str, + ) -> dict[str, Any]: if len(messages) == 0: - raise BadRequestError('The number of messages should not be zero.') + raise BadRequestError("The number of messages should not be zero.") # check if the first element is system, shift it - system_message = '' - if messages[0].role == 'system': + system_message = "" + if messages[0].role == "system": message = messages.pop(0) system_message = message.content if len(messages) % 2 == 0: - raise BadRequestError('The number of messages should be odd.') - if messages[0].role != 'user': - raise BadRequestError('The first message should be user message.') + raise BadRequestError("The number of messages should be odd.") + if messages[0].role != "user": + raise BadRequestError("The first message should be user message.") body = { - 'messages': [message.to_dict() for message in messages], - 'stream': stream, - 'stop': stop, - 'user_id': user, - **parameters + "messages": [message.to_dict() for message in messages], + "stream": stream, + "stop": stop, + "user_id": user, + **parameters, } - if 'max_tokens' in parameters and type(parameters['max_tokens']) == int: - body['max_output_tokens'] = parameters['max_tokens'] + if "max_tokens" in parameters and type(parameters["max_tokens"]) == int: + body["max_output_tokens"] = parameters["max_tokens"] - if 'presence_penalty' in parameters and type(parameters['presence_penalty']) == float: - body['penalty_score'] = parameters['presence_penalty'] + if "presence_penalty" in parameters and type(parameters["presence_penalty"]) == float: + body["penalty_score"] = parameters["presence_penalty"] if system_message: - body['system'] = system_message + body["system"] = system_message return body def _handle_chat_generate_response(self, response: Response) -> ErnieMessage: data = response.json() - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) - result = data['result'] - usage = data['usage'] + result = data["result"] + usage = data["usage"] - message = ErnieMessage(content=result, role='assistant') + message = ErnieMessage(content=result, role="assistant") message.usage = { - 'prompt_tokens': usage['prompt_tokens'], - 'completion_tokens': usage['completion_tokens'], - 'total_tokens': usage['total_tokens'] + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], } return message @@ -173,19 +201,19 @@ class ErnieBotModel(_CommonWenxin): for line in response.iter_lines(): if len(line) == 0: continue - line = line.decode('utf-8') - if line[0] == '{': + line = line.decode("utf-8") + if line[0] == "{": try: data = loads(line) - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) except Exception as e: - raise InternalServerError(f'Failed to parse response: {e}') + raise InternalServerError(f"Failed to parse response: {e}") - if line.startswith('data:'): + if line.startswith("data:"): line = line[5:].strip() else: continue @@ -195,23 +223,23 @@ class ErnieBotModel(_CommonWenxin): try: data = loads(line) except Exception as e: - raise InternalServerError(f'Failed to parse response: {e}') + raise InternalServerError(f"Failed to parse response: {e}") - result = data['result'] - is_end = data['is_end'] + result = data["result"] + is_end = data["is_end"] if is_end: - usage = data['usage'] - finish_reason = data.get('finish_reason', None) - message = ErnieMessage(content=result, role='assistant') + usage = data["usage"] + finish_reason = data.get("finish_reason", None) + message = ErnieMessage(content=result, role="assistant") message.usage = { - 'prompt_tokens': usage['prompt_tokens'], - 'completion_tokens': usage['completion_tokens'], - 'total_tokens': usage['total_tokens'] + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], } message.stop_reason = finish_reason yield message else: - message = ErnieMessage(content=result, role='assistant') + message = ErnieMessage(content=result, role="assistant") yield message diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index 140606298c..1ff0ac7ad2 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -30,42 +30,82 @@ if you are not sure about the structure. You should also complete the text started with ``` but not tell ``` directly. """ -class ErnieBotLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: +class ErnieBotLargeLanguageModel(LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: - response_format = model_parameters['response_format'] + if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: + response_format = model_parameters["response_format"] stop = stop or [] - self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format) - model_parameters.pop('response_format') + self._transform_json_prompts( + model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format + ) + model_parameters.pop("response_format") if stream: return self._code_block_mode_stream_processor( model=model, prompt_messages=prompt_messages, - input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + input_generator=self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), ) - + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _transform_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts to model prompts """ @@ -74,34 +114,44 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=ERNIE_BOT_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=ERNIE_BOT_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=ERNIE_BOT_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=ERNIE_BOT_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): # add ```JSON\n to the last message prompt_messages[-1].content += "\n```JSON\n{\n" else: # append a user message - prompt_messages.append(UserPromptMessage( - content="```JSON\n{\n" - )) + prompt_messages.append(UserPromptMessage(content="```JSON\n{\n")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: # tools is not supported yet return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: + def _num_tokens_from_messages( + self, + messages: list[PromptMessage], + ) -> int: """Calculate num tokens for baichuan model""" + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -113,10 +163,10 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -126,36 +176,53 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] try: BaiduAccessToken.get_access_token(api_key, secret_key) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: instance = ErnieBotModel( - api_key=credentials['api_key'], - secret_key=credentials['secret_key'], + api_key=credentials["api_key"], + secret_key=credentials["secret_key"], ) - user = user if user else 'ErnieBotDefault' + user = user if user else "ErnieBotDefault" # convert prompt messages to baichuan messages messages = [ ErnieMessage( - content=message.content if isinstance(message.content, str) else ''.join([ - content.data for content in message.content - ]), - role=message.role.value - ) for message in prompt_messages + content=message.content + if isinstance(message.content, str) + else "".join([content.data for content in message.content]), + role=message.role.value, + ) + for message in prompt_messages ] # invoke model - response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60, tools=tools, stop=stop, user=user) + response = instance.generate( + model=model, + stream=stream, + messages=messages, + parameters=model_parameters, + timeout=60, + tools=tools, + stop=stop, + user=user, + ) if stream: return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) @@ -180,41 +247,47 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): message_dict = {"role": "system", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: ErnieMessage) -> LLMResult: + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: ErnieMessage + ) -> LLMResult: # convert baichuan message to llm result - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=response.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=response.content, tool_calls=[]), usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Generator[ErnieMessage, None, None]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[ErnieMessage, None, None], + ) -> Generator: for message in response: if message.usage: - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, finish_reason=message.stop_reason if message.stop_reason else None, ), @@ -225,10 +298,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), finish_reason=message.stop_reason if message.stop_reason else None, ), ) diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index 10ac1a1861..db323ae4c1 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -29,38 +29,38 @@ class TextEmbedding: class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): access_token = self._get_access_token() - url = f'{self.api_bases[model]}?access_token={access_token}' + url = f"{self.api_bases[model]}?access_token={access_token}" body = self._build_embed_request_body(model, texts, user) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } resp = post(url, data=dumps(body), headers=headers) if resp.status_code != 200: - raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + raise InternalServerError(f"Failed to invoke ernie bot: {resp.text}") return self._handle_embed_response(model, resp) def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]: if len(texts) == 0: - raise BadRequestError('The number of texts should not be zero.') + raise BadRequestError("The number of texts should not be zero.") body = { - 'input': texts, - 'user_id': user, + "input": texts, + "user_id": user, } return body def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): data = response.json() - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) - embeddings = [v['embedding'] for v in data['data']] - _usage = data['usage'] - tokens = _usage['prompt_tokens'] - total_tokens = _usage['total_tokens'] + embeddings = [v["embedding"] for v in data["data"]] + _usage = data["usage"] + tokens = _usage["prompt_tokens"] + total_tokens = _usage["total_tokens"] return embeddings, tokens, total_tokens @@ -69,22 +69,23 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding: return WenxinTextEmbedding(api_key, secret_key) - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ - Invoke text embedding model + Invoke text embedding model - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param user: unique user id - :return: embeddings result - """ + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) - user = user if user else 'ErnieBotDefault' + user = user if user else "ErnieBotDefault" context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -94,7 +95,6 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): used_total_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) @@ -110,9 +110,8 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): _iter = range(0, len(inputs), max_chunks) for i in _iter: embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents( - model, - inputs[i: i + max_chunks], - user) + model, inputs[i : i + max_chunks], user + ) used_tokens += _used_tokens used_total_tokens += _total_used_tokens batched_embeddings += embeddings_batch @@ -142,12 +141,12 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): return total_num_tokens def validate_credentials(self, model: str, credentials: Mapping) -> None: - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] try: BaiduAccessToken.get_access_token(api_key, secret_key) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: @@ -164,10 +163,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -178,7 +174,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.py b/api/core/model_runtime/model_providers/wenxin/wenxin.py index 04845d06bc..895af20bc8 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.py @@ -6,6 +6,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvid logger = logging.getLogger(__name__) + class WenxinProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ class WenxinProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) # Use `ernie-bot` model for validate, - model_instance.validate_credentials( - model='ernie-bot', - credentials=credentials - ) + model_instance.validate_credentials(model="ernie-bot", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py index 0fbd0f55ec..f2e2248680 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py @@ -18,40 +18,37 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalance, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalance(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 4760e8f118..b2c837dee1 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -65,88 +65,108 @@ from core.model_runtime.utils import helper class XinferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - invoke LLM + invoke LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` """ - if 'temperature' in model_parameters: - if model_parameters['temperature'] < 0.01: - model_parameters['temperature'] = 0.01 - elif model_parameters['temperature'] > 1.0: - model_parameters['temperature'] = 0.99 + if "temperature" in model_parameters: + if model_parameters["temperature"] < 0.01: + model_parameters["temperature"] = 0.01 + elif model_parameters["temperature"] > 1.0: + model_parameters["temperature"] = 0.99 return self._generate( - model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=tools, stop=stop, stream=stream, user=user, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key'), - ) + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), + ), ) def validate_credentials(self, model: str, credentials: dict) -> None: """ - validate credentials + validate credentials - credentials should be like: - { - 'model_type': 'text-generation', - 'server_url': 'server url', - 'model_uid': 'model uid', - } + credentials should be like: + { + 'model_type': 'text-generation', + 'server_url': 'server url', + 'model_uid': 'model uid', + } """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") extra_param = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key') + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'completion_type' not in credentials: - if 'chat' in extra_param.model_ability: - credentials['completion_type'] = 'chat' - elif 'generate' in extra_param.model_ability: - credentials['completion_type'] = 'completion' + if "completion_type" not in credentials: + if "chat" in extra_param.model_ability: + credentials["completion_type"] = "chat" + elif "generate" in extra_param.model_ability: + credentials["completion_type"] = "completion" else: raise ValueError( - f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') + f"xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type" + ) if extra_param.support_function_call: - credentials['support_function_call'] = True + credentials["support_function_call"] = True if extra_param.support_vision: - credentials['support_vision'] = True + credentials["support_vision"] = True if extra_param.context_length: - credentials['context_length'] = extra_param.context_length + credentials["context_length"] = extra_param.context_length except RuntimeError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + raise CredentialsValidateFailedError(f"Xinference credentials validate failed: {e}") except KeyError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + raise CredentialsValidateFailedError(f"Xinference credentials validate failed: {e}") except Exception as e: raise e - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ - get number of tokens + get number of tokens - cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use - so we just take the GPT2 tokenizer as default + cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default """ return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], - is_completion_model: bool = False) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False + ) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -162,10 +182,10 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -217,30 +237,30 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) @@ -248,9 +268,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ - convert prompt message to text + convert prompt message to text """ - text = '' + text = "" for item in message: if isinstance(item, UserPromptMessage): text += item.content @@ -259,7 +279,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): elif isinstance(item, AssistantPromptMessage): text += item.content else: - raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") return text def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: @@ -275,19 +295,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -297,7 +311,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -312,151 +326,144 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=credentials.get('context_length', 2048), + max=credentials.get("context_length", 2048), default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY, use_template=DefaultParameterName.PRESENCE_PENALTY, type=ParameterType.FLOAT, label=I18nObject( - en_US='Presence Penalty', - zh_Hans='存在惩罚', + en_US="Presence Penalty", + zh_Hans="存在惩罚", ), required=False, help=I18nObject( - en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they ' - 'appear in the text so far, increasing the model\'s likelihood to talk about new topics.', - zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。' + en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they " + "appear in the text so far, increasing the model's likelihood to talk about new topics.", + zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。", ), default=0.0, min=-2.0, max=2.0, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY, use_template=DefaultParameterName.FREQUENCY_PENALTY, type=ParameterType.FLOAT, label=I18nObject( - en_US='Frequency Penalty', - zh_Hans='频率惩罚', + en_US="Frequency Penalty", + zh_Hans="频率惩罚", ), required=False, help=I18nObject( - en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on their ' - 'existing frequency in the text so far, decreasing the model\'s likelihood to repeat the ' - 'same line verbatim.', - zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。' + en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on their " + "existing frequency in the text so far, decreasing the model's likelihood to repeat the " + "same line verbatim.", + zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。", ), default=0.0, min=-2.0, max=2.0, - precision=2 - ) + precision=2, + ), ] completion_type = None - if 'completion_type' in credentials: - if credentials['completion_type'] == 'chat': + if "completion_type" in credentials: + if credentials["completion_type"] == "chat": completion_type = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') else: extra_args = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key') + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'chat' in extra_args.model_ability: + if "chat" in extra_args.model_ability: completion_type = LLMMode.CHAT.value - elif 'generate' in extra_args.model_ability: + elif "generate" in extra_args.model_ability: completion_type = LLMMode.COMPLETION.value else: - raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') + raise ValueError(f"xinference model ability {extra_args.model_ability} is not supported") features = [] - support_function_call = credentials.get('support_function_call', False) + support_function_call = credentials.get("support_function_call", False) if support_function_call: features.append(ModelFeature.TOOL_CALL) - support_vision = credentials.get('support_vision', False) + support_vision = credentials.get("support_vision", False) if support_vision: features.append(ModelFeature.VISION) - context_length = credentials.get('context_length', 2048) + context_length = credentials.get("context_length", 2048) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, features=features, - model_properties={ - ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: context_length - }, - parameter_rules=rules + model_properties={ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length}, + parameter_rules=rules, ) return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + extra_model_kwargs: XinferenceModelExtraParameter, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - generate text from LLM + generate text from LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` - extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` + extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] - api_key = credentials.get('api_key') or "abc" + api_key = credentials.get("api_key") or "abc" client = OpenAI( base_url=f'{credentials["server_url"]}/v1', @@ -466,34 +473,29 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ) xinference_client = Client( - base_url=credentials['server_url'], - api_key=credentials.get('api_key'), + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_model = xinference_client.get_model(credentials['model_uid']) + xinference_model = xinference_client.get_model(credentials["model_uid"]) generate_config = { - 'temperature': model_parameters.get('temperature', 1.0), - 'top_p': model_parameters.get('top_p', 0.7), - 'max_tokens': model_parameters.get('max_tokens', 512), - 'presence_penalty': model_parameters.get('presence_penalty', 0.0), - 'frequency_penalty': model_parameters.get('frequency_penalty', 0.0), + "temperature": model_parameters.get("temperature", 1.0), + "top_p": model_parameters.get("top_p", 0.7), + "max_tokens": model_parameters.get("max_tokens", 512), + "presence_penalty": model_parameters.get("presence_penalty", 0.0), + "frequency_penalty": model_parameters.get("frequency_penalty", 0.0), } if stop: - generate_config['stop'] = stop + generate_config["stop"] = stop if tools and len(tools) > 0: - generate_config['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools - ] - vision = credentials.get('support_vision', False) + generate_config["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] + vision = credentials.get("support_vision", False) if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): resp = client.chat.completions.create( - model=credentials['model_uid'], + model=credentials["model_uid"], messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], stream=stream, user=user, @@ -501,34 +503,34 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ) if stream: if tools and len(tools) > 0: - raise InvokeBadRequestError('xinference tool calls does not support stream mode') - return self._handle_chat_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_chat_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) + raise InvokeBadRequestError("xinference tool calls does not support stream mode") + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) elif isinstance(xinference_model, RESTfulGenerateModelHandle): resp = client.completions.create( - model=credentials['model_uid'], + model=credentials["model_uid"], prompt=self._convert_prompt_message_to_text(prompt_messages), stream=stream, user=user, **generate_config, ) if stream: - return self._handle_completion_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_completion_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) + return self._handle_completion_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) + return self._handle_completion_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) else: - raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported') + raise NotImplementedError(f"xinference model handle type {type(xinference_model)} is not supported") - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -539,21 +541,19 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -563,23 +563,25 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: ChatCompletion) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: ChatCompletion, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") @@ -595,15 +597,15 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=assistant_prompt_message_tool_calls + content=assistant_message.content, tool_calls=assistant_prompt_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -615,13 +617,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): return response - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[ChatCompletionChunk]) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[ChatCompletionChunk], + ) -> Generator: """ - handle stream chat generate response + handle stream chat generate response """ - full_response = '' + full_response = "" for chunk in resp: if len(chunk.choices) == 0: @@ -629,7 +636,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue # check if there is a tool call in the response @@ -646,32 +653,31 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage + index=0, message=assistant_prompt_message, finish_reason=delta.finish_reason, usage=usage ), ) else: @@ -687,11 +693,16 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): full_response += delta.delta.content - def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Completion) -> LLMResult: + def _handle_completion_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Completion, + ) -> LLMResult: """ - handle normal completion generate response + handle normal completion generate response """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") @@ -699,14 +710,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): assistant_message = resp.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message, tool_calls=[]) - prompt_tokens = self._get_num_tokens_by_gpt2( - self._convert_prompt_message_to_text(prompt_messages) - ) + prompt_tokens = self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) completion_tokens = self._num_tokens_from_messages( messages=[assistant_prompt_message], tools=[], is_completion_model=True ) @@ -724,13 +730,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): return response - def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[Completion]) -> Generator: + def _handle_completion_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[Completion], + ) -> Generator: """ - handle stream completion generate response + handle stream completion generate response """ - full_response = '' + full_response = "" for chunk in resp: if len(chunk.choices) == 0: @@ -739,40 +750,33 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): delta = chunk.choices[0] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[]) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) - prompt_tokens = self._get_num_tokens_by_gpt2( - self._convert_prompt_message_to_text(prompt_messages) - ) + prompt_tokens = self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) completion_tokens = self._num_tokens_from_messages( messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True ) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage + index=0, message=assistant_prompt_message, finish_reason=delta.finish_reason, usage=usage ), ) else: - if delta.text is None or delta.text == '': + if delta.text is None or delta.text == "": continue yield LLMResultChunk( @@ -807,15 +811,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError + PermissionDeniedError, ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index d809537479..1582fe43b9 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -22,10 +22,16 @@ class XinferenceRerankModel(RerankModel): Model class for Xinference rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -39,24 +45,16 @@ class XinferenceRerankModel(RerankModel): :return: rerank result """ if len(docs) == 0: - return RerankResult( - model=model, - docs=[] - ) + return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + if server_url.endswith("/"): server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} - params = { - 'documents': docs, - 'query': query, - 'top_n': top_n, - 'return_documents': True - } + params = {"documents": docs, "query": query, "top_n": top_n, "return_documents": True} try: handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers) response = handle.rerank(**params) @@ -69,27 +67,24 @@ class XinferenceRerankModel(RerankModel): response = handle.rerank(**params) rerank_documents = [] - for idx, result in enumerate(response['results']): + for idx, result in enumerate(response["results"]): # format document - index = result['index'] - page_content = result['document'] if isinstance(result['document'], str) else result['document']['text'] + index = result["index"] + page_content = result["document"] if isinstance(result["document"], str) else result["document"]["text"] rerank_document = RerankDocument( index=index, text=page_content, - score=result['relevance_score'], + score=result["relevance_score"], ) # score threshold check if score_threshold is not None: - if result['relevance_score'] >= score_threshold: + if result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -100,34 +95,35 @@ class XinferenceRerankModel(RerankModel): :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] # initialize client client = Client( - base_url=credentials['server_url'], - api_key=credentials.get('api_key'), + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_client = client.get_model(model_uid=credentials['model_uid']) + xinference_client = client.get_model(model_uid=credentials["model_uid"]) if not isinstance(xinference_client, RESTfulRerankModelHandle): raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a rerank model') + "please check model type, the model you want to invoke is not a rerank model" + ) self.invoke( model=model, credentials=credentials, query="Whose kasumi", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -143,53 +139,38 @@ class XinferenceRerankModel(RerankModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle): - def rerank( - self, - documents: list[str], - query: str, - top_n: Optional[int] = None, - max_chunks_per_doc: Optional[int] = None, - return_documents: Optional[bool] = None, - **kwargs + self, + documents: list[str], + query: str, + top_n: Optional[int] = None, + max_chunks_per_doc: Optional[int] = None, + return_documents: Optional[bool] = None, + **kwargs, ): url = f"{self._base_url}/v1/rerank" request_body = { @@ -205,8 +186,6 @@ class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle): response = requests.post(url, json=request_body, headers=self.auth_headers) if response.status_code != 200: - raise InvokeServerUnavailableError( - f"Failed to rerank documents, detail: {response.json()['detail']}" - ) + raise InvokeServerUnavailableError(f"Failed to rerank documents, detail: {response.json()['detail']}") response_data = response.json() return response_data diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 62b77f22e5..54c8b51654 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -21,9 +21,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel): Model class for Xinference speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -44,27 +42,28 @@ class XinferenceSpeech2TextModel(Speech2TextModel): :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] # initialize client client = Client( - base_url=credentials['server_url'], - api_key=credentials.get('api_key'), + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_client = client.get_model(model_uid=credentials['model_uid']) + xinference_client = client.get_model(model_uid=credentials["model_uid"]) if not isinstance(xinference_client, RESTfulAudioModelHandle): raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a audio model') + "please check model type, the model you want to invoke is not a audio model" + ) audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self.invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -80,23 +79,11 @@ class XinferenceSpeech2TextModel(Speech2TextModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def _speech2text_invoke( @@ -122,21 +109,17 @@ class XinferenceSpeech2TextModel(Speech2TextModel): :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit. :return: text for given audio file """ - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + if server_url.endswith("/"): server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers) response = handle.transcriptions( - audio=file, - language=language, - prompt=prompt, - response_format=response_format, - temperature=temperature + audio=file, language=language, prompt=prompt, response_format=response_format, temperature=temperature ) except RuntimeError as e: raise InvokeServerUnavailableError(str(e)) @@ -145,17 +128,15 @@ class XinferenceSpeech2TextModel(Speech2TextModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 3a8d704c25..ac704e7de8 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -23,9 +23,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ Model class for Xinference text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -41,12 +42,12 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + if server_url.endswith("/"): server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers) @@ -70,13 +71,11 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): embedding: List[float] """ - usage = embeddings['usage'] - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = embeddings["usage"] + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[embedding['embedding'] for embedding in embeddings['data']], - usage=usage + model=model, embeddings=[embedding["embedding"] for embedding in embeddings["data"]], usage=usage ) return result @@ -105,12 +104,12 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") extra_args = XinferenceHelper.get_xinference_extra_parameter( server_url=server_url, model_uid=model_uid, @@ -118,8 +117,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): ) if extra_args.max_tokens: - credentials['max_tokens'] = extra_args.max_tokens - if server_url.endswith('/'): + credentials["max_tokens"] = extra_args.max_tokens + if server_url.endswith("/"): server_url = server_url[:-1] client = Client( @@ -133,32 +132,24 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): raise InvokeAuthorizationError(e) if not isinstance(handle, RESTfulEmbeddingModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') + raise InvokeBadRequestError( + "please check model type, the model you want to invoke is not a text embedding model" + ) - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError as e: - raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}') + raise CredentialsValidateFailedError(f"Failed to validate credentials for model {model}: {e}") except RuntimeError as e: raise CredentialsValidateFailedError(e) @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: @@ -172,10 +163,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -186,28 +174,26 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ ModelPropertyKey.MAX_CHUNKS: 1, - ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512, + ModelPropertyKey.CONTEXT_SIZE: "max_tokens" in credentials and credentials["max_tokens"] or 512, }, - parameter_rules=[] + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index 8cc99fef7c..60db151302 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -19,92 +19,91 @@ from core.model_runtime.model_providers.xinference.xinference_helper import Xinf class XinferenceText2SpeechModel(TTSModel): - def __init__(self): # preset voices, need support custom voice self.model_voices = { - '__default': { - 'all': [ - {'name': 'Default', 'value': 'default'}, + "__default": { + "all": [ + {"name": "Default", "value": "default"}, ] }, - 'ChatTTS': { - 'all': [ - {'name': 'Alloy', 'value': 'alloy'}, - {'name': 'Echo', 'value': 'echo'}, - {'name': 'Fable', 'value': 'fable'}, - {'name': 'Onyx', 'value': 'onyx'}, - {'name': 'Nova', 'value': 'nova'}, - {'name': 'Shimmer', 'value': 'shimmer'}, + "ChatTTS": { + "all": [ + {"name": "Alloy", "value": "alloy"}, + {"name": "Echo", "value": "echo"}, + {"name": "Fable", "value": "fable"}, + {"name": "Onyx", "value": "onyx"}, + {"name": "Nova", "value": "nova"}, + {"name": "Shimmer", "value": "shimmer"}, ] }, - 'CosyVoice': { - 'zh-Hans': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "CosyVoice": { + "zh-Hans": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'zh-Hant': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "zh-Hant": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'en-US': [ - {'name': '英文男', 'value': '英文男'}, - {'name': '英文女', 'value': '英文女'}, + "en-US": [ + {"name": "英文男", "value": "英文男"}, + {"name": "英文女", "value": "英文女"}, ], - 'ja-JP': [ - {'name': '日语男', 'value': '日语男'}, + "ja-JP": [ + {"name": "日语男", "value": "日语男"}, ], - 'ko-KR': [ - {'name': '韩语女', 'value': '韩语女'}, - ] - } + "ko-KR": [ + {"name": "韩语女", "value": "韩语女"}, + ], + }, } def validate_credentials(self, model: str, credentials: dict) -> None: """ - Validate model credentials + Validate model credentials - :param model: model name - :param credentials: model credentials - :return: - """ + :param model: model name + :param credentials: model credentials + :return: + """ try: - if ("/" in credentials['model_uid'] or - "?" in credentials['model_uid'] or - "#" in credentials['model_uid']): + if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] extra_param = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'], - api_key=credentials.get('api_key'), + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'text-to-audio' not in extra_param.model_ability: + if "text-to-audio" not in extra_param.model_ability: raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a text-to-audio model') + "please check model type, the model you want to invoke is not a text-to-audio model" + ) if extra_param.model_family and extra_param.model_family in self.model_voices: - credentials['audio_model_name'] = extra_param.model_family + credentials["audio_model_name"] = extra_param.model_family else: - credentials['audio_model_name'] = '__default' + credentials["audio_model_name"] = "__default" self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ _invoke text2speech model @@ -120,18 +119,16 @@ class XinferenceText2SpeechModel(TTSModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity @@ -147,40 +144,28 @@ class XinferenceText2SpeechModel(TTSModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: - audio_model_name = credentials.get('audio_model_name', '__default') + audio_model_name = credentials.get("audio_model_name", "__default") for key, voices in self.model_voices.items(): if key in audio_model_name: if language and language in voices: return voices[language] - elif 'all' in voices: - return voices['all'] + elif "all" in voices: + return voices["all"] else: all_voices = [] for lang, lang_voices in voices.items(): all_voices.extend(lang_voices) return all_voices - return self.model_voices['__default']['all'] + return self.model_voices["__default"]["all"] def _get_model_default_voice(self, model: str, credentials: dict) -> any: return "" @@ -194,8 +179,7 @@ class XinferenceText2SpeechModel(TTSModel): def _get_model_workers_limit(self, model: str, credentials: dict) -> int: return 5 - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model @@ -205,48 +189,42 @@ class XinferenceText2SpeechModel(TTSModel): :param voice: model timbre :return: text translated to audio file """ - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] try: - api_key = credentials.get('api_key') - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + api_key = credentials.get("api_key") + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} handle = RESTfulAudioModelHandle( - credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers + credentials["model_uid"], credentials["server_url"], auth_headers=auth_headers ) - model_support_voice = [x.get("value") for x in - self.get_tts_model_voices(model=model, credentials=credentials)] + model_support_voice = [ + x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials) + ] if not voice or voice not in model_support_voice: voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: sentences = self._split_text_into_sentences(content_text, max_length=word_limit) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit( - handle.speech, - input=sentences[i], - voice=voice, - response_format="mp3", - speed=1.0, - stream=False - ) - for i in range(len(sentences))] + futures = [ + executor.submit( + handle.speech, input=sentences[i], voice=voice, response_format="mp3", speed=1.0, stream=False + ) + for i in range(len(sentences)) + ] for index, future in enumerate(futures): response = future.result() for i in range(0, len(response), 1024): - yield response[i:i + 1024] + yield response[i : i + 1024] else: response = handle.speech( - input=content_text.strip(), - voice=voice, - response_format="mp3", - speed=1.0, - stream=False + input=content_text.strip(), voice=voice, response_format="mp3", speed=1.0, stream=False ) for i in range(0, len(response), 1024): - yield response[i:i + 1024] + yield response[i : i + 1024] except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 151166f165..6ad10e690d 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -18,9 +18,17 @@ class XinferenceModelExtraParameter: support_vision: bool = False model_family: Optional[str] - def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], - support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int, - model_family: Optional[str]) -> None: + def __init__( + self, + model_format: str, + model_handle_type: str, + model_ability: list[str], + support_function_call: bool, + support_vision: bool, + max_tokens: int, + context_length: int, + model_family: Optional[str], + ) -> None: self.model_format = model_format self.model_handle_type = model_handle_type self.model_ability = model_ability @@ -30,9 +38,11 @@ class XinferenceModelExtraParameter: self.context_length = context_length self.model_family = model_family + cache = {} cache_lock = Lock() + class XinferenceHelper: @staticmethod def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: @@ -40,16 +50,16 @@ class XinferenceHelper: with cache_lock: if model_uid not in cache: cache[model_uid] = { - 'expires': time() + 300, - 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key) + "expires": time() + 300, + "value": XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key), } - return cache[model_uid]['value'] + return cache[model_uid]["value"] @staticmethod def _clean_cache() -> None: try: with cache_lock: - expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + expired_keys = [model_uid for model_uid, model in cache.items() if model["expires"] < time()] for model_uid in expired_keys: del cache[model_uid] except RuntimeError as e: @@ -58,55 +68,57 @@ class XinferenceHelper: @staticmethod def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: """ - get xinference model extra parameter like model_format and model_handle_type + get xinference model extra parameter like model_format and model_handle_type """ if not model_uid or not model_uid.strip() or not server_url or not server_url.strip(): - raise RuntimeError('model_uid is empty') + raise RuntimeError("model_uid is empty") - url = str(URL(server_url) / 'v1' / 'models' / model_uid) + url = str(URL(server_url) / "v1" / "models" / model_uid) # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 session = Session() - session.mount('http://', HTTPAdapter(max_retries=3)) - session.mount('https://', HTTPAdapter(max_retries=3)) - headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + session.mount("http://", HTTPAdapter(max_retries=3)) + session.mount("https://", HTTPAdapter(max_retries=3)) + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: response = session.get(url, headers=headers, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: - raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') + raise RuntimeError(f"get xinference model extra parameter failed, url: {url}, error: {e}") if response.status_code != 200: - raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') + raise RuntimeError( + f"get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}" + ) response_json = response.json() - model_format = response_json.get('model_format', 'ggmlv3') - model_ability = response_json.get('model_ability', []) - model_family = response_json.get('model_family', None) + model_format = response_json.get("model_format", "ggmlv3") + model_ability = response_json.get("model_ability", []) + model_family = response_json.get("model_family", None) - if response_json.get('model_type') == 'embedding': - model_handle_type = 'embedding' - elif response_json.get('model_type') == 'audio': - model_handle_type = 'audio' - if model_family and model_family in ['ChatTTS', 'CosyVoice', 'FishAudio']: - model_ability.append('text-to-audio') + if response_json.get("model_type") == "embedding": + model_handle_type = "embedding" + elif response_json.get("model_type") == "audio": + model_handle_type = "audio" + if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]: + model_ability.append("text-to-audio") else: - model_ability.append('audio-to-text') - elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: - model_handle_type = 'chatglm' - elif 'generate' in model_ability: - model_handle_type = 'generate' - elif 'chat' in model_ability: - model_handle_type = 'chat' + model_ability.append("audio-to-text") + elif model_format == "ggmlv3" and "chatglm" in response_json["model_name"]: + model_handle_type = "chatglm" + elif "generate" in model_ability: + model_handle_type = "generate" + elif "chat" in model_ability: + model_handle_type = "chat" else: - raise NotImplementedError('xinference model handle type is not supported') + raise NotImplementedError("xinference model handle type is not supported") - support_function_call = 'tools' in model_ability - support_vision = 'vision' in model_ability - max_tokens = response_json.get('max_tokens', 512) + support_function_call = "tools" in model_ability + support_vision = "vision" in model_ability + max_tokens = response_json.get("max_tokens", 512) - context_length = response_json.get('context_length', 2048) + context_length = response_json.get("context_length", 2048) return XinferenceModelExtraParameter( model_format=model_format, @@ -116,5 +128,5 @@ class XinferenceHelper: support_vision=support_vision, max_tokens=max_tokens, context_length=context_length, - model_family=model_family + model_family=model_family, ) diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py index d33f38333b..5ab7fd126e 100644 --- a/api/core/model_runtime/model_providers/yi/llm/llm.py +++ b/api/core/model_runtime/model_providers/yi/llm/llm.py @@ -14,11 +14,17 @@ from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguag class YiLargeLanguageModel(OpenAILargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) # yi-vl-plus not support system prompt yet. @@ -27,7 +33,9 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): for message in prompt_messages: if not isinstance(message, SystemPromptMessage): prompt_message_except_system.append(message) - return super()._invoke(model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream) + return super()._invoke( + model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream + ) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -36,8 +44,7 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): super().validate_credentials(model, credentials) # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -55,8 +62,9 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -76,10 +84,10 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -110,10 +118,10 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://api.lingyiwanwu.com' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://api.lingyiwanwu.com" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/yi/yi.py b/api/core/model_runtime/model_providers/yi/yi.py index 691c7aa371..9599acb22b 100644 --- a/api/core/model_runtime/model_providers/yi/yi.py +++ b/api/core/model_runtime/model_providers/yi/yi.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class YiProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ class YiProvider(ModelProvider): # Use `yi-34b-chat-0205` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='yi-34b-chat-0205', - credentials=credentials - ) + model_instance.validate_credentials(model="yi-34b-chat-0205", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhinao/llm/llm.py b/api/core/model_runtime/model_providers/zhinao/llm/llm.py index 6930a5ed01..befc3de021 100644 --- a/api/core/model_runtime/model_providers/zhinao/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhinao/llm/llm.py @@ -7,11 +7,17 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class ZhinaoLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,5 +27,5 @@ class ZhinaoLargeLanguageModel(OAIAPICompatLargeLanguageModel): @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.360.cn/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.360.cn/v1" diff --git a/api/core/model_runtime/model_providers/zhinao/zhinao.py b/api/core/model_runtime/model_providers/zhinao/zhinao.py index 44b36c9f51..2a263292f9 100644 --- a/api/core/model_runtime/model_providers/zhinao/zhinao.py +++ b/api/core/model_runtime/model_providers/zhinao/zhinao.py @@ -8,7 +8,6 @@ logger = logging.getLogger(__name__) class ZhinaoProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ class ZhinaoProvider(ModelProvider): # Use `360gpt-turbo` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='360gpt-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="360gpt-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhipuai/_common.py b/api/core/model_runtime/model_providers/zhipuai/_common.py index 3412d8100f..fa95232f71 100644 --- a/api/core/model_runtime/model_providers/zhipuai/_common.py +++ b/api/core/model_runtime/model_providers/zhipuai/_common.py @@ -17,8 +17,7 @@ class _CommonZhipuaiAI: :return: """ credentials_kwargs = { - "api_key": credentials['api_key'] if 'api_key' in credentials else - credentials.get("zhipuai_api_key"), + "api_key": credentials["api_key"] if "api_key" in credentials else credentials.get("zhipuai_api_key"), } return credentials_kwargs @@ -38,5 +37,5 @@ class _CommonZhipuaiAI: InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index b2cdc7ad7a..484ac088db 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -35,12 +35,17 @@ And you should always end the block with a "```" to indicate the end of the JSON class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -62,9 +67,9 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) - # def _transform_json_prompts(self, model: str, credentials: dict, - # prompt_messages: list[PromptMessage], model_parameters: dict, - # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + # def _transform_json_prompts(self, model: str, credentials: dict, + # prompt_messages: list[PromptMessage], model_parameters: dict, + # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, # stream: bool = True, user: str | None = None) \ # -> None: # """ @@ -94,8 +99,13 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # content="```JSON\n" # )) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -130,16 +140,22 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): "temperature": 0.5, }, tools=[], - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials_kwargs: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials_kwargs: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -154,15 +170,13 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): """ extra_model_kwargs = {} # request to glm-4v-plus with stop words will always response "finish_reason":"network_error" - if stop and model!= 'glm-4v-plus': - extra_model_kwargs['stop'] = stop + if stop and model != "glm-4v-plus": + extra_model_kwargs["stop"] = stop - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) if len(prompt_messages) == 0: - raise ValueError('At least one message is required') + raise ValueError("At least one message is required") if prompt_messages[0].role == PromptMessageRole.SYSTEM: if not prompt_messages[0].content: @@ -175,10 +189,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' - if model not in ('glm-4v', 'glm-4v-plus'): + if model not in ("glm-4v", "glm-4v-plus"): # not support list message continue - # get image and + # get image and if not isinstance(copy_prompt_message, UserPromptMessage): # not support system message continue @@ -188,8 +202,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # not support image message continue - if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \ - copy_prompt_message.role == PromptMessageRole.USER: + if ( + new_prompt_messages + and new_prompt_messages[-1].role == PromptMessageRole.USER + and copy_prompt_message.role == PromptMessageRole.USER + ): new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content else: if copy_prompt_message.role == PromptMessageRole.USER: @@ -208,77 +225,66 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): else: new_prompt_messages.append(copy_prompt_message) - if model == 'glm-4v' or model == 'glm-4v-plus': + if model == "glm-4v" or model == "glm-4v-plus": params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: - params = { - 'model': model, - 'messages': [], - **model_parameters - } + params = {"model": model, "messages": [], **model_parameters} # glm model - if not model.startswith('chatglm'): - + if not model.startswith("chatglm"): for prompt_message in new_prompt_messages: if prompt_message.role == PromptMessageRole.TOOL: - params['messages'].append({ - 'role': 'tool', - 'content': prompt_message.content, - 'tool_call_id': prompt_message.tool_call_id - }) + params["messages"].append( + { + "role": "tool", + "content": prompt_message.content, + "tool_call_id": prompt_message.tool_call_id, + } + ) elif isinstance(prompt_message, AssistantPromptMessage): if prompt_message.tool_calls: - params['messages'].append({ - 'role': 'assistant', - 'content': prompt_message.content, - 'tool_calls': [ - { - 'id': tool_call.id, - 'type': tool_call.type, - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments + params["messages"].append( + { + "role": "assistant", + "content": prompt_message.content, + "tool_calls": [ + { + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, } - } for tool_call in prompt_message.tool_calls - ] - }) + for tool_call in prompt_message.tool_calls + ], + } + ) else: - params['messages'].append({ - 'role': 'assistant', - 'content': prompt_message.content - }) + params["messages"].append({"role": "assistant", "content": prompt_message.content}) else: - params['messages'].append({ - 'role': prompt_message.role.value, - 'content': prompt_message.content - }) + params["messages"].append( + {"role": prompt_message.role.value, "content": prompt_message.content} + ) else: # chatglm model for prompt_message in new_prompt_messages: # merge system message to user message - if prompt_message.role == PromptMessageRole.SYSTEM or \ - prompt_message.role == PromptMessageRole.TOOL or \ - prompt_message.role == PromptMessageRole.USER: - if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user': - params['messages'][-1]['content'] += "\n\n" + prompt_message.content + if ( + prompt_message.role == PromptMessageRole.SYSTEM + or prompt_message.role == PromptMessageRole.TOOL + or prompt_message.role == PromptMessageRole.USER + ): + if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user": + params["messages"][-1]["content"] += "\n\n" + prompt_message.content else: - params['messages'].append({ - 'role': 'user', - 'content': prompt_message.content - }) + params["messages"].append({"role": "user", "content": prompt_message.content}) else: - params['messages'].append({ - 'role': prompt_message.role.value, - 'content': prompt_message.content - }) + params["messages"].append( + {"role": prompt_message.role.value, "content": prompt_message.content} + ) if tools and len(tools) > 0: - params['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools - ] + params["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] if stream: response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs) @@ -287,47 +293,41 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): response = client.chat.completions.create(**params, **extra_model_kwargs) return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) - def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage], - model_parameters: dict): + def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict): messages = [ - { - 'role': message.role.value, - 'content': self._construct_glm_4v_messages(message.content) - } + {"role": message.role.value, "content": self._construct_glm_4v_messages(message.content)} for message in prompt_messages ] - params = { - 'model': model, - 'messages': messages, - **model_parameters - } + params = {"model": model, "messages": messages, **model_parameters} return params def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]: if isinstance(prompt_message, str): - return [{'type': 'text', 'text': prompt_message}] + return [{"type": "text", "text": prompt_message}] return [ - {'type': 'image_url', 'image_url': {'url': self._remove_image_header(item.data)}} - if item.type == PromptMessageContentType.IMAGE else - {'type': 'text', 'text': item.data} - + {"type": "image_url", "image_url": {"url": self._remove_image_header(item.data)}} + if item.type == PromptMessageContentType.IMAGE + else {"type": "text", "text": item.data} for item in prompt_message ] def _remove_image_header(self, image: str) -> str: - if image.startswith('data:image'): - return image.split(',')[1] + if image.startswith("data:image"): + return image.split(",")[1] return image - def _handle_generate_response(self, model: str, - credentials: dict, - tools: Optional[list[PromptMessageTool]], - response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + tools: Optional[list[PromptMessageTool]], + response: Completion, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -336,12 +336,12 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ - text = '' + text = "" assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for choice in response.choices: if choice.message.tool_calls: for tool_call in choice.message.tool_calls: - if tool_call.type == 'function': + if tool_call.type == "function": assistant_tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call.id, @@ -349,11 +349,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) + ), ) ) - text += choice.message.content or '' + text += choice.message.content or "" prompt_usage = response.usage.prompt_tokens completion_usage = response.usage.completion_tokens @@ -365,20 +365,20 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): result = LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=text, - tool_calls=assistant_tool_calls - ), + message=AssistantPromptMessage(content=text, tool_calls=assistant_tool_calls), usage=usage, ) return result - def _handle_generate_stream_response(self, model: str, - credentials: dict, - tools: Optional[list[PromptMessageTool]], - responses: Generator[ChatCompletionChunk, None, None], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + tools: Optional[list[PromptMessageTool]], + responses: Generator[ChatCompletionChunk, None, None], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -387,19 +387,19 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_assistant_content = '' + full_assistant_content = "" for chunk in responses: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for tool_call in delta.delta.tool_calls or []: - if tool_call.type == 'function': + if tool_call.type == "function": assistant_tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call.id, @@ -407,17 +407,16 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) + ), ) ) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_tool_calls + content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" if delta.finish_reason is not None and chunk.usage is not None: completion_tokens = chunk.usage.completion_tokens @@ -429,24 +428,22 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): yield LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, - system_fingerprint='', + system_fingerprint="", delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, - system_fingerprint='', + system_fingerprint="", delta=LLMResultChunkDelta( - index=delta.index, - message=assistant_prompt_message, - finish_reason=delta.finish_reason - ) + index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -473,18 +470,16 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> str: + def _convert_messages_to_prompt( + self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> str: """ :param messages: List of PromptMessage to combine. :return: Combined string with necessary human_prompt and ai_prompt tags. """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) if tools and len(tools) > 0: text += "\n\nTools:" diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 0f9fecfc72..ee20954381 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -14,9 +14,9 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): Model class for ZhipuAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -27,16 +27,14 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) embeddings, embedding_used_tokens = self.embed_documents(model, client, texts) return TextEmbeddingResult( embeddings=embeddings, usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens), - model=model + model=model, ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -50,7 +48,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): """ if len(texts) == 0: return 0 - + total_num_tokens = 0 for text in texts: total_num_tokens += self._get_num_tokens_by_gpt2(text) @@ -68,15 +66,13 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): try: # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) # call embedding model self.embed_documents( model=model, client=client, - texts=['ping'], + texts=["ping"], ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -100,7 +96,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): embedding_used_tokens += response.usage.total_tokens return [list(map(float, e)) for e in embeddings], embedding_used_tokens - + def embed_query(self, text: str) -> list[float]: """Call out to ZhipuAI's embedding endpoint. @@ -111,8 +107,8 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): Embeddings for the text. """ return self.embed_documents([text])[0] - - def _calc_response_usage(self, model: str,credentials: dict, tokens: int) -> EmbeddingUsage: + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -122,10 +118,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -136,7 +129,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai.py index c517d2dba5..e75aad6eb0 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai.py @@ -19,12 +19,9 @@ class ZhipuaiProvider(ModelProvider): try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='glm-4', - credentials=credentials - ) + model_instance.validate_credentials(model="glm-4", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py index 4dcd03f551..bf9b093cb3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py @@ -1,4 +1,3 @@ - from .__version__ import __version__ from ._client import ZhipuAI from .core._errors import ( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py index eb0ad332ca..659f38d7ff 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py @@ -1,2 +1 @@ - -__version__ = 'v2.0.1' \ No newline at end of file +__version__ = "v2.0.1" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py index 6588d1dd68..df9e506095 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py @@ -20,14 +20,14 @@ class ZhipuAI(HttpClient): api_key: str def __init__( - self, - *, - api_key: str | None = None, - base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, - max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, - http_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None + self, + *, + api_key: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, + http_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None, ) -> None: if api_key is None: raise ZhipuAIError("No api_key provided, please provide it through parameters or environment variables") @@ -38,6 +38,7 @@ class ZhipuAI(HttpClient): if base_url is None: base_url = "https://open.bigmodel.cn/api/paas/v4" from .__version__ import __version__ + super().__init__( version=__version__, base_url=base_url, @@ -58,9 +59,7 @@ class ZhipuAI(HttpClient): return {"Authorization": f"{_jwt_token.generate_token(api_key)}"} def __del__(self) -> None: - if (not hasattr(self, "_has_custom_http_client") - or not hasattr(self, "close") - or not hasattr(self, "_client")): + if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"): # if the '__init__' method raised an error, self would not have client attr return diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py index dab6dac5fe..1f80119739 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py @@ -17,25 +17,24 @@ class AsyncCompletions(BaseAPI): def __init__(self, client: ZhipuAI) -> None: super().__init__(client) - def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], list[list[int]], None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], list[list[int]], None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + tools: Optional[object] | NotGiven = NOT_GIVEN, + tool_choice: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> AsyncTaskStatus: _cast_type = AsyncTaskStatus @@ -57,9 +56,7 @@ class AsyncCompletions(BaseAPI): "tools": tools, "tool_choice": tool_choice, }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=_cast_type, enable_stream=False, ) @@ -71,16 +68,11 @@ class AsyncCompletions(BaseAPI): disable_strict_validation: Optional[bool] | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Union[AsyncCompletion, AsyncTaskStatus]: - _cast_type = Union[AsyncCompletion,AsyncTaskStatus] + _cast_type = Union[AsyncCompletion, AsyncTaskStatus] if disable_strict_validation: _cast_type = object return self._get( path=f"/async-result/{id}", cast_type=_cast_type, - options=make_user_request_input( - extra_headers=extra_headers, - timeout=timeout - ) + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), ) - - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py index 5c4ed4d1ba..ec29f33864 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py @@ -20,24 +20,24 @@ class Completions(BaseAPI): super().__init__(client) def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], object, None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], object, None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + tools: Optional[object] | NotGiven = NOT_GIVEN, + tool_choice: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Completion | StreamResponse[ChatCompletionChunk]: _cast_type = Completion _stream_cls = StreamResponse[ChatCompletionChunk] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py index 35d54592fd..2308a20451 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py @@ -18,16 +18,16 @@ class Embeddings(BaseAPI): super().__init__(client) def create( - self, - *, - input: Union[str, list[str], list[int], list[list[int]]], - model: Union[str], - encoding_format: str | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + input: Union[str, list[str], list[int], list[list[int]]], + model: Union[str], + encoding_format: str | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> EmbeddingsResponded: _cast_type = EmbeddingsResponded if disable_strict_validation: @@ -41,9 +41,7 @@ class Embeddings(BaseAPI): "user": user, "sensitive_word_check": sensitive_word_check, }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=_cast_type, enable_stream=False, ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py index 5deb8d08f3..f2ac74bffa 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py @@ -17,17 +17,16 @@ __all__ = ["Files"] class Files(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( - self, - *, - file: FileTypes, - purpose: str, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + file: FileTypes, + purpose: str, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FileObject: if not is_file_content(file): prefix = f"Expected file input `{file!r}`" @@ -44,21 +43,19 @@ class Files(BaseAPI): "purpose": purpose, }, files=files, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=FileObject, ) def list( - self, - *, - purpose: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - after: str | NotGiven = NOT_GIVEN, - order: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + purpose: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + after: str | NotGiven = NOT_GIVEN, + order: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ListOfFileObject: return self._get( "/files", diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py index dc54a9ca45..dc30bd33ed 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py @@ -13,4 +13,3 @@ class FineTuning(BaseAPI): def __init__(self, client: "ZhipuAI") -> None: super().__init__(client) self.jobs = Jobs(client) - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py index b860de192a..3d2e9208a1 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py @@ -16,21 +16,20 @@ __all__ = ["Jobs"] class Jobs(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( - self, - *, - model: str, - training_file: str, - hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, - suffix: Optional[str] | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - validation_file: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + training_file: str, + hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, + suffix: Optional[str] | NotGiven = NOT_GIVEN, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + validation_file: Optional[str] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJob: return self._post( "/fine_tuning/jobs", @@ -42,34 +41,30 @@ class Jobs(BaseAPI): "validation_file": validation_file, "request_id": request_id, }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=FineTuningJob, ) def retrieve( - self, - fine_tuning_job_id: str, - *, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + fine_tuning_job_id: str, + *, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJob: return self._get( f"/fine_tuning/jobs/{fine_tuning_job_id}", - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), cast_type=FineTuningJob, ) def list( - self, - *, - after: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + after: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ListOfFineTuningJob: return self._get( "/fine_tuning/jobs", @@ -93,7 +88,6 @@ class Jobs(BaseAPI): extra_headers: Headers | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJobEvent: - return self._get( f"/fine_tuning/jobs/{fine_tuning_job_id}/events", cast_type=FineTuningJobEvent, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py index 8eae1216d0..2692b093af 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py @@ -18,21 +18,21 @@ class Images(BaseAPI): super().__init__(client) def generations( - self, - *, - prompt: str, - model: str | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - quality: Optional[str] | NotGiven = NOT_GIVEN, - response_format: Optional[str] | NotGiven = NOT_GIVEN, - size: Optional[str] | NotGiven = NOT_GIVEN, - style: Optional[str] | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + prompt: str, + model: str | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + quality: Optional[str] | NotGiven = NOT_GIVEN, + response_format: Optional[str] | NotGiven = NOT_GIVEN, + size: Optional[str] | NotGiven = NOT_GIVEN, + style: Optional[str] | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ImagesResponded: _cast_type = ImagesResponded if disable_strict_validation: @@ -50,11 +50,7 @@ class Images(BaseAPI): "user": user, "request_id": request_id, }, - options=make_user_request_input( - extra_headers=extra_headers, - extra_body=extra_body, - timeout=timeout - ), + options=make_user_request_input(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), cast_type=_cast_type, enable_stream=False, ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py index a2a438b8f3..1027c1bc5b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py @@ -17,7 +17,10 @@ __all__ = [ class ZhipuAIError(Exception): - def __init__(self, message: str, ) -> None: + def __init__( + self, + message: str, + ) -> None: super().__init__(message) @@ -31,24 +34,19 @@ class APIStatusError(Exception): self.status_code = response.status_code -class APIRequestFailedError(APIStatusError): - ... +class APIRequestFailedError(APIStatusError): ... -class APIAuthenticationError(APIStatusError): - ... +class APIAuthenticationError(APIStatusError): ... -class APIReachLimitError(APIStatusError): - ... +class APIReachLimitError(APIStatusError): ... -class APIInternalError(APIStatusError): - ... +class APIInternalError(APIStatusError): ... -class APIServerFlowExceedError(APIStatusError): - ... +class APIServerFlowExceedError(APIStatusError): ... class APIResponseError(Exception): @@ -67,16 +65,11 @@ class APIResponseValidationError(APIResponseError): status_code: int response: httpx.Response - def __init__( - self, - response: httpx.Response, - json_data: object | None, *, - message: str | None = None - ) -> None: + def __init__(self, response: httpx.Response, json_data: object | None, *, message: str | None = None) -> None: super().__init__( message=message or "Data returned by API invalid for expected schema.", request=response.request, - json_data=json_data + json_data=json_data, ) self.response = response self.status_code = response.status_code diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py index 65401f6c1c..48eeb37c41 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py @@ -48,13 +48,13 @@ class HttpClient: _default_stream_cls: type[StreamResponse[Any]] | None = None def __init__( - self, - *, - version: str, - base_url: URL, - timeout: Union[float, Timeout, None], - custom_httpx_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None, + self, + *, + version: str, + base_url: URL, + timeout: Union[float, Timeout, None], + custom_httpx_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None, ) -> None: if timeout is None or isinstance(timeout, NotGiven): if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT: @@ -76,7 +76,6 @@ class HttpClient: self._custom_headers = custom_headers or {} def _prepare_url(self, url: str) -> URL: - sub_url = URL(url) if sub_url.is_relative_url: request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/") @@ -86,16 +85,15 @@ class HttpClient: @property def _default_headers(self): - return \ - { - "Accept": "application/json", - "Content-Type": "application/json; charset=UTF-8", - "ZhipuAI-SDK-Ver": self._version, - "source_type": "zhipu-sdk-python", - "x-request-sdk": "zhipu-sdk-python", - **self._auth_headers, - **self._custom_headers, - } + return { + "Accept": "application/json", + "Content-Type": "application/json; charset=UTF-8", + "ZhipuAI-SDK-Ver": self._version, + "source_type": "zhipu-sdk-python", + "x-request-sdk": "zhipu-sdk-python", + **self._auth_headers, + **self._custom_headers, + } @property def _auth_headers(self): @@ -109,10 +107,7 @@ class HttpClient: return httpx_headers - def _prepare_request( - self, - request_param: ClientRequestParam - ) -> httpx.Request: + def _prepare_request(self, request_param: ClientRequestParam) -> httpx.Request: kwargs: dict[str, Any] = {} json_data = request_param.json_data headers = self._prepare_headers(request_param) @@ -164,7 +159,6 @@ class HttpClient: return [(key, str_data)] def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: - items = flatten([self._object_to_formdata(k, v) for k, v in data.items()]) serialized: dict[str, object] = {} @@ -175,30 +169,25 @@ class HttpClient: return serialized def _parse_response( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - enable_stream: bool, - request_param: ClientRequestParam, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + cast_type: type[ResponseT], + response: httpx.Response, + enable_stream: bool, + request_param: ClientRequestParam, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> HttpResponse: - http_response = HttpResponse( - raw_response=response, - cast_type=cast_type, - client=self, - enable_stream=enable_stream, - stream_cls=stream_cls + raw_response=response, cast_type=cast_type, client=self, enable_stream=enable_stream, stream_cls=stream_cls ) return http_response.parse() def _process_response_data( - self, - *, - data: object, - cast_type: type[ResponseT], - response: httpx.Response, + self, + *, + data: object, + cast_type: type[ResponseT], + response: httpx.Response, ) -> ResponseT: if data is None: return cast(ResponseT, None) @@ -225,12 +214,12 @@ class HttpClient: @retry(stop=stop_after_attempt(ZHIPUAI_DEFAULT_MAX_RETRIES)) def request( - self, - *, - cast_type: type[ResponseT], - params: ClientRequestParam, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + cast_type: type[ResponseT], + params: ClientRequestParam, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> ResponseT | StreamResponse: request = self._prepare_request(params) @@ -259,81 +248,79 @@ class HttpClient: ) def get( - self, - path: str, - *, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - enable_stream: bool = False, + self, + path: str, + *, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + enable_stream: bool = False, ) -> ResponseT | StreamResponse: opts = ClientRequestParam.construct(method="get", url=path, **options) - return self.request( - cast_type=cast_type, params=opts, - enable_stream=enable_stream - ) + return self.request(cast_type=cast_type, params=opts, enable_stream=enable_stream) def post( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - files: RequestFiles | None = None, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + files: RequestFiles | None = None, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path, - **options) - - return self.request( - cast_type=cast_type, params=opts, - enable_stream=enable_stream, - stream_cls=stream_cls + opts = ClientRequestParam.construct( + method="post", json_data=body, files=make_httpx_files(files), url=path, **options ) + return self.request(cast_type=cast_type, params=opts, enable_stream=enable_stream, stream_cls=stream_cls) + def patch( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, ) -> ResponseT: opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def put( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - files: RequestFiles | None = None, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + files: RequestFiles | None = None, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files), - **options) + opts = ClientRequestParam.construct( + method="put", url=path, json_data=body, files=make_httpx_files(files), **options + ) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def delete( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, ) -> ResponseT | StreamResponse: opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def _make_status_error(self, response) -> APIStatusError: @@ -355,11 +342,11 @@ class HttpClient: def make_user_request_input( - max_retries: int | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, - extra_headers: Headers = None, - extra_body: Body | None = None, - query: Query | None = None, + max_retries: int | None = None, + timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + extra_headers: Headers = None, + extra_body: Body | None = None, + query: Query | None = None, ) -> UserRequestInput: options: UserRequestInput = {} @@ -368,7 +355,7 @@ def make_user_request_input( if max_retries is not None: options["max_retries"] = max_retries if not isinstance(timeout, NotGiven): - options['timeout'] = timeout + options["timeout"] = timeout if query is not None: options["params"] = query if extra_body is not None: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py index a3f49ba846..ac459151fc 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py @@ -35,17 +35,14 @@ class ClientRequestParam: @classmethod def construct( # type: ignore - cls, - _fields_set: set[str] | None = None, - **values: Unpack[UserRequestInput], - ) -> ClientRequestParam : - kwargs: dict[str, Any] = { - key: remove_notgiven_indict(value) for key, value in values.items() - } + cls, + _fields_set: set[str] | None = None, + **values: Unpack[UserRequestInput], + ) -> ClientRequestParam: + kwargs: dict[str, Any] = {key: remove_notgiven_indict(value) for key, value in values.items()} client = cls() client.__dict__.update(kwargs) return client model_construct = construct - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py index 2f831b6fc9..56e60a7934 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py @@ -26,13 +26,13 @@ class HttpResponse(Generic[R]): http_response: httpx.Response def __init__( - self, - *, - raw_response: httpx.Response, - cast_type: type[R], - client: HttpClient, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + raw_response: httpx.Response, + cast_type: type[R], + client: HttpClient, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> None: self._cast_type = cast_type self._client = client @@ -52,8 +52,8 @@ class HttpResponse(Generic[R]): self._stream_cls( cast_type=cast(type, get_args(self._stream_cls)[0]), response=self.http_response, - client=self._client - ) + client=self._client, + ), ) return self._parsed cast_type = self._cast_type diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py index 66afbfd107..3566c6b332 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py @@ -16,16 +16,15 @@ if TYPE_CHECKING: class StreamResponse(Generic[ResponseT]): - response: httpx.Response _cast_type: type[ResponseT] def __init__( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - client: HttpClient, + self, + *, + cast_type: type[ResponseT], + response: httpx.Response, + client: HttpClient, ) -> None: self.response = response self._cast_type = cast_type @@ -39,7 +38,6 @@ class StreamResponse(Generic[ResponseT]): yield from self._stream_chunks def __stream__(self) -> Iterator[ResponseT]: - sse_line_parser = SSELineParser() iterator = sse_line_parser.iter_lines(self.response.iter_lines()) @@ -63,11 +61,7 @@ class StreamResponse(Generic[ResponseT]): class Event: def __init__( - self, - event: str | None = None, - data: str | None = None, - id: str | None = None, - retry: int | None = None + self, event: str | None = None, data: str | None = None, id: str | None = None, retry: int | None = None ): self._event = event self._data = data @@ -76,21 +70,28 @@ class Event: def __repr__(self): data_len = len(self._data) if self._data else 0 - return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" + return ( + f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" + ) @property - def event(self): return self._event + def event(self): + return self._event @property - def data(self): return self._data + def data(self): + return self._data - def json_data(self): return json.loads(self._data) + def json_data(self): + return json.loads(self._data) @property - def id(self): return self._id + def id(self): + return self._id @property - def retry(self): return self._retry + def retry(self): + return self._retry class SSELineParser: @@ -107,19 +108,11 @@ class SSELineParser: def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]: for line in lines: - line = line.rstrip('\n') + line = line.rstrip("\n") if not line: - if self._event is None and \ - not self._data and \ - self._id is None and \ - self._retry is None: + if self._event is None and not self._data and self._id is None and self._retry is None: continue - sse_event = Event( - event=self._event, - data='\n'.join(self._data), - id=self._id, - retry=self._retry - ) + sse_event = Event(event=self._event, data="\n".join(self._data), id=self._id, retry=self._retry) self._event = None self._data = [] self._id = None @@ -134,7 +127,7 @@ class SSELineParser: field, _p, value = line.partition(":") - if value.startswith(' '): + if value.startswith(" "): value = value[1:] if field == "data": self._data.append(value) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py index f22f32d251..a0645b0916 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py @@ -20,4 +20,4 @@ class AsyncCompletion(BaseModel): model: Optional[str] = None task_status: str choices: list[CompletionChoice] - usage: CompletionUsage \ No newline at end of file + usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py index b2a847c50c..4b3a929a2b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py @@ -41,5 +41,3 @@ class Completion(BaseModel): request_id: Optional[str] = None id: Optional[str] = None usage: CompletionUsage - - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py index 917bda7576..75f76fe969 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py @@ -6,7 +6,6 @@ __all__ = ["FileObject"] class FileObject(BaseModel): - id: Optional[str] = None bytes: Optional[int] = None created_at: Optional[int] = None @@ -18,7 +17,6 @@ class FileObject(BaseModel): class ListOfFileObject(BaseModel): - object: Optional[str] = None data: list[FileObject] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py index 71c00eaff0..1d3930286b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py @@ -2,7 +2,7 @@ from typing import Optional, Union from pydantic import BaseModel -__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] +__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob"] class Error(BaseModel): diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index fe705d6943..e4f3541475 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -4,9 +4,9 @@ from core.model_runtime.entities.provider_entities import CredentialFormSchema, class CommonValidator: - def _validate_and_filter_credential_form_schemas(self, - credential_form_schemas: list[CredentialFormSchema], - credentials: dict) -> dict: + def _validate_and_filter_credential_form_schemas( + self, credential_form_schemas: list[CredentialFormSchema], credentials: dict + ) -> dict: need_validate_credential_form_schema_map = {} for credential_form_schema in credential_form_schemas: if not credential_form_schema.show_on: @@ -36,8 +36,9 @@ class CommonValidator: return validated_credentials - def _validate_credential_form_schema(self, credential_form_schema: CredentialFormSchema, credentials: dict) \ - -> Optional[str]: + def _validate_credential_form_schema( + self, credential_form_schema: CredentialFormSchema, credentials: dict + ) -> Optional[str]: """ Validate credential form schema @@ -49,7 +50,7 @@ class CommonValidator: if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: # If required is True, an exception is thrown if credential_form_schema.required: - raise ValueError(f'Variable {credential_form_schema.variable} is required') + raise ValueError(f"Variable {credential_form_schema.variable} is required") else: # Get the value of default if credential_form_schema.default: @@ -65,23 +66,25 @@ class CommonValidator: # If max_length=0, no validation is performed if credential_form_schema.max_length: if len(value) > credential_form_schema.max_length: - raise ValueError(f'Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}') + raise ValueError( + f"Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}" + ) # check the type of value if not isinstance(value, str): - raise ValueError(f'Variable {credential_form_schema.variable} should be string') + raise ValueError(f"Variable {credential_form_schema.variable} should be string") if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: # If the value is in options, no validation is performed if credential_form_schema.options: if value not in [option.value for option in credential_form_schema.options]: - raise ValueError(f'Variable {credential_form_schema.variable} is not in options') + raise ValueError(f"Variable {credential_form_schema.variable} is not in options") if credential_form_schema.type == FormType.SWITCH: # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in ['true', 'false']: - raise ValueError(f'Variable {credential_form_schema.variable} should be true or false') + if value.lower() not in ["true", "false"]: + raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") - value = True if value.lower() == 'true' else False + value = True if value.lower() == "true" else False return value diff --git a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py index c4786fad5d..7d1644d134 100644 --- a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py @@ -4,7 +4,6 @@ from core.model_runtime.schema_validators.common_validator import CommonValidato class ModelCredentialSchemaValidator(CommonValidator): - def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): self.model_type = model_type self.model_credential_schema = model_credential_schema diff --git a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py index c945016534..6dff2428ca 100644 --- a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py @@ -3,7 +3,6 @@ from core.model_runtime.schema_validators.common_validator import CommonValidato class ProviderCredentialSchemaValidator(CommonValidator): - def __init__(self, provider_credential_schema: ProviderCredentialSchema): self.provider_credential_schema = provider_credential_schema diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index 5078f00bfa..ec1bad5698 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -18,11 +18,10 @@ from pydantic_core import Url from pydantic_extra_types.color import Color -def _model_dump( - model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any -) -> Any: +def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: return model.model_dump(mode=mode, **kwargs) + # Taken from Pydantic v1 as is def isoformat(o: Union[datetime.date, datetime.time]) -> str: return o.isoformat() @@ -82,11 +81,9 @@ ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { def generate_encoders_by_class_tuples( - type_encoder_map: dict[Any, Callable[[Any], Any]] + type_encoder_map: dict[Any, Callable[[Any], Any]], ) -> dict[Callable[[Any], Any], tuple[Any, ...]]: - encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict( - tuple - ) + encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple) for type_, encoder in type_encoder_map.items(): encoders_by_class_tuples[encoder] += (type_,) return encoders_by_class_tuples @@ -149,17 +146,13 @@ def jsonable_encoder( if isinstance(obj, str | int | float | type(None)): return obj if isinstance(obj, Decimal): - return format(obj, 'f') + return format(obj, "f") if isinstance(obj, dict): encoded_dict = {} allowed_keys = set(obj.keys()) for key, value in obj.items(): if ( - ( - not sqlalchemy_safe - or (not isinstance(key, str)) - or (not key.startswith("_sa")) - ) + (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and (value is not None or not exclude_none) and key in allowed_keys ): diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index c68a554471..2067092d80 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -3,7 +3,7 @@ from pydantic import BaseModel def dump_model(model: BaseModel) -> dict: - if hasattr(pydantic, 'model_dump'): + if hasattr(pydantic, "model_dump"): return pydantic.model_dump(model) else: return model.model_dump() diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index f96e2a1c21..094ad78636 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -44,32 +44,29 @@ class ApiModeration(Moderation): flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - params = ModerationInputParams( - app_id=self.app_id, - inputs=inputs, - query=query - ) + if self.config["inputs_config"]["enabled"]: + params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump()) return ModerationInputsResult(**result) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: - params = ModerationOutputParams( - app_id=self.app_id, - text=text - ) + if self.config["outputs_config"]["enabled"]: + params = ModerationOutputParams(app_id=self.app_id, text=text) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump()) return ModerationOutputsResult(**result) - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) @@ -80,9 +77,10 @@ class ApiModeration(Moderation): @staticmethod def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) return extension diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 757dd2ab46..4b91f20184 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -8,8 +8,8 @@ from core.extension.extensible import Extensible, ExtensionModule class ModerationAction(Enum): - DIRECT_OUTPUT = 'direct_output' - OVERRIDDEN = 'overridden' + DIRECT_OUTPUT = "direct_output" + OVERRIDDEN = "overridden" class ModerationInputsResult(BaseModel): @@ -31,6 +31,7 @@ class Moderation(Extensible, ABC): """ The base class of moderation. """ + module: ExtensionModule = ExtensionModule.MODERATION def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None: diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 46dfacbc9e..336c16eecf 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -13,13 +13,14 @@ logger = logging.getLogger(__name__) class InputModeration: def check( - self, app_id: str, + self, + app_id: str, tenant_id: str, app_config: AppConfig, inputs: dict, query: str, message_id: str, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. @@ -39,10 +40,7 @@ class InputModeration: moderation_type = sensitive_word_avoidance_config.type moderation_factory = ModerationFactory( - name=moderation_type, - app_id=app_id, - tenant_id=tenant_id, - config=sensitive_word_avoidance_config.config + name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config ) with measure_time() as timer: @@ -55,7 +53,7 @@ class InputModeration: message_id=message_id, moderation_result=moderation_result, inputs=inputs, - timer=timer + timer=timer, ) ) diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index ca562ad987..17e48b8fbe 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -25,31 +25,35 @@ class KeywordsModeration(Moderation): flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - preset_response = self.config['inputs_config']['preset_response'] + if self.config["inputs_config"]["enabled"]: + preset_response = self.config["inputs_config"]["preset_response"] if query: - inputs['query__'] = query + inputs["query__"] = query # Filter out empty values - keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword] + keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword] flagged = self._is_violated(inputs, keywords_list) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: + if self.config["outputs_config"]["enabled"]: # Filter out empty values - keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword] + keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword] - flagged = self._is_violated({'text': text}, keywords_list) - preset_response = self.config['outputs_config']['preset_response'] + flagged = self._is_violated({"text": text}, keywords_list) + preset_response = self.config["outputs_config"]["preset_response"] - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _is_violated(self, inputs: dict, keywords_list: list) -> bool: for value in inputs.values(): diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index fee51007eb..6465de23b9 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -21,37 +21,36 @@ class OpenAIModeration(Moderation): flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - preset_response = self.config['inputs_config']['preset_response'] + if self.config["inputs_config"]["enabled"]: + preset_response = self.config["inputs_config"]["preset_response"] if query: - inputs['query__'] = query + inputs["query__"] = query flagged = self._is_violated(inputs) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: - flagged = self._is_violated({'text': text}) - preset_response = self.config['outputs_config']['preset_response'] + if self.config["outputs_config"]["enabled"]: + flagged = self._is_violated({"text": text}) + preset_response = self.config["outputs_config"]["preset_response"] - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _is_violated(self, inputs: dict): - text = '\n'.join(str(inputs.values())) + text = "\n".join(str(inputs.values())) model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - provider="openai", - model_type=ModelType.MODERATION, - model="text-moderation-stable" + tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable" ) - openai_moderation = model_instance.invoke_moderation( - text=text - ) + openai_moderation = model_instance.invoke_moderation(text=text) return openai_moderation diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 69e28770c3..d8d794be18 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -29,7 +29,7 @@ class OutputModeration(BaseModel): thread: Optional[threading.Thread] = None thread_running: bool = True - buffer: str = '' + buffer: str = "" is_final_chunk: bool = False final_output: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -50,11 +50,7 @@ class OutputModeration(BaseModel): self.buffer = completion self.is_final_chunk = True - result = self.moderation( - tenant_id=self.tenant_id, - app_id=self.app_id, - moderation_buffer=completion - ) + result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion) if not result or not result.flagged: return completion @@ -65,21 +61,19 @@ class OutputModeration(BaseModel): final_output = result.text if public_event: - self.queue_manager.publish( - QueueMessageReplaceEvent( - text=final_output - ), - PublishFrom.TASK_PIPELINE - ) + self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) return final_output def start_thread(self) -> threading.Thread: buffer_size = dify_config.MODERATION_BUFFER_SIZE - thread = threading.Thread(target=self.worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'buffer_size': buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE - }) + thread = threading.Thread( + target=self.worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE, + }, + ) thread.start() @@ -104,9 +98,7 @@ class OutputModeration(BaseModel): current_length = buffer_length result = self.moderation( - tenant_id=self.tenant_id, - app_id=self.app_id, - moderation_buffer=moderation_buffer + tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer ) if not result or not result.flagged: @@ -116,16 +108,11 @@ class OutputModeration(BaseModel): final_output = result.preset_response self.final_output = final_output else: - final_output = result.text + self.buffer[len(moderation_buffer):] + final_output = result.text + self.buffer[len(moderation_buffer) :] # trigger replace event if self.thread_running: - self.queue_manager.publish( - QueueMessageReplaceEvent( - text=final_output - ), - PublishFrom.TASK_PIPELINE - ) + self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) if result.action == ModerationAction.DIRECT_OUTPUT: break @@ -133,10 +120,7 @@ class OutputModeration(BaseModel): def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: try: moderation_factory = ModerationFactory( - name=self.rule.type, - app_id=app_id, - tenant_id=tenant_id, - config=self.rule.config + name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config ) result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index c7af8e2963..f7b882fc71 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -23,4 +23,4 @@ class BaseTraceInstance(ABC): Abstract method to trace activities. Subclasses must implement specific tracing logic for activities. """ - ... \ No newline at end of file + ... diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 221e6239ab..0ab2139a88 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -4,14 +4,15 @@ from pydantic import BaseModel, ValidationInfo, field_validator class TracingProviderEnum(Enum): - LANGFUSE = 'langfuse' - LANGSMITH = 'langsmith' + LANGFUSE = "langfuse" + LANGSMITH = "langsmith" class BaseTracingConfig(BaseModel): """ Base model class for tracing """ + ... @@ -19,16 +20,17 @@ class LangfuseConfig(BaseTracingConfig): """ Model class for Langfuse tracing config. """ + public_key: str secret_key: str - host: str = 'https://api.langfuse.com' + host: str = "https://api.langfuse.com" @field_validator("host") def set_value(cls, v, info: ValidationInfo): if v is None or v == "": - v = 'https://api.langfuse.com' - if not v.startswith('https://') and not v.startswith('http://'): - raise ValueError('host must start with https:// or http://') + v = "https://api.langfuse.com" + if not v.startswith("https://") and not v.startswith("http://"): + raise ValueError("host must start with https:// or http://") return v @@ -37,15 +39,16 @@ class LangSmithConfig(BaseTracingConfig): """ Model class for Langsmith tracing config. """ + api_key: str project: str - endpoint: str = 'https://api.smith.langchain.com' + endpoint: str = "https://api.smith.langchain.com" @field_validator("endpoint") def set_value(cls, v, info: ValidationInfo): if v is None or v == "": - v = 'https://api.smith.langchain.com' - if not v.startswith('https://'): - raise ValueError('endpoint must start with https://') + v = "https://api.smith.langchain.com" + if not v.startswith("https://"): + raise ValueError("endpoint must start with https://") return v diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index a1443f0691..a3ce27d5d4 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -23,6 +23,7 @@ class BaseTraceInfo(BaseModel): else: return "" + class WorkflowTraceInfo(BaseTraceInfo): workflow_data: Any conversation_id: Optional[str] = None @@ -98,23 +99,24 @@ class GenerateNameTraceInfo(BaseTraceInfo): conversation_id: Optional[str] = None tenant_id: str + trace_info_info_map = { - 'WorkflowTraceInfo': WorkflowTraceInfo, - 'MessageTraceInfo': MessageTraceInfo, - 'ModerationTraceInfo': ModerationTraceInfo, - 'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo, - 'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo, - 'ToolTraceInfo': ToolTraceInfo, - 'GenerateNameTraceInfo': GenerateNameTraceInfo, + "WorkflowTraceInfo": WorkflowTraceInfo, + "MessageTraceInfo": MessageTraceInfo, + "ModerationTraceInfo": ModerationTraceInfo, + "SuggestedQuestionTraceInfo": SuggestedQuestionTraceInfo, + "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo, + "ToolTraceInfo": ToolTraceInfo, + "GenerateNameTraceInfo": GenerateNameTraceInfo, } class TraceTaskName(str, Enum): - CONVERSATION_TRACE = 'conversation' - WORKFLOW_TRACE = 'workflow' - MESSAGE_TRACE = 'message' - MODERATION_TRACE = 'moderation' - SUGGESTED_QUESTION_TRACE = 'suggested_question' - DATASET_RETRIEVAL_TRACE = 'dataset_retrieval' - TOOL_TRACE = 'tool' - GENERATE_NAME_TRACE = 'generate_conversation_name' + CONVERSATION_TRACE = "conversation" + WORKFLOW_TRACE = "workflow" + MESSAGE_TRACE = "message" + MODERATION_TRACE = "moderation" + SUGGESTED_QUESTION_TRACE = "suggested_question" + DATASET_RETRIEVAL_TRACE = "dataset_retrieval" + TOOL_TRACE = "tool" + GENERATE_NAME_TRACE = "generate_conversation_name" diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index f3fc46d99a..8cbf162bf2 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -35,38 +35,20 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): run_type: LangSmithRunType = Field(..., description="Type of the run") start_time: Optional[datetime | str] = Field(None, description="Start time of the run") end_time: Optional[datetime | str] = Field(None, description="End time of the run") - extra: Optional[dict[str, Any]] = Field( - None, description="Extra information of the run" - ) + extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") error: Optional[str] = Field(None, description="Error message of the run") - serialized: Optional[dict[str, Any]] = Field( - None, description="Serialized data of the run" - ) + serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run") parent_run_id: Optional[str] = Field(None, description="Parent run ID") - events: Optional[list[dict[str, Any]]] = Field( - None, description="Events associated with the run" - ) + events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - trace_id: Optional[str] = Field( - None, description="Trace ID associated with the run" - ) + trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") dotted_order: Optional[str] = Field(None, description="Dotted order of the run") id: Optional[str] = Field(None, description="ID of the run") - session_id: Optional[str] = Field( - None, description="Session ID associated with the run" - ) - session_name: Optional[str] = Field( - None, description="Session name associated with the run" - ) - reference_example_id: Optional[str] = Field( - None, description="Reference example ID associated with the run" - ) - input_attachments: Optional[dict[str, Any]] = Field( - None, description="Input attachments of the run" - ) - output_attachments: Optional[dict[str, Any]] = Field( - None, description="Output attachments of the run" - ) + session_id: Optional[str] = Field(None, description="Session ID associated with the run") + session_name: Optional[str] = Field(None, description="Session name associated with the run") + reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") + input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") + output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") @field_validator("inputs", "outputs") def ensure_dict(cls, v, info: ValidationInfo): @@ -75,9 +57,9 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): if v == {} or v is None: return v usage_metadata = { - "input_tokens": values.get('input_tokens', 0), - "output_tokens": values.get('output_tokens', 0), - "total_tokens": values.get('total_tokens', 0), + "input_tokens": values.get("input_tokens", 0), + "output_tokens": values.get("output_tokens", 0), + "total_tokens": values.get("total_tokens", 0), } file_list = values.get("file_list", []) if isinstance(v, str): @@ -143,25 +125,15 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): class LangSmithRunUpdateModel(BaseModel): run_id: str = Field(..., description="ID of the run") - trace_id: Optional[str] = Field( - None, description="Trace ID associated with the run" - ) + trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") dotted_order: Optional[str] = Field(None, description="Dotted order of the run") parent_run_id: Optional[str] = Field(None, description="Parent run ID") end_time: Optional[datetime | str] = Field(None, description="End time of the run") error: Optional[str] = Field(None, description="Error message of the run") inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run") outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run") - events: Optional[list[dict[str, Any]]] = Field( - None, description="Events associated with the run" - ) + events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - extra: Optional[dict[str, Any]] = Field( - None, description="Extra information of the run" - ) - input_attachments: Optional[dict[str, Any]] = Field( - None, description="Input attachments of the run" - ) - output_attachments: Optional[dict[str, Any]] = Field( - None, description="Output attachments of the run" - ) + extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") + input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") + output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 9cbc805fe7..eea7bb3535 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -159,8 +159,8 @@ class LangSmithDataTrace(BaseTraceInstance): run_type = LangSmithRunType.llm metadata.update( { - 'ls_provider': process_data.get('model_provider', ''), - 'ls_model_name': process_data.get('model_name', ''), + "ls_provider": process_data.get("model_provider", ""), + "ls_model_name": process_data.get("model_name", ""), } ) elif node_type == "knowledge-retrieval": @@ -385,12 +385,10 @@ class LangSmithDataTrace(BaseTraceInstance): start_time=datetime.now(), ) - project_url = self.langsmith_client.get_run_url(run=run_data, - project_id=self.project_id, - project_name=self.project_name) - return project_url.split('/r/')[0] + project_url = self.langsmith_client.get_run_url( + run=run_data, project_id=self.project_id, project_name=self.project_name + ) + return project_url.split("/r/")[0] except Exception as e: logger.debug(f"LangSmith get run url failed: {str(e)}") raise ValueError(f"LangSmith get run url failed: {str(e)}") - - diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index aefab6ed16..d6156e479a 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -36,17 +36,17 @@ from tasks.ops_trace_task import process_trace_tasks provider_config_map = { TracingProviderEnum.LANGFUSE.value: { - 'config_class': LangfuseConfig, - 'secret_keys': ['public_key', 'secret_key'], - 'other_keys': ['host', 'project_key'], - 'trace_instance': LangFuseDataTrace + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, }, TracingProviderEnum.LANGSMITH.value: { - 'config_class': LangSmithConfig, - 'secret_keys': ['api_key'], - 'other_keys': ['project', 'endpoint'], - 'trace_instance': LangSmithDataTrace - } + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + }, } @@ -64,14 +64,17 @@ class OpsTraceManager: :return: encrypted tracing configuration """ # Get the configuration class and the keys that require encryption - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} # Encrypt necessary keys for key in secret_keys: if key in tracing_config: - if '*' in tracing_config[key]: + if "*" in tracing_config[key]: # If the key contains '*', retain the original value from the current config new_config[key] = current_trace_config.get(key, tracing_config[key]) else: @@ -94,8 +97,11 @@ class OpsTraceManager: :param tracing_config: tracing config :return: """ - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} for key in secret_keys: if key in tracing_config: @@ -114,8 +120,11 @@ class OpsTraceManager: :param decrypt_tracing_config: tracing config :return: """ - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} for key in secret_keys: if key in decrypt_tracing_config: @@ -133,9 +142,11 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config_data: return None @@ -164,21 +175,21 @@ class OpsTraceManager: if app_id is None: return None - app: App = db.session.query(App).filter( - App.id == app_id - ).first() + app: App = db.session.query(App).filter(App.id == app_id).first() app_ops_trace_config = json.loads(app.tracing) if app.tracing else None if app_ops_trace_config is not None: - tracing_provider = app_ops_trace_config.get('tracing_provider') + tracing_provider = app_ops_trace_config.get("tracing_provider") else: return None # decrypt_token decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider) - if app_ops_trace_config.get('enabled'): - trace_instance, config_class = provider_config_map[tracing_provider]['trace_instance'], \ - provider_config_map[tracing_provider]['config_class'] + if app_ops_trace_config.get("enabled"): + trace_instance, config_class = ( + provider_config_map[tracing_provider]["trace_instance"], + provider_config_map[tracing_provider]["config_class"], + ) tracing_instance = trace_instance(config_class(**decrypt_trace_config)) return tracing_instance @@ -192,9 +203,11 @@ class OpsTraceManager: conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if conversation_data.app_model_config_id: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation_data.app_model_config_id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter(AppModelConfig.id == conversation_data.app_model_config_id) + .first() + ) elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: app_model_config = conversation_data.override_model_configs @@ -231,10 +244,7 @@ class OpsTraceManager: """ app: App = db.session.query(App).filter(App.id == app_id).first() if not app.tracing: - return { - "enabled": False, - "tracing_provider": None - } + return {"enabled": False, "tracing_provider": None} app_trace_config = json.loads(app.tracing) return app_trace_config @@ -246,8 +256,10 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['trace_instance'] + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).api_check() @@ -259,8 +271,10 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['trace_instance'] + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).get_project_key() @@ -272,8 +286,10 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['trace_instance'] + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).get_project_url() @@ -287,7 +303,7 @@ class TraceTask: conversation_id: Optional[str] = None, user_id: Optional[str] = None, timer: Optional[Any] = None, - **kwargs + **kwargs, ): self.trace_type = trace_type self.message_id = message_id @@ -310,9 +326,7 @@ class TraceTask: self.workflow_run, self.conversation_id, self.user_id ), TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id), - TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( - self.message_id, self.timer, **self.kwargs - ), + TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( self.message_id, self.timer, **self.kwargs ), @@ -337,12 +351,8 @@ class TraceTask: workflow_run_id = workflow_run.id workflow_run_elapsed_time = workflow_run.elapsed_time workflow_run_status = workflow_run.status - workflow_run_inputs = ( - json.loads(workflow_run.inputs) if workflow_run.inputs else {} - ) - workflow_run_outputs = ( - json.loads(workflow_run.outputs) if workflow_run.outputs else {} - ) + workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {} + workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {} workflow_run_version = workflow_run.version error = workflow_run.error if workflow_run.error else "" @@ -352,17 +362,18 @@ class TraceTask: query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" # get workflow_app_log_id - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - tenant_id=tenant_id, - app_id=workflow_run.app_id, - workflow_run_id=workflow_run.id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog) + .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id) + .first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None # get message_id - message_data = db.session.query(Message.id).filter_by( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id - ).first() + message_data = ( + db.session.query(Message.id) + .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id) + .first() + ) message_id = str(message_data.id) if message_data else None metadata = { @@ -470,9 +481,9 @@ class TraceTask: # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - workflow_run_id=message_data.workflow_run_id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None moderation_trace_info = ModerationTraceInfo( @@ -510,9 +521,9 @@ class TraceTask: # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - workflow_run_id=message_data.workflow_run_id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None suggested_question_trace_info = SuggestedQuestionTraceInfo( @@ -569,9 +580,9 @@ class TraceTask: return dataset_retrieval_trace_info def tool_trace(self, message_id, timer, **kwargs): - tool_name = kwargs.get('tool_name') - tool_inputs = kwargs.get('tool_inputs') - tool_outputs = kwargs.get('tool_outputs') + tool_name = kwargs.get("tool_name") + tool_inputs = kwargs.get("tool_inputs") + tool_outputs = kwargs.get("tool_outputs") message_data = get_message_data(message_id) if not message_data: return {} @@ -586,11 +597,11 @@ class TraceTask: if tool_name in agent_thought.tools: created_time = agent_thought.created_at tool_meta_data = agent_thought.tool_meta.get(tool_name, {}) - tool_config = tool_meta_data.get('tool_config', {}) - time_cost = tool_meta_data.get('time_cost', 0) + tool_config = tool_meta_data.get("tool_config", {}) + time_cost = tool_meta_data.get("time_cost", 0) end_time = created_time + timedelta(seconds=time_cost) - error = tool_meta_data.get('error', "") - tool_parameters = tool_meta_data.get('tool_parameters', {}) + error = tool_meta_data.get("error", "") + tool_parameters = tool_meta_data.get("tool_parameters", {}) metadata = { "message_id": message_id, "tool_name": tool_name, @@ -715,9 +726,7 @@ class TraceQueueManager: def start_timer(self): global trace_manager_timer if trace_manager_timer is None or not trace_manager_timer.is_alive(): - trace_manager_timer = threading.Timer( - trace_manager_interval, self.run - ) + trace_manager_timer = threading.Timer(trace_manager_interval, self.run) trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}" trace_manager_timer.daemon = False trace_manager_timer.start() diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 3b2e04abb7..498685b342 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -20,19 +20,19 @@ def get_message_data(message_id): @contextmanager def measure_time(): - timing_info = {'start': datetime.now(), 'end': None} + timing_info = {"start": datetime.now(), "end": None} try: yield timing_info finally: - timing_info['end'] = datetime.now() + timing_info["end"] = datetime.now() def replace_text_with_content(data): if isinstance(data, dict): new_data = {} for key, value in data.items(): - if key == 'text': - new_data['content'] = value + if key == "text": + new_data["content"] = value else: new_data[key] = replace_text_with_content(value) return new_data diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 22420fea2c..ce8038d14e 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -22,18 +22,22 @@ class AdvancedPromptTransform(PromptTransform): """ Advanced Prompt Transform for Workflow LLM Node. """ + def __init__(self, with_variable_tmpl: bool = False) -> None: self.with_variable_tmpl = with_variable_tmpl - def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], - inputs: dict, - query: str, - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None) -> list[PromptMessage]: + def get_prompt( + self, + prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], + inputs: dict, + query: str, + files: list[FileVar], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + query_prompt_template: Optional[str] = None, + ) -> list[PromptMessage]: inputs = {key: str(value) for key, value in inputs.items()} prompt_messages = [] @@ -48,7 +52,7 @@ class AdvancedPromptTransform(PromptTransform): context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) elif model_mode == ModelMode.CHAT: prompt_messages = self._get_chat_model_prompt_messages( @@ -60,20 +64,22 @@ class AdvancedPromptTransform(PromptTransform): context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages - def _get_completion_model_prompt_messages(self, - prompt_template: CompletionModelPromptTemplate, - inputs: dict, - query: Optional[str], - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + def _get_completion_model_prompt_messages( + self, + prompt_template: CompletionModelPromptTemplate, + inputs: dict, + query: Optional[str], + files: list[FileVar], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> list[PromptMessage]: """ Get completion model prompt messages. """ @@ -81,7 +87,7 @@ class AdvancedPromptTransform(PromptTransform): prompt_messages = [] - if prompt_template.edition_type == 'basic' or not prompt_template.edition_type: + if prompt_template.edition_type == "basic" or not prompt_template.edition_type: prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} @@ -96,15 +102,13 @@ class AdvancedPromptTransform(PromptTransform): role_prefix=role_prefix, prompt_template=prompt_template, prompt_inputs=prompt_inputs, - model_config=model_config + model_config=model_config, ) if query: prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) - prompt = prompt_template.format( - prompt_inputs - ) + prompt = prompt_template.format(prompt_inputs) else: prompt = raw_prompt prompt_inputs = inputs @@ -122,16 +126,18 @@ class AdvancedPromptTransform(PromptTransform): return prompt_messages - def _get_chat_model_prompt_messages(self, - prompt_template: list[ChatModelMessage], - inputs: dict, - query: Optional[str], - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None) -> list[PromptMessage]: + def _get_chat_model_prompt_messages( + self, + prompt_template: list[ChatModelMessage], + inputs: dict, + query: Optional[str], + files: list[FileVar], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + query_prompt_template: Optional[str] = None, + ) -> list[PromptMessage]: """ Get chat model prompt messages. """ @@ -142,22 +148,20 @@ class AdvancedPromptTransform(PromptTransform): for prompt_item in raw_prompt_list: raw_prompt = prompt_item.text - if prompt_item.edition_type == 'basic' or not prompt_item.edition_type: + if prompt_item.edition_type == "basic" or not prompt_item.edition_type: prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - prompt = prompt_template.format( - prompt_inputs - ) - elif prompt_item.edition_type == 'jinja2': + prompt = prompt_template.format(prompt_inputs) + elif prompt_item.edition_type == "jinja2": prompt = raw_prompt prompt_inputs = inputs prompt = Jinja2Formatter.format(prompt, prompt_inputs) else: - raise ValueError(f'Invalid edition type: {prompt_item.edition_type}') + raise ValueError(f"Invalid edition type: {prompt_item.edition_type}") if prompt_item.role == PromptMessageRole.USER: prompt_messages.append(UserPromptMessage(content=prompt)) @@ -168,17 +172,14 @@ class AdvancedPromptTransform(PromptTransform): if query and query_prompt_template: prompt_template = PromptTemplateParser( - template=query_prompt_template, - with_variable_tmpl=self.with_variable_tmpl + template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl ) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt_inputs['#sys.query#'] = query + prompt_inputs["#sys.query#"] = query prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - query = prompt_template.format( - prompt_inputs - ) + query = prompt_template.format(prompt_inputs) if memory and memory_config: prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) @@ -203,7 +204,7 @@ class AdvancedPromptTransform(PromptTransform): last_message.content = prompt_message_contents else: - prompt_message_contents = [TextPromptMessageContent(data='')] # not for query + prompt_message_contents = [TextPromptMessageContent(data="")] # not for query for file in files: prompt_message_contents.append(file.prompt_message_content) @@ -220,38 +221,39 @@ class AdvancedPromptTransform(PromptTransform): return prompt_messages def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if '#context#' in prompt_template.variable_keys: + if "#context#" in prompt_template.variable_keys: if context: - prompt_inputs['#context#'] = context + prompt_inputs["#context#"] = context else: - prompt_inputs['#context#'] = '' + prompt_inputs["#context#"] = "" return prompt_inputs def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if '#query#' in prompt_template.variable_keys: + if "#query#" in prompt_template.variable_keys: if query: - prompt_inputs['#query#'] = query + prompt_inputs["#query#"] = query else: - prompt_inputs['#query#'] = '' + prompt_inputs["#query#"] = "" return prompt_inputs - def _set_histories_variable(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - raw_prompt: str, - role_prefix: MemoryConfig.RolePrefix, - prompt_template: PromptTemplateParser, - prompt_inputs: dict, - model_config: ModelConfigWithCredentialsEntity) -> dict: - if '#histories#' in prompt_template.variable_keys: + def _set_histories_variable( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + raw_prompt: str, + role_prefix: MemoryConfig.RolePrefix, + prompt_template: PromptTemplateParser, + prompt_inputs: dict, + model_config: ModelConfigWithCredentialsEntity, + ) -> dict: + if "#histories#" in prompt_template.variable_keys: if memory: - inputs = {'#histories#': '', **prompt_inputs} + inputs = {"#histories#": "", **prompt_inputs} prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - tmp_human_message = UserPromptMessage( - content=prompt_template.format(prompt_inputs) - ) + tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs)) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) @@ -260,10 +262,10 @@ class AdvancedPromptTransform(PromptTransform): memory_config=memory_config, max_token_limit=rest_tokens, human_prefix=role_prefix.user, - ai_prefix=role_prefix.assistant + ai_prefix=role_prefix.assistant, ) - prompt_inputs['#histories#'] = histories + prompt_inputs["#histories#"] = histories else: - prompt_inputs['#histories#'] = '' + prompt_inputs["#histories#"] = "" return prompt_inputs diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index af0075ea91..caa1793ea8 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -17,12 +17,14 @@ class AgentHistoryPromptTransform(PromptTransform): """ History Prompt Transform for Agent App """ - def __init__(self, - model_config: ModelConfigWithCredentialsEntity, - prompt_messages: list[PromptMessage], - history_messages: list[PromptMessage], - memory: Optional[TokenBufferMemory] = None, - ): + + def __init__( + self, + model_config: ModelConfigWithCredentialsEntity, + prompt_messages: list[PromptMessage], + history_messages: list[PromptMessage], + memory: Optional[TokenBufferMemory] = None, + ): self.model_config = model_config self.prompt_messages = prompt_messages self.history_messages = history_messages @@ -45,9 +47,7 @@ class AgentHistoryPromptTransform(PromptTransform): model_type_instance = cast(LargeLanguageModel, model_type_instance) curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, - self.memory.model_instance.credentials, - self.history_messages + self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages ) if curr_message_tokens <= max_token_limit: return self.history_messages @@ -63,9 +63,7 @@ class AgentHistoryPromptTransform(PromptTransform): # a message is start with UserPromptMessage if isinstance(prompt_message, UserPromptMessage): curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, - self.memory.model_instance.credentials, - prompt_messages + self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages ) # if current message token is overflow, drop all the prompts in current message and break if curr_message_tokens > max_token_limit: diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 61df69163c..c8e7b414df 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -9,27 +9,31 @@ class ChatModelMessage(BaseModel): """ Chat Message. """ + text: str role: PromptMessageRole - edition_type: Optional[Literal['basic', 'jinja2']] = None + edition_type: Optional[Literal["basic", "jinja2"]] = None class CompletionModelPromptTemplate(BaseModel): """ Completion Model Prompt Template. """ + text: str - edition_type: Optional[Literal['basic', 'jinja2']] = None + edition_type: Optional[Literal["basic", "jinja2"]] = None class MemoryConfig(BaseModel): """ Memory Config. """ + class RolePrefix(BaseModel): """ Role Prefix. """ + user: str assistant: str @@ -37,6 +41,7 @@ class MemoryConfig(BaseModel): """ Window Config. """ + enabled: bool size: Optional[int] = None diff --git a/api/core/prompt/prompt_templates/advanced_prompt_templates.py b/api/core/prompt/prompt_templates/advanced_prompt_templates.py index da40534d99..e4b3a61cb4 100644 --- a/api/core/prompt/prompt_templates/advanced_prompt_templates.py +++ b/api/core/prompt/prompt_templates/advanced_prompt_templates.py @@ -7,39 +7,18 @@ CHAT_APP_COMPLETION_PROMPT_CONFIG = { "prompt": { "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant: " }, - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - } + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, }, - "stop": ["Human:"] + "stop": ["Human:"], } -CHAT_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "system", - "text": "{{#pre_prompt#}}" - }] - } -} +CHAT_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}} -COMPLETION_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "user", - "text": "{{#pre_prompt#}}" - }] - } -} +COMPLETION_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}} COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { - "completion_prompt_config": { - "prompt": { - "text": "{{#pre_prompt#}}" - } - }, - "stop": ["Human:"] + "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, + "stop": ["Human:"], } BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = { @@ -47,37 +26,20 @@ BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = { "prompt": { "text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" }, - "conversation_histories_role": { - "user_prefix": "用户", - "assistant_prefix": "助手" - } + "conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"}, }, - "stop": ["用户:"] + "stop": ["用户:"], } -BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "system", - "text": "{{#pre_prompt#}}" - }] - } +BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { + "chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]} } BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "user", - "text": "{{#pre_prompt#}}" - }] - } + "chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]} } BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { - "completion_prompt_config": { - "prompt": { - "text": "{{#pre_prompt#}}" - } - }, - "stop": ["用户:"] + "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, + "stop": ["用户:"], } diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index b86d3fa815..87acdb3c49 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -9,75 +9,78 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig class PromptTransform: - def _append_chat_histories(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - prompt_messages: list[PromptMessage], - model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + def _append_chat_histories( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + prompt_messages: list[PromptMessage], + model_config: ModelConfigWithCredentialsEntity, + ) -> list[PromptMessage]: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) prompt_messages.extend(histories) return prompt_messages - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], - model_config: ModelConfigWithCredentialsEntity) -> int: + def _calculate_rest_token( + self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity + ) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) - curr_message_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _get_history_messages_from_memory(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - max_token_limit: int, - human_prefix: Optional[str] = None, - ai_prefix: Optional[str] = None) -> str: + def _get_history_messages_from_memory( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + max_token_limit: int, + human_prefix: Optional[str] = None, + ai_prefix: Optional[str] = None, + ) -> str: """Get memory messages.""" - kwargs = { - "max_token_limit": max_token_limit - } + kwargs = {"max_token_limit": max_token_limit} if human_prefix: - kwargs['human_prefix'] = human_prefix + kwargs["human_prefix"] = human_prefix if ai_prefix: - kwargs['ai_prefix'] = ai_prefix + kwargs["ai_prefix"] = ai_prefix if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0: - kwargs['message_limit'] = memory_config.window.size + kwargs["message_limit"] = memory_config.window.size - return memory.get_history_prompt_text( - **kwargs - ) + return memory.get_history_prompt_text(**kwargs) - def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - max_token_limit: int) -> list[PromptMessage]: + def _get_history_messages_list_from_memory( + self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int + ) -> list[PromptMessage]: """Get memory messages.""" return memory.get_history_prompt_messages( max_token_limit=max_token_limit, message_limit=memory_config.window.size - if (memory_config.window.enabled - and memory_config.window.size is not None - and memory_config.window.size > 0) - else None + if ( + memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0 + ) + else None, ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index fd7ed0181b..13e5c5253e 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -22,11 +22,11 @@ if TYPE_CHECKING: class ModelMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' + COMPLETION = "completion" + CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'ModelMode': + def value_of(cls, value: str) -> "ModelMode": """ Get value of given mode. @@ -36,7 +36,7 @@ class ModelMode(enum.Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") prompt_file_contents = {} @@ -47,16 +47,17 @@ class SimplePromptTransform(PromptTransform): Simple Prompt Transform for Chatbot App Basic Mode. """ - def get_prompt(self, - app_mode: AppMode, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) -> \ - tuple[list[PromptMessage], Optional[list[str]]]: + def get_prompt( + self, + app_mode: AppMode, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list["FileVar"], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: inputs = {key: str(value) for key, value in inputs.items()} model_mode = ModelMode.value_of(model_config.mode) @@ -69,7 +70,7 @@ class SimplePromptTransform(PromptTransform): files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( @@ -80,19 +81,21 @@ class SimplePromptTransform(PromptTransform): files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages, stops - def get_prompt_str_and_rules(self, app_mode: AppMode, - model_config: ModelConfigWithCredentialsEntity, - pre_prompt: str, - inputs: dict, - query: Optional[str] = None, - context: Optional[str] = None, - histories: Optional[str] = None, - ) -> tuple[str, dict]: + def get_prompt_str_and_rules( + self, + app_mode: AppMode, + model_config: ModelConfigWithCredentialsEntity, + pre_prompt: str, + inputs: dict, + query: Optional[str] = None, + context: Optional[str] = None, + histories: Optional[str] = None, + ) -> tuple[str, dict]: # get prompt template prompt_template_config = self.get_prompt_template( app_mode=app_mode, @@ -101,74 +104,75 @@ class SimplePromptTransform(PromptTransform): pre_prompt=pre_prompt, has_context=context is not None, query_in_prompt=query is not None, - with_memory_prompt=histories is not None + with_memory_prompt=histories is not None, ) - variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs} + variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs} - for v in prompt_template_config['special_variable_keys']: + for v in prompt_template_config["special_variable_keys"]: # support #context#, #query# and #histories# - if v == '#context#': - variables['#context#'] = context if context else '' - elif v == '#query#': - variables['#query#'] = query if query else '' - elif v == '#histories#': - variables['#histories#'] = histories if histories else '' + if v == "#context#": + variables["#context#"] = context if context else "" + elif v == "#query#": + variables["#query#"] = query if query else "" + elif v == "#histories#": + variables["#histories#"] = histories if histories else "" - prompt_template = prompt_template_config['prompt_template'] + prompt_template = prompt_template_config["prompt_template"] prompt = prompt_template.format(variables) - return prompt, prompt_template_config['prompt_rules'] + return prompt, prompt_template_config["prompt_rules"] - def get_prompt_template(self, app_mode: AppMode, - provider: str, - model: str, - pre_prompt: str, - has_context: bool, - query_in_prompt: bool, - with_memory_prompt: bool = False) -> dict: - prompt_rules = self._get_prompt_rule( - app_mode=app_mode, - provider=provider, - model=model - ) + def get_prompt_template( + self, + app_mode: AppMode, + provider: str, + model: str, + pre_prompt: str, + has_context: bool, + query_in_prompt: bool, + with_memory_prompt: bool = False, + ) -> dict: + prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) custom_variable_keys = [] special_variable_keys = [] - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt' and has_context: - prompt += prompt_rules['context_prompt'] - special_variable_keys.append('#context#') - elif order == 'pre_prompt' and pre_prompt: - prompt += pre_prompt + '\n' + prompt = "" + for order in prompt_rules["system_prompt_orders"]: + if order == "context_prompt" and has_context: + prompt += prompt_rules["context_prompt"] + special_variable_keys.append("#context#") + elif order == "pre_prompt" and pre_prompt: + prompt += pre_prompt + "\n" pre_prompt_template = PromptTemplateParser(template=pre_prompt) custom_variable_keys = pre_prompt_template.variable_keys - elif order == 'histories_prompt' and with_memory_prompt: - prompt += prompt_rules['histories_prompt'] - special_variable_keys.append('#histories#') + elif order == "histories_prompt" and with_memory_prompt: + prompt += prompt_rules["histories_prompt"] + special_variable_keys.append("#histories#") if query_in_prompt: - prompt += prompt_rules.get('query_prompt', '{{#query#}}') - special_variable_keys.append('#query#') + prompt += prompt_rules.get("query_prompt", "{{#query#}}") + special_variable_keys.append("#query#") return { "prompt_template": PromptTemplateParser(template=prompt), "custom_variable_keys": custom_variable_keys, "special_variable_keys": special_variable_keys, - "prompt_rules": prompt_rules + "prompt_rules": prompt_rules, } - def _get_chat_model_prompt_messages(self, app_mode: AppMode, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list["FileVar"], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _get_chat_model_prompt_messages( + self, + app_mode: AppMode, + pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list["FileVar"], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] # get prompt @@ -178,7 +182,7 @@ class SimplePromptTransform(PromptTransform): pre_prompt=pre_prompt, inputs=inputs, query=None, - context=context + context=context, ) if prompt and query: @@ -193,7 +197,7 @@ class SimplePromptTransform(PromptTransform): ) ), prompt_messages=prompt_messages, - model_config=model_config + model_config=model_config, ) if query: @@ -203,15 +207,17 @@ class SimplePromptTransform(PromptTransform): return prompt_messages, None - def _get_completion_model_prompt_messages(self, app_mode: AppMode, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list["FileVar"], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _get_completion_model_prompt_messages( + self, + app_mode: AppMode, + pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list["FileVar"], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( app_mode=app_mode, @@ -219,13 +225,11 @@ class SimplePromptTransform(PromptTransform): pre_prompt=pre_prompt, inputs=inputs, query=query, - context=context + context=context, ) if memory: - tmp_human_message = UserPromptMessage( - content=prompt - ) + tmp_human_message = UserPromptMessage(content=prompt) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) histories = self._get_history_messages_from_memory( @@ -236,8 +240,8 @@ class SimplePromptTransform(PromptTransform): ) ), max_token_limit=rest_tokens, - human_prefix=prompt_rules.get('human_prefix', 'Human'), - ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant') + human_prefix=prompt_rules.get("human_prefix", "Human"), + ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), ) # get prompt @@ -248,10 +252,10 @@ class SimplePromptTransform(PromptTransform): inputs=inputs, query=query, context=context, - histories=histories + histories=histories, ) - stops = prompt_rules.get('stops') + stops = prompt_rules.get("stops") if stops is not None and len(stops) == 0: stops = None @@ -277,22 +281,18 @@ class SimplePromptTransform(PromptTransform): :param model: model name :return: """ - prompt_file_name = self._prompt_file_name( - app_mode=app_mode, - provider=provider, - model=model - ) + prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model) # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: return prompt_file_contents[prompt_file_name] # Get the absolute path of the subdirectory - prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates') - json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') + prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") + json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json") # Open the JSON file and read its content - with open(json_file_path, encoding='utf-8') as json_file: + with open(json_file_path, encoding="utf-8") as json_file: content = json.load(json_file) # Store the content of the prompt file @@ -303,21 +303,21 @@ class SimplePromptTransform(PromptTransform): def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan is_baichuan = False - if provider == 'baichuan': + if provider == "baichuan": is_baichuan = True else: baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] - if provider in baichuan_supported_providers and 'baichuan' in model.lower(): + if provider in baichuan_supported_providers and "baichuan" in model.lower(): is_baichuan = True if is_baichuan: if app_mode == AppMode.COMPLETION: - return 'baichuan_completion' + return "baichuan_completion" else: - return 'baichuan_chat' + return "baichuan_chat" # common if app_mode == AppMode.COMPLETION: - return 'common_completion' + return "common_completion" else: - return 'common_chat' + return "common_chat" diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index befdceeda5..29494db221 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -25,26 +25,29 @@ class PromptMessageUtil: tool_calls = [] for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: - role = 'user' + role = "user" elif prompt_message.role == PromptMessageRole.ASSISTANT: - role = 'assistant' + role = "assistant" if isinstance(prompt_message, AssistantPromptMessage): - tool_calls = [{ - 'id': tool_call.id, - 'type': 'function', - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments, + tool_calls = [ + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, } - } for tool_call in prompt_message.tool_calls] + for tool_call in prompt_message.tool_calls + ] elif prompt_message.role == PromptMessageRole.SYSTEM: - role = 'system' + role = "system" elif prompt_message.role == PromptMessageRole.TOOL: - role = 'tool' + role = "tool" else: continue - text = '' + text = "" files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: @@ -53,27 +56,25 @@ class PromptMessageUtil: text += content.data else: content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) + files.append( + { + "type": "image", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "detail": content.detail.value, + } + ) else: text = prompt_message.content - prompt = { - "role": role, - "text": text, - "files": files - } - + prompt = {"role": role, "text": text, "files": files} + if tool_calls: - prompt['tool_calls'] = tool_calls + prompt["tool_calls"] = tool_calls prompts.append(prompt) else: prompt_message = prompt_messages[0] - text = '' + text = "" files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: @@ -82,21 +83,23 @@ class PromptMessageUtil: text += content.data else: content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) + files.append( + { + "type": "image", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "detail": content.detail.value, + } + ) else: text = prompt_message.content params = { - "role": 'user', + "role": "user", "text": text, } if files: - params['files'] = files + params["files"] = files prompts.append(params) diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py index 3e68492df2..8111559675 100644 --- a/api/core/prompt/utils/prompt_template_parser.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -38,8 +38,8 @@ class PromptTemplateParser: return value prompt = re.sub(self.regex, replacer, self.template) - return re.sub(r'<\|.*?\|>', '', prompt) + return re.sub(r"<\|.*?\|>", "", prompt) @classmethod def remove_template_variables(cls, text: str, with_variable_tmpl: bool = False): - return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r'{\1}', text) + return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r"{\1}", text) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 67eee2c294..3a1fe300df 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -90,8 +90,7 @@ class ProviderManager: # Initialize trial provider records if not exist provider_name_to_provider_records_dict = self._init_trial_provider_records( - tenant_id, - provider_name_to_provider_records_dict + tenant_id, provider_name_to_provider_records_dict ) # Get all provider model records of the workspace @@ -107,22 +106,20 @@ class ProviderManager: provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) # Get All load balancing configs - provider_name_to_provider_load_balancing_model_configs_dict \ - = self._get_all_provider_load_balancing_configs(tenant_id) - - provider_configurations = ProviderConfigurations( - tenant_id=tenant_id + provider_name_to_provider_load_balancing_model_configs_dict = self._get_all_provider_load_balancing_configs( + tenant_id ) + provider_configurations = ProviderConfigurations(tenant_id=tenant_id) + # Construct ProviderConfiguration objects for each provider for provider_entity in provider_entities: - # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, - exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, - data=provider_entity, - name_func=lambda x: x.provider, + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + data=provider_entity, + name_func=lambda x: x.provider, ): continue @@ -132,18 +129,11 @@ class ProviderManager: # Convert to custom configuration custom_configuration = self._to_custom_configuration( - tenant_id, - provider_entity, - provider_records, - provider_model_records + tenant_id, provider_entity, provider_records, provider_model_records ) # Convert to system configuration - system_configuration = self._to_system_configuration( - tenant_id, - provider_entity, - provider_records - ) + system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records) # Get preferred provider type preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) @@ -173,14 +163,15 @@ class ProviderManager: provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name) # Get provider load balancing configs - provider_load_balancing_configs \ - = provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name) + provider_load_balancing_configs = provider_name_to_provider_load_balancing_model_configs_dict.get( + provider_name + ) # Convert to model settings model_settings = self._to_model_settings( provider_entity=provider_entity, provider_model_settings=provider_model_settings, - load_balancing_model_configs=provider_load_balancing_configs + load_balancing_model_configs=provider_load_balancing_configs, ) provider_configuration = ProviderConfiguration( @@ -190,7 +181,7 @@ class ProviderManager: using_provider_type=using_provider_type, system_configuration=system_configuration, custom_configuration=custom_configuration, - model_settings=model_settings + model_settings=model_settings, ) provider_configurations[provider_name] = provider_configuration @@ -219,7 +210,7 @@ class ProviderManager: return ProviderModelBundle( configuration=provider_configuration, provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]: @@ -231,11 +222,14 @@ class ProviderManager: :return: """ # Get the corresponding TenantDefaultModel record - default_model = db.session.query(TenantDefaultModel) \ + default_model = ( + db.session.query(TenantDefaultModel) .filter( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type() - ).first() + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # If it does not exist, get the first available provider model from get_configurations # and update the TenantDefaultModel record @@ -244,20 +238,18 @@ class ProviderManager: provider_configurations = self.get_configurations(tenant_id) # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True - ) + available_models = provider_configurations.get_models(model_type=model_type, only_active=True) if available_models: - available_model = next((model for model in available_models if model.model == "gpt-4"), - available_models[0]) + available_model = next( + (model for model in available_models if model.model == "gpt-4"), available_models[0] + ) default_model = TenantDefaultModel( tenant_id=tenant_id, model_type=model_type.to_origin_model_type(), provider_name=available_model.provider.provider, - model_name=available_model.model + model_name=available_model.model, ) db.session.add(default_model) db.session.commit() @@ -276,8 +268,8 @@ class ProviderManager: label=provider_schema.label, icon_small=provider_schema.icon_small, icon_large=provider_schema.icon_large, - supported_model_types=provider_schema.supported_model_types - ) + supported_model_types=provider_schema.supported_model_types, + ), ) def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: @@ -291,15 +283,13 @@ class ProviderManager: provider_configurations = self.get_configurations(tenant_id) # get available models from provider_configurations - all_models = provider_configurations.get_models( - model_type=model_type, - only_active=False - ) + all_models = provider_configurations.get_models(model_type=model_type, only_active=False) return all_models[0].provider.provider, all_models[0].model - def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ - -> TenantDefaultModel: + def update_default_model_record( + self, tenant_id: str, model_type: ModelType, provider: str, model: str + ) -> TenantDefaultModel: """ Update default model record. @@ -314,10 +304,7 @@ class ProviderManager: raise ValueError(f"Provider {provider} does not exist.") # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True - ) + available_models = provider_configurations.get_models(model_type=model_type, only_active=True) # check if the model is exist in available models model_names = [model.model for model in available_models] @@ -325,11 +312,14 @@ class ProviderManager: raise ValueError(f"Model {model} does not exist.") # Get the list of available models from get_configurations and check if it is LLM - default_model = db.session.query(TenantDefaultModel) \ + default_model = ( + db.session.query(TenantDefaultModel) .filter( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type() - ).first() + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # create or update TenantDefaultModel record if default_model: @@ -358,11 +348,7 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - providers = db.session.query(Provider) \ - .filter( - Provider.tenant_id == tenant_id, - Provider.is_valid == True - ).all() + providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() provider_name_to_provider_records_dict = defaultdict(list) for provider in providers: @@ -379,11 +365,11 @@ class ProviderManager: :return: """ # Get all provider model records of the workspace - provider_models = db.session.query(ProviderModel) \ - .filter( - ProviderModel.tenant_id == tenant_id, - ProviderModel.is_valid == True - ).all() + provider_models = ( + db.session.query(ProviderModel) + .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) + .all() + ) provider_name_to_provider_model_records_dict = defaultdict(list) for provider_model in provider_models: @@ -399,10 +385,11 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - preferred_provider_types = db.session.query(TenantPreferredModelProvider) \ - .filter( - TenantPreferredModelProvider.tenant_id == tenant_id - ).all() + preferred_provider_types = ( + db.session.query(TenantPreferredModelProvider) + .filter(TenantPreferredModelProvider.tenant_id == tenant_id) + .all() + ) provider_name_to_preferred_provider_type_records_dict = { preferred_provider_type.provider_name: preferred_provider_type @@ -419,15 +406,17 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - provider_model_settings = db.session.query(ProviderModelSetting) \ - .filter( - ProviderModelSetting.tenant_id == tenant_id - ).all() + provider_model_settings = ( + db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() + ) provider_name_to_provider_model_settings_dict = defaultdict(list) for provider_model_setting in provider_model_settings: - (provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name] - .append(provider_model_setting)) + ( + provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( + provider_model_setting + ) + ) return provider_name_to_provider_model_settings_dict @@ -445,27 +434,30 @@ class ProviderManager: model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled redis_client.setex(cache_key, 120, str(model_load_balancing_enabled)) else: - cache_result = cache_result.decode('utf-8') - model_load_balancing_enabled = cache_result == 'True' + cache_result = cache_result.decode("utf-8") + model_load_balancing_enabled = cache_result == "True" if not model_load_balancing_enabled: return {} - provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ - .filter( - LoadBalancingModelConfig.tenant_id == tenant_id - ).all() + provider_load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() + ) provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) for provider_load_balancing_config in provider_load_balancing_configs: - (provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name] - .append(provider_load_balancing_config)) + ( + provider_name_to_provider_load_balancing_model_configs_dict[ + provider_load_balancing_config.provider_name + ].append(provider_load_balancing_config) + ) return provider_name_to_provider_load_balancing_model_configs_dict @staticmethod - def _init_trial_provider_records(tenant_id: str, - provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: + def _init_trial_provider_records( + tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] + ) -> dict[str, list]: """ Initialize trial provider records if not exists. @@ -489,8 +481,9 @@ class ProviderManager: if provider_record.provider_type != ProviderType.SYSTEM.value: continue - provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \ - = provider_record + provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( + provider_record + ) for quota in configuration.quotas: if quota.quota_type == ProviderQuotaType.TRIAL: @@ -504,19 +497,22 @@ class ProviderManager: quota_type=ProviderQuotaType.TRIAL.value, quota_limit=quota.quota_limit, quota_used=0, - is_valid=True + is_valid=True, ) db.session.add(provider_record) db.session.commit() except IntegrityError: db.session.rollback() - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == tenant_id, - Provider.provider_name == provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.TRIAL.value - ).first() + Provider.tenant_id == tenant_id, + Provider.provider_name == provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == ProviderQuotaType.TRIAL.value, + ) + .first() + ) if provider_record and not provider_record.is_valid: provider_record.is_valid = True @@ -526,11 +522,13 @@ class ProviderManager: return provider_name_to_provider_records_dict - def _to_custom_configuration(self, - tenant_id: str, - provider_entity: ProviderEntity, - provider_records: list[Provider], - provider_model_records: list[ProviderModel]) -> CustomConfiguration: + def _to_custom_configuration( + self, + tenant_id: str, + provider_entity: ProviderEntity, + provider_records: list[Provider], + provider_model_records: list[ProviderModel], + ) -> CustomConfiguration: """ Convert to custom configuration. @@ -543,7 +541,8 @@ class ProviderManager: # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] + if provider_entity.provider_credential_schema + else [] ) # Get custom provider record @@ -563,7 +562,7 @@ class ProviderManager: provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=custom_provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) # Get cached provider credentials @@ -572,11 +571,11 @@ class ProviderManager: if not cached_provider_credentials: try: # fix origin data - if (custom_provider_record.encrypted_config - and not custom_provider_record.encrypted_config.startswith("{")): - provider_credentials = { - "openai_api_key": custom_provider_record.encrypted_config - } + if ( + custom_provider_record.encrypted_config + and not custom_provider_record.encrypted_config.startswith("{") + ): + provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) except JSONDecodeError: @@ -590,28 +589,23 @@ class ProviderManager: if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), - self.decoding_rsa_key, - self.decoding_cipher_rsa + provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa ) except ValueError: pass # cache provider credentials - provider_credentials_cache.set( - credentials=provider_credentials - ) + provider_credentials_cache.set(credentials=provider_credentials) else: provider_credentials = cached_provider_credentials - custom_provider_configuration = CustomProviderConfiguration( - credentials=provider_credentials - ) + custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas - if provider_entity.model_credential_schema else [] + if provider_entity.model_credential_schema + else [] ) # Get custom provider model credentials @@ -621,9 +615,7 @@ class ProviderManager: continue provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL ) # Get cached provider model credentials @@ -645,15 +637,13 @@ class ProviderManager: provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), self.decoding_rsa_key, - self.decoding_cipher_rsa + self.decoding_cipher_rsa, ) except ValueError: pass # cache provider model credentials - provider_model_credentials_cache.set( - credentials=provider_model_credentials - ) + provider_model_credentials_cache.set(credentials=provider_model_credentials) else: provider_model_credentials = cached_provider_model_credentials @@ -661,19 +651,15 @@ class ProviderManager: CustomModelConfiguration( model=provider_model_record.model_name, model_type=ModelType.value_of(provider_model_record.model_type), - credentials=provider_model_credentials + credentials=provider_model_credentials, ) ) - return CustomConfiguration( - provider=custom_provider_configuration, - models=custom_model_configurations - ) + return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations) - def _to_system_configuration(self, - tenant_id: str, - provider_entity: ProviderEntity, - provider_records: list[Provider]) -> SystemConfiguration: + def _to_system_configuration( + self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] + ) -> SystemConfiguration: """ Convert to system configuration. @@ -685,11 +671,11 @@ class ProviderManager: # Get hosting configuration hosting_configuration = ext_hosting_provider.hosting_configuration - if provider_entity.provider not in hosting_configuration.provider_map \ - or not hosting_configuration.provider_map.get(provider_entity.provider).enabled: - return SystemConfiguration( - enabled=False - ) + if ( + provider_entity.provider not in hosting_configuration.provider_map + or not hosting_configuration.provider_map.get(provider_entity.provider).enabled + ): + return SystemConfiguration(enabled=False) provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) @@ -699,8 +685,9 @@ class ProviderManager: if provider_record.provider_type != ProviderType.SYSTEM.value: continue - quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \ - = provider_record + quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( + provider_record + ) quota_configurations = [] for provider_quota in provider_hosting_configuration.quotas: @@ -712,7 +699,7 @@ class ProviderManager: quota_used=0, quota_limit=0, is_valid=False, - restrict_models=provider_quota.restrict_models + restrict_models=provider_quota.restrict_models, ) else: continue @@ -724,16 +711,15 @@ class ProviderManager: quota_unit=provider_hosting_configuration.quota_unit, quota_used=provider_record.quota_used, quota_limit=provider_record.quota_limit, - is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1, - restrict_models=provider_quota.restrict_models + is_valid=provider_record.quota_limit > provider_record.quota_used + or provider_record.quota_limit == -1, + restrict_models=provider_quota.restrict_models, ) quota_configurations.append(quota_configuration) if len(quota_configurations) == 0: - return SystemConfiguration( - enabled=False - ) + return SystemConfiguration(enabled=False) current_quota_type = self._choice_current_using_quota_type(quota_configurations) @@ -745,7 +731,7 @@ class ProviderManager: provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) # Get cached provider credentials @@ -760,7 +746,8 @@ class ProviderManager: # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] + if provider_entity.provider_credential_schema + else [] ) # Get decoding rsa key and cipher for decrypting credentials @@ -771,9 +758,7 @@ class ProviderManager: if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), - self.decoding_rsa_key, - self.decoding_cipher_rsa + provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa ) except ValueError: pass @@ -781,9 +766,7 @@ class ProviderManager: current_using_credentials = provider_credentials # cache provider credentials - provider_credentials_cache.set( - credentials=current_using_credentials - ) + provider_credentials_cache.set(credentials=current_using_credentials) else: current_using_credentials = cached_provider_credentials else: @@ -794,7 +777,7 @@ class ProviderManager: enabled=True, current_quota_type=current_quota_type, quota_configurations=quota_configurations, - credentials=current_using_credentials + credentials=current_using_credentials, ) @staticmethod @@ -809,8 +792,7 @@ class ProviderManager: """ # convert to dict quota_type_to_quota_configuration_dict = { - quota_configuration.quota_type: quota_configuration - for quota_configuration in quota_configurations + quota_configuration.quota_type: quota_configuration for quota_configuration in quota_configurations } last_quota_configuration = None @@ -823,7 +805,7 @@ class ProviderManager: if last_quota_configuration: return last_quota_configuration.quota_type - raise ValueError('No quota type available') + raise ValueError("No quota type available") @staticmethod def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]: @@ -840,10 +822,12 @@ class ProviderManager: return secret_input_form_variables - def _to_model_settings(self, provider_entity: ProviderEntity, - provider_model_settings: Optional[list[ProviderModelSetting]] = None, - load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \ - -> list[ModelSettings]: + def _to_model_settings( + self, + provider_entity: ProviderEntity, + provider_model_settings: Optional[list[ProviderModelSetting]] = None, + load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None, + ) -> list[ModelSettings]: """ Convert to model settings. :param provider_entity: provider entity @@ -854,7 +838,8 @@ class ProviderManager: # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas - if provider_entity.model_credential_schema else [] + if provider_entity.model_credential_schema + else [] ) model_settings = [] @@ -865,24 +850,28 @@ class ProviderManager: load_balancing_configs = [] if provider_model_setting.load_balancing_enabled and load_balancing_model_configs: for load_balancing_model_config in load_balancing_model_configs: - if (load_balancing_model_config.model_name == provider_model_setting.model_name - and load_balancing_model_config.model_type == provider_model_setting.model_type): + if ( + load_balancing_model_config.model_name == provider_model_setting.model_name + and load_balancing_model_config.model_type == provider_model_setting.model_type + ): if not load_balancing_model_config.enabled: continue if not load_balancing_model_config.encrypted_config: if load_balancing_model_config.name == "__inherit__": - load_balancing_configs.append(ModelLoadBalancingConfiguration( - id=load_balancing_model_config.id, - name=load_balancing_model_config.name, - credentials={} - )) + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials={}, + ) + ) continue provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=load_balancing_model_config.tenant_id, identity_id=load_balancing_model_config.id, - cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) # Get cached provider model credentials @@ -897,7 +886,8 @@ class ProviderManager: # Get decoding rsa key and cipher for decrypting credentials if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding( - load_balancing_model_config.tenant_id) + load_balancing_model_config.tenant_id + ) for variable in model_credential_secret_variables: if variable in provider_model_credentials: @@ -905,30 +895,30 @@ class ProviderManager: provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), self.decoding_rsa_key, - self.decoding_cipher_rsa + self.decoding_cipher_rsa, ) except ValueError: pass # cache provider model credentials - provider_model_credentials_cache.set( - credentials=provider_model_credentials - ) + provider_model_credentials_cache.set(credentials=provider_model_credentials) else: provider_model_credentials = cached_provider_model_credentials - load_balancing_configs.append(ModelLoadBalancingConfiguration( - id=load_balancing_model_config.id, - name=load_balancing_model_config.name, - credentials=provider_model_credentials - )) + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials=provider_model_credentials, + ) + ) model_settings.append( ModelSettings( model=provider_model_setting.model_name, model_type=ModelType.value_of(provider_model_setting.model_type), enabled=provider_model_setting.enabled, - load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [] + load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], ) ) diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py index eaad0e0f4c..3c6ab2e4cf 100644 --- a/api/core/rag/cleaner/clean_processor.py +++ b/api/core/rag/cleaner/clean_processor.py @@ -2,37 +2,35 @@ import re class CleanProcessor: - @classmethod def clean(cls, text: str, process_rule: dict) -> str: # default clean # remove invalid symbol - text = re.sub(r'<\|', '<', text) - text = re.sub(r'\|>', '>', text) - text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) + text = re.sub(r"<\|", "<", text) + text = re.sub(r"\|>", ">", text) + text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text) # Unicode U+FFFE - text = re.sub('\uFFFE', '', text) + text = re.sub("\ufffe", "", text) - rules = process_rule['rules'] if process_rule else None - if 'pre_processing_rules' in rules: + rules = process_rule["rules"] if process_rule else None + if "pre_processing_rules" in rules: pre_processing_rules = rules["pre_processing_rules"] for pre_processing_rule in pre_processing_rules: if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: # Remove extra spaces - pattern = r'\n{3,}' - text = re.sub(pattern, '\n\n', text) - pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' - text = re.sub(pattern, ' ', text) + pattern = r"\n{3,}" + text = re.sub(pattern, "\n\n", text) + pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}" + text = re.sub(pattern, " ", text) elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: # Remove email - pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' - text = re.sub(pattern, '', text) + pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" + text = re.sub(pattern, "", text) # Remove URL - pattern = r'https?://[^\s]+' - text = re.sub(pattern, '', text) + pattern = r"https?://[^\s]+" + text = re.sub(pattern, "", text) return text def filter_string(self, text): - return text diff --git a/api/core/rag/cleaner/cleaner_base.py b/api/core/rag/cleaner/cleaner_base.py index 523bd904f2..d3bc2f765e 100644 --- a/api/core/rag/cleaner/cleaner_base.py +++ b/api/core/rag/cleaner/cleaner_base.py @@ -1,12 +1,11 @@ """Abstract interface for document cleaner implementations.""" + from abc import ABC, abstractmethod class BaseCleaner(ABC): - """Interface for clean chunk content. - """ + """Interface for clean chunk content.""" @abstractmethod def clean(self, content: str): raise NotImplementedError - diff --git a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py index 6a0b8c9046..167a919e69 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.core import clean_extra_whitespace diff --git a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py index 6fc3a408da..9c682d29db 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" import re diff --git a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py index 87dc2d49fa..0cdbb171e1 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.core import clean_non_ascii_chars diff --git a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py index 974a28fef1..9f42044a2d 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py @@ -1,11 +1,12 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """Replaces unicode quote characters, such as the \x91 character in a string.""" from unstructured.cleaners.core import replace_unicode_quotes + return replace_unicode_quotes(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py index dfaf3a2787..32ae7217e8 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredTranslateTextCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.translate import translate_text diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index ad9ee4f7cf..b1d6f93cff 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -12,17 +12,27 @@ from core.rag.rerank.weight_rerank import WeightRerankRunner class DataPostProcessor: - """Interface for data post-processing document. - """ + """Interface for data post-processing document.""" - def __init__(self, tenant_id: str, reranking_mode: str, - reranking_model: Optional[dict] = None, weights: Optional[dict] = None, - reorder_enabled: bool = False): + def __init__( + self, + tenant_id: str, + reranking_mode: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + reorder_enabled: bool = False, + ): self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) self.reorder_runner = self._get_reorder_runner(reorder_enabled) - def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def invoke( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: if self.rerank_runner: documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) @@ -31,21 +41,26 @@ class DataPostProcessor: return documents - def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None, - weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]: + def _get_rerank_runner( + self, + reranking_mode: str, + tenant_id: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + ) -> Optional[RerankModelRunner | WeightRerankRunner]: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: return WeightRerankRunner( tenant_id, Weights( vector_setting=VectorSetting( - vector_weight=weights['vector_setting']['vector_weight'], - embedding_provider_name=weights['vector_setting']['embedding_provider_name'], - embedding_model_name=weights['vector_setting']['embedding_model_name'], + vector_weight=weights["vector_setting"]["vector_weight"], + embedding_provider_name=weights["vector_setting"]["embedding_provider_name"], + embedding_model_name=weights["vector_setting"]["embedding_model_name"], ), keyword_setting=KeywordSetting( - keyword_weight=weights['keyword_setting']['keyword_weight'], - ) - ) + keyword_weight=weights["keyword_setting"]["keyword_weight"], + ), + ), ) elif reranking_mode == RerankMode.RERANKING_MODEL.value: if reranking_model: @@ -53,9 +68,9 @@ class DataPostProcessor: model_manager = ModelManager() rerank_model_instance = model_manager.get_model_instance( tenant_id=tenant_id, - provider=reranking_model['reranking_provider_name'], + provider=reranking_model["reranking_provider_name"], model_type=ModelType.RERANK, - model=reranking_model['reranking_model_name'] + model=reranking_model["reranking_model_name"], ) except InvokeAuthorizationError: return None @@ -67,5 +82,3 @@ class DataPostProcessor: if reorder_enabled: return ReorderRunner() return None - - diff --git a/api/core/rag/data_post_processor/reorder.py b/api/core/rag/data_post_processor/reorder.py index 71297588a4..a9a0885241 100644 --- a/api/core/rag/data_post_processor/reorder.py +++ b/api/core/rag/data_post_processor/reorder.py @@ -2,7 +2,6 @@ from core.rag.models.document import Document class ReorderRunner: - def run(self, documents: list[Document]) -> list[Document]: # Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list odd_elements = documents[::2] diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index a3714c2fd3..3073100746 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -24,37 +24,42 @@ class Jieba(BaseKeyword): self._config = KeywordTableConfig() def create(self, texts: list[Document], **kwargs) -> BaseKeyword: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() for text in texts: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) self._save_dataset_keyword_table(keyword_table) return self def add_texts(self, texts: list[Document], **kwargs): - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() - keywords_list = kwargs.get('keywords_list', None) + keywords_list = kwargs.get("keywords_list", None) for i in range(len(texts)): text = texts[i] if keywords_list: keywords = keywords_list[i] if not keywords: - keywords = keyword_table_handler.extract_keywords(text.page_content, - self._config.max_keywords_per_chunk) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) else: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) self._save_dataset_keyword_table(keyword_table) @@ -63,97 +68,91 @@ class Jieba(BaseKeyword): return id in set.union(*keyword_table.values()) def delete_by_ids(self, ids: list[str]) -> None: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table() keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) self._save_dataset_keyword_table(keyword_table) - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: keyword_table = self._get_dataset_keyword_table() - k = kwargs.get('top_k', 4) + k = kwargs.get("top_k", 4) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) documents = [] for chunk_index in sorted_chunk_indices: - segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self.dataset.id, - DocumentSegment.index_node_id == chunk_index - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index) + .first() + ) if segment: - - documents.append(Document( - page_content=segment.content, - metadata={ - "doc_id": chunk_index, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - )) + documents.append( + Document( + page_content=segment.content, + metadata={ + "doc_id": chunk_index, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + ) return documents def delete(self) -> None: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: db.session.delete(dataset_keyword_table) db.session.commit() - if dataset_keyword_table.data_source_type != 'database': - file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' + if dataset_keyword_table.data_source_type != "database": + file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" storage.delete(file_key) def _save_dataset_keyword_table(self, keyword_table): keyword_table_dict = { - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": keyword_table - } + "__type__": "keyword_table", + "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, } dataset_keyword_table = self.dataset.dataset_keyword_table keyword_data_source_type = dataset_keyword_table.data_source_type - if keyword_data_source_type == 'database': + if keyword_data_source_type == "database": dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) db.session.commit() else: - file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' + file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8')) + storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8")) def _get_dataset_keyword_table(self) -> Optional[dict]: dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return keyword_table_dict['__data__']['table'] + return keyword_table_dict["__data__"]["table"] else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( dataset_id=self.dataset.id, - keyword_table='', + keyword_table="", data_source_type=keyword_data_source_type, ) - if keyword_data_source_type == 'database': - dataset_keyword_table.keyword_table = json.dumps({ - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": {} - } - }, cls=SetEncoder) + if keyword_data_source_type == "database": + dataset_keyword_table.keyword_table = json.dumps( + { + "__type__": "keyword_table", + "__data__": {"index_id": self.dataset.id, "summary": None, "table": {}}, + }, + cls=SetEncoder, + ) db.session.add(dataset_keyword_table) db.session.commit() @@ -174,9 +173,7 @@ class Jieba(BaseKeyword): keywords_to_delete = set() for keyword, node_idxs in keyword_table.items(): if node_idxs_to_delete.intersection(node_idxs): - keyword_table[keyword] = node_idxs.difference( - node_idxs_to_delete - ) + keyword_table[keyword] = node_idxs.difference(node_idxs_to_delete) if not keyword_table[keyword]: keywords_to_delete.add(keyword) @@ -202,13 +199,14 @@ class Jieba(BaseKeyword): reverse=True, ) - return sorted_chunk_indices[: k] + return sorted_chunk_indices[:k] def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): - document_segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.index_node_id == node_id - ).first() + document_segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) + .first() + ) if document_segment: document_segment.keywords = keywords db.session.add(document_segment) @@ -224,14 +222,14 @@ class Jieba(BaseKeyword): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() for pre_segment_data in pre_segment_data_list: - segment = pre_segment_data['segment'] - if pre_segment_data['keywords']: - segment.keywords = pre_segment_data['keywords'] - keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, - pre_segment_data['keywords']) + segment = pre_segment_data["segment"] + if pre_segment_data["keywords"]: + segment.keywords = pre_segment_data["keywords"] + keyword_table = self._add_text_to_keyword_table( + keyword_table, segment.index_node_id, pre_segment_data["keywords"] + ) else: - keywords = keyword_table_handler.extract_keywords(segment.content, - self._config.max_keywords_per_chunk) + keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk) segment.keywords = list(keywords) keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) self._save_dataset_keyword_table(keyword_table) diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index ad669ef515..4b1ade8e3f 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -8,7 +8,6 @@ from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS class JiebaKeywordTableHandler: - def __init__(self): default_tfidf.stop_words = STOPWORDS @@ -30,4 +29,4 @@ class JiebaKeywordTableHandler: if len(sub_tokens) > 1: results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) - return results \ No newline at end of file + return results diff --git a/api/core/rag/datasource/keyword/jieba/stopwords.py b/api/core/rag/datasource/keyword/jieba/stopwords.py index c616a15cf0..9abe78d6ef 100644 --- a/api/core/rag/datasource/keyword/jieba/stopwords.py +++ b/api/core/rag/datasource/keyword/jieba/stopwords.py @@ -1,90 +1,1380 @@ STOPWORDS = { - "during", "when", "but", "then", "further", "isn", "mustn't", "until", "own", "i", "couldn", "y", "only", "you've", - "ours", "who", "where", "ourselves", "has", "to", "was", "didn't", "themselves", "if", "against", "through", "her", - "an", "your", "can", "those", "didn", "about", "aren't", "shan't", "be", "not", "these", "again", "so", "t", - "theirs", "weren", "won't", "won", "itself", "just", "same", "while", "why", "doesn", "aren", "him", "haven", - "for", "you'll", "that", "we", "am", "d", "by", "having", "wasn't", "than", "weren't", "out", "from", "now", - "their", "too", "hadn", "o", "needn", "most", "it", "under", "needn't", "any", "some", "few", "ll", "hers", "which", - "m", "you're", "off", "other", "had", "she", "you'd", "do", "you", "does", "s", "will", "each", "wouldn't", "hasn't", - "such", "more", "whom", "she's", "my", "yours", "yourself", "of", "on", "very", "hadn't", "with", "yourselves", - "been", "ma", "them", "mightn't", "shan", "mustn", "they", "what", "both", "that'll", "how", "is", "he", "because", - "down", "haven't", "are", "no", "it's", "our", "being", "the", "or", "above", "myself", "once", "don't", "doesn't", - "as", "nor", "here", "herself", "hasn", "mightn", "have", "its", "all", "were", "ain", "this", "at", "after", - "over", "shouldn't", "into", "before", "don", "wouldn", "re", "couldn't", "wasn", "in", "should", "there", - "himself", "isn't", "should've", "doing", "ve", "shouldn", "a", "did", "and", "his", "between", "me", "up", "below", - "人民", "末##末", "啊", "阿", "哎", "哎呀", "哎哟", "唉", "俺", "俺们", "按", "按照", "吧", "吧哒", "把", "罢了", "被", "本", - "本着", "比", "比方", "比如", "鄙人", "彼", "彼此", "边", "别", "别的", "别说", "并", "并且", "不比", "不成", "不单", "不但", - "不独", "不管", "不光", "不过", "不仅", "不拘", "不论", "不怕", "不然", "不如", "不特", "不惟", "不问", "不只", "朝", "朝着", - "趁", "趁着", "乘", "冲", "除", "除此之外", "除非", "除了", "此", "此间", "此外", "从", "从而", "打", "待", "但", "但是", "当", - "当着", "到", "得", "的", "的话", "等", "等等", "地", "第", "叮咚", "对", "对于", "多", "多少", "而", "而况", "而且", "而是", - "而外", "而言", "而已", "尔后", "反过来", "反过来说", "反之", "非但", "非徒", "否则", "嘎", "嘎登", "该", "赶", "个", "各", - "各个", "各位", "各种", "各自", "给", "根据", "跟", "故", "故此", "固然", "关于", "管", "归", "果然", "果真", "过", "哈", - "哈哈", "呵", "和", "何", "何处", "何况", "何时", "嘿", "哼", "哼唷", "呼哧", "乎", "哗", "还是", "还有", "换句话说", "换言之", - "或", "或是", "或者", "极了", "及", "及其", "及至", "即", "即便", "即或", "即令", "即若", "即使", "几", "几时", "己", "既", - "既然", "既是", "继而", "加之", "假如", "假若", "假使", "鉴于", "将", "较", "较之", "叫", "接着", "结果", "借", "紧接着", - "进而", "尽", "尽管", "经", "经过", "就", "就是", "就是说", "据", "具体地说", "具体说来", "开始", "开外", "靠", "咳", "可", - "可见", "可是", "可以", "况且", "啦", "来", "来着", "离", "例如", "哩", "连", "连同", "两者", "了", "临", "另", "另外", - "另一方面", "论", "嘛", "吗", "慢说", "漫说", "冒", "么", "每", "每当", "们", "莫若", "某", "某个", "某些", "拿", "哪", - "哪边", "哪儿", "哪个", "哪里", "哪年", "哪怕", "哪天", "哪些", "哪样", "那", "那边", "那儿", "那个", "那会儿", "那里", "那么", - "那么些", "那么样", "那时", "那些", "那样", "乃", "乃至", "呢", "能", "你", "你们", "您", "宁", "宁可", "宁肯", "宁愿", "哦", - "呕", "啪达", "旁人", "呸", "凭", "凭借", "其", "其次", "其二", "其他", "其它", "其一", "其余", "其中", "起", "起见", "岂但", - "恰恰相反", "前后", "前者", "且", "然而", "然后", "然则", "让", "人家", "任", "任何", "任凭", "如", "如此", "如果", "如何", - "如其", "如若", "如上所述", "若", "若非", "若是", "啥", "上下", "尚且", "设若", "设使", "甚而", "甚么", "甚至", "省得", "时候", - "什么", "什么样", "使得", "是", "是的", "首先", "谁", "谁知", "顺", "顺着", "似的", "虽", "虽然", "虽说", "虽则", "随", "随着", - "所", "所以", "他", "他们", "他人", "它", "它们", "她", "她们", "倘", "倘或", "倘然", "倘若", "倘使", "腾", "替", "通过", "同", - "同时", "哇", "万一", "往", "望", "为", "为何", "为了", "为什么", "为着", "喂", "嗡嗡", "我", "我们", "呜", "呜呼", "乌乎", - "无论", "无宁", "毋宁", "嘻", "吓", "相对而言", "像", "向", "向着", "嘘", "呀", "焉", "沿", "沿着", "要", "要不", "要不然", - "要不是", "要么", "要是", "也", "也罢", "也好", "一", "一般", "一旦", "一方面", "一来", "一切", "一样", "一则", "依", "依照", - "矣", "以", "以便", "以及", "以免", "以至", "以至于", "以致", "抑或", "因", "因此", "因而", "因为", "哟", "用", "由", - "由此可见", "由于", "有", "有的", "有关", "有些", "又", "于", "于是", "于是乎", "与", "与此同时", "与否", "与其", "越是", - "云云", "哉", "再说", "再者", "在", "在下", "咱", "咱们", "则", "怎", "怎么", "怎么办", "怎么样", "怎样", "咋", "照", "照着", - "者", "这", "这边", "这儿", "这个", "这会儿", "这就是说", "这里", "这么", "这么点儿", "这么些", "这么样", "这时", "这些", "这样", - "正如", "吱", "之", "之类", "之所以", "之一", "只是", "只限", "只要", "只有", "至", "至于", "诸位", "着", "着呢", "自", "自从", - "自个儿", "自各儿", "自己", "自家", "自身", "综上所述", "总的来看", "总的来说", "总的说来", "总而言之", "总之", "纵", "纵令", - "纵然", "纵使", "遵照", "作为", "兮", "呃", "呗", "咚", "咦", "喏", "啐", "喔唷", "嗬", "嗯", "嗳", "~", "!", ".", ":", - "\"", "'", "(", ")", "*", "A", "白", "社会主义", "--", "..", ">>", " [", " ]", "", "<", ">", "/", "\\", "|", "-", "_", - "+", "=", "&", "^", "%", "#", "@", "`", ";", "$", "(", ")", "——", "—", "¥", "·", "...", "‘", "’", "〉", "〈", "…", - " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "二", - "三", "四", "五", "六", "七", "八", "九", "零", ">", "<", "@", "#", "$", "%", "︿", "&", "*", "+", "~", "|", "[", - "]", "{", "}", "啊哈", "啊呀", "啊哟", "挨次", "挨个", "挨家挨户", "挨门挨户", "挨门逐户", "挨着", "按理", "按期", "按时", - "按说", "暗地里", "暗中", "暗自", "昂然", "八成", "白白", "半", "梆", "保管", "保险", "饱", "背地里", "背靠背", "倍感", "倍加", - "本人", "本身", "甭", "比起", "比如说", "比照", "毕竟", "必", "必定", "必将", "必须", "便", "别人", "并非", "并肩", "并没", - "并没有", "并排", "并无", "勃然", "不", "不必", "不常", "不大", "不但...而且", "不得", "不得不", "不得了", "不得已", "不迭", - "不定", "不对", "不妨", "不管怎样", "不会", "不仅...而且", "不仅仅", "不仅仅是", "不经意", "不可开交", "不可抗拒", "不力", "不了", - "不料", "不满", "不免", "不能不", "不起", "不巧", "不然的话", "不日", "不少", "不胜", "不时", "不是", "不同", "不能", "不要", - "不外", "不外乎", "不下", "不限", "不消", "不已", "不亦乐乎", "不由得", "不再", "不择手段", "不怎么", "不曾", "不知不觉", "不止", - "不止一次", "不至于", "才", "才能", "策略地", "差不多", "差一点", "常", "常常", "常言道", "常言说", "常言说得好", "长此下去", - "长话短说", "长期以来", "长线", "敞开儿", "彻夜", "陈年", "趁便", "趁机", "趁热", "趁势", "趁早", "成年", "成年累月", "成心", - "乘机", "乘胜", "乘势", "乘隙", "乘虚", "诚然", "迟早", "充分", "充其极", "充其量", "抽冷子", "臭", "初", "出", "出来", "出去", - "除此", "除此而外", "除此以外", "除开", "除去", "除却", "除外", "处处", "川流不息", "传", "传说", "传闻", "串行", "纯", "纯粹", - "此后", "此中", "次第", "匆匆", "从不", "从此", "从此以后", "从古到今", "从古至今", "从今以后", "从宽", "从来", "从轻", "从速", - "从头", "从未", "从无到有", "从小", "从新", "从严", "从优", "从早到晚", "从中", "从重", "凑巧", "粗", "存心", "达旦", "打从", - "打开天窗说亮话", "大", "大不了", "大大", "大抵", "大都", "大多", "大凡", "大概", "大家", "大举", "大略", "大面儿上", "大事", - "大体", "大体上", "大约", "大张旗鼓", "大致", "呆呆地", "带", "殆", "待到", "单", "单纯", "单单", "但愿", "弹指之间", "当场", - "当儿", "当即", "当口儿", "当然", "当庭", "当头", "当下", "当真", "当中", "倒不如", "倒不如说", "倒是", "到处", "到底", "到了儿", - "到目前为止", "到头", "到头来", "得起", "得天独厚", "的确", "等到", "叮当", "顶多", "定", "动不动", "动辄", "陡然", "都", "独", - "独自", "断然", "顿时", "多次", "多多", "多多少少", "多多益善", "多亏", "多年来", "多年前", "而后", "而论", "而又", "尔等", - "二话不说", "二话没说", "反倒", "反倒是", "反而", "反手", "反之亦然", "反之则", "方", "方才", "方能", "放量", "非常", "非得", - "分期", "分期分批", "分头", "奋勇", "愤然", "风雨无阻", "逢", "弗", "甫", "嘎嘎", "该当", "概", "赶快", "赶早不赶晚", "敢", - "敢情", "敢于", "刚", "刚才", "刚好", "刚巧", "高低", "格外", "隔日", "隔夜", "个人", "各式", "更", "更加", "更进一步", "更为", - "公然", "共", "共总", "够瞧的", "姑且", "古来", "故而", "故意", "固", "怪", "怪不得", "惯常", "光", "光是", "归根到底", - "归根结底", "过于", "毫不", "毫无", "毫无保留地", "毫无例外", "好在", "何必", "何尝", "何妨", "何苦", "何乐而不为", "何须", - "何止", "很", "很多", "很少", "轰然", "后来", "呼啦", "忽地", "忽然", "互", "互相", "哗啦", "话说", "还", "恍然", "会", "豁然", - "活", "伙同", "或多或少", "或许", "基本", "基本上", "基于", "极", "极大", "极度", "极端", "极力", "极其", "极为", "急匆匆", - "即将", "即刻", "即是说", "几度", "几番", "几乎", "几经", "既...又", "继之", "加上", "加以", "间或", "简而言之", "简言之", - "简直", "见", "将才", "将近", "将要", "交口", "较比", "较为", "接连不断", "接下来", "皆可", "截然", "截至", "藉以", "借此", - "借以", "届时", "仅", "仅仅", "谨", "进来", "进去", "近", "近几年来", "近来", "近年来", "尽管如此", "尽可能", "尽快", "尽量", - "尽然", "尽如人意", "尽心竭力", "尽心尽力", "尽早", "精光", "经常", "竟", "竟然", "究竟", "就此", "就地", "就算", "居然", "局外", - "举凡", "据称", "据此", "据实", "据说", "据我所知", "据悉", "具体来说", "决不", "决非", "绝", "绝不", "绝顶", "绝对", "绝非", - "均", "喀", "看", "看来", "看起来", "看上去", "看样子", "可好", "可能", "恐怕", "快", "快要", "来不及", "来得及", "来讲", - "来看", "拦腰", "牢牢", "老", "老大", "老老实实", "老是", "累次", "累年", "理当", "理该", "理应", "历", "立", "立地", "立刻", - "立马", "立时", "联袂", "连连", "连日", "连日来", "连声", "连袂", "临到", "另方面", "另行", "另一个", "路经", "屡", "屡次", - "屡次三番", "屡屡", "缕缕", "率尔", "率然", "略", "略加", "略微", "略为", "论说", "马上", "蛮", "满", "没", "没有", "每逢", - "每每", "每时每刻", "猛然", "猛然间", "莫", "莫不", "莫非", "莫如", "默默地", "默然", "呐", "那末", "奈", "难道", "难得", "难怪", - "难说", "内", "年复一年", "凝神", "偶而", "偶尔", "怕", "砰", "碰巧", "譬如", "偏偏", "乒", "平素", "颇", "迫于", "扑通", - "其后", "其实", "奇", "齐", "起初", "起来", "起首", "起头", "起先", "岂", "岂非", "岂止", "迄", "恰逢", "恰好", "恰恰", "恰巧", - "恰如", "恰似", "千", "千万", "千万千万", "切", "切不可", "切莫", "切切", "切勿", "窃", "亲口", "亲身", "亲手", "亲眼", "亲自", - "顷", "顷刻", "顷刻间", "顷刻之间", "请勿", "穷年累月", "取道", "去", "权时", "全都", "全力", "全年", "全然", "全身心", "然", - "人人", "仍", "仍旧", "仍然", "日复一日", "日见", "日渐", "日益", "日臻", "如常", "如此等等", "如次", "如今", "如期", "如前所述", - "如上", "如下", "汝", "三番两次", "三番五次", "三天两头", "瑟瑟", "沙沙", "上", "上来", "上去", "一个", "月", "日", "\n" + "during", + "when", + "but", + "then", + "further", + "isn", + "mustn't", + "until", + "own", + "i", + "couldn", + "y", + "only", + "you've", + "ours", + "who", + "where", + "ourselves", + "has", + "to", + "was", + "didn't", + "themselves", + "if", + "against", + "through", + "her", + "an", + "your", + "can", + "those", + "didn", + "about", + "aren't", + "shan't", + "be", + "not", + "these", + "again", + "so", + "t", + "theirs", + "weren", + "won't", + "won", + "itself", + "just", + "same", + "while", + "why", + "doesn", + "aren", + "him", + "haven", + "for", + "you'll", + "that", + "we", + "am", + "d", + "by", + "having", + "wasn't", + "than", + "weren't", + "out", + "from", + "now", + "their", + "too", + "hadn", + "o", + "needn", + "most", + "it", + "under", + "needn't", + "any", + "some", + "few", + "ll", + "hers", + "which", + "m", + "you're", + "off", + "other", + "had", + "she", + "you'd", + "do", + "you", + "does", + "s", + "will", + "each", + "wouldn't", + "hasn't", + "such", + "more", + "whom", + "she's", + "my", + "yours", + "yourself", + "of", + "on", + "very", + "hadn't", + "with", + "yourselves", + "been", + "ma", + "them", + "mightn't", + "shan", + "mustn", + "they", + "what", + "both", + "that'll", + "how", + "is", + "he", + "because", + "down", + "haven't", + "are", + "no", + "it's", + "our", + "being", + "the", + "or", + "above", + "myself", + "once", + "don't", + "doesn't", + "as", + "nor", + "here", + "herself", + "hasn", + "mightn", + "have", + "its", + "all", + "were", + "ain", + "this", + "at", + "after", + "over", + "shouldn't", + "into", + "before", + "don", + "wouldn", + "re", + "couldn't", + "wasn", + "in", + "should", + "there", + "himself", + "isn't", + "should've", + "doing", + "ve", + "shouldn", + "a", + "did", + "and", + "his", + "between", + "me", + "up", + "below", + "人民", + "末##末", + "啊", + "阿", + "哎", + "哎呀", + "哎哟", + "唉", + "俺", + "俺们", + "按", + "按照", + "吧", + "吧哒", + "把", + "罢了", + "被", + "本", + "本着", + "比", + "比方", + "比如", + "鄙人", + "彼", + "彼此", + "边", + "别", + "别的", + "别说", + "并", + "并且", + "不比", + "不成", + "不单", + "不但", + "不独", + "不管", + "不光", + "不过", + "不仅", + "不拘", + "不论", + "不怕", + "不然", + "不如", + "不特", + "不惟", + "不问", + "不只", + "朝", + "朝着", + "趁", + "趁着", + "乘", + "冲", + "除", + "除此之外", + "除非", + "除了", + "此", + "此间", + "此外", + "从", + "从而", + "打", + "待", + "但", + "但是", + "当", + "当着", + "到", + "得", + "的", + "的话", + "等", + "等等", + "地", + "第", + "叮咚", + "对", + "对于", + "多", + "多少", + "而", + "而况", + "而且", + "而是", + "而外", + "而言", + "而已", + "尔后", + "反过来", + "反过来说", + "反之", + "非但", + "非徒", + "否则", + "嘎", + "嘎登", + "该", + "赶", + "个", + "各", + "各个", + "各位", + "各种", + "各自", + "给", + "根据", + "跟", + "故", + "故此", + "固然", + "关于", + "管", + "归", + "果然", + "果真", + "过", + "哈", + "哈哈", + "呵", + "和", + "何", + "何处", + "何况", + "何时", + "嘿", + "哼", + "哼唷", + "呼哧", + "乎", + "哗", + "还是", + "还有", + "换句话说", + "换言之", + "或", + "或是", + "或者", + "极了", + "及", + "及其", + "及至", + "即", + "即便", + "即或", + "即令", + "即若", + "即使", + "几", + "几时", + "己", + "既", + "既然", + "既是", + "继而", + "加之", + "假如", + "假若", + "假使", + "鉴于", + "将", + "较", + "较之", + "叫", + "接着", + "结果", + "借", + "紧接着", + "进而", + "尽", + "尽管", + "经", + "经过", + "就", + "就是", + "就是说", + "据", + "具体地说", + "具体说来", + "开始", + "开外", + "靠", + "咳", + "可", + "可见", + "可是", + "可以", + "况且", + "啦", + "来", + "来着", + "离", + "例如", + "哩", + "连", + "连同", + "两者", + "了", + "临", + "另", + "另外", + "另一方面", + "论", + "嘛", + "吗", + "慢说", + "漫说", + "冒", + "么", + "每", + "每当", + "们", + "莫若", + "某", + "某个", + "某些", + "拿", + "哪", + "哪边", + "哪儿", + "哪个", + "哪里", + "哪年", + "哪怕", + "哪天", + "哪些", + "哪样", + "那", + "那边", + "那儿", + "那个", + "那会儿", + "那里", + "那么", + "那么些", + "那么样", + "那时", + "那些", + "那样", + "乃", + "乃至", + "呢", + "能", + "你", + "你们", + "您", + "宁", + "宁可", + "宁肯", + "宁愿", + "哦", + "呕", + "啪达", + "旁人", + "呸", + "凭", + "凭借", + "其", + "其次", + "其二", + "其他", + "其它", + "其一", + "其余", + "其中", + "起", + "起见", + "岂但", + "恰恰相反", + "前后", + "前者", + "且", + "然而", + "然后", + "然则", + "让", + "人家", + "任", + "任何", + "任凭", + "如", + "如此", + "如果", + "如何", + "如其", + "如若", + "如上所述", + "若", + "若非", + "若是", + "啥", + "上下", + "尚且", + "设若", + "设使", + "甚而", + "甚么", + "甚至", + "省得", + "时候", + "什么", + "什么样", + "使得", + "是", + "是的", + "首先", + "谁", + "谁知", + "顺", + "顺着", + "似的", + "虽", + "虽然", + "虽说", + "虽则", + "随", + "随着", + "所", + "所以", + "他", + "他们", + "他人", + "它", + "它们", + "她", + "她们", + "倘", + "倘或", + "倘然", + "倘若", + "倘使", + "腾", + "替", + "通过", + "同", + "同时", + "哇", + "万一", + "往", + "望", + "为", + "为何", + "为了", + "为什么", + "为着", + "喂", + "嗡嗡", + "我", + "我们", + "呜", + "呜呼", + "乌乎", + "无论", + "无宁", + "毋宁", + "嘻", + "吓", + "相对而言", + "像", + "向", + "向着", + "嘘", + "呀", + "焉", + "沿", + "沿着", + "要", + "要不", + "要不然", + "要不是", + "要么", + "要是", + "也", + "也罢", + "也好", + "一", + "一般", + "一旦", + "一方面", + "一来", + "一切", + "一样", + "一则", + "依", + "依照", + "矣", + "以", + "以便", + "以及", + "以免", + "以至", + "以至于", + "以致", + "抑或", + "因", + "因此", + "因而", + "因为", + "哟", + "用", + "由", + "由此可见", + "由于", + "有", + "有的", + "有关", + "有些", + "又", + "于", + "于是", + "于是乎", + "与", + "与此同时", + "与否", + "与其", + "越是", + "云云", + "哉", + "再说", + "再者", + "在", + "在下", + "咱", + "咱们", + "则", + "怎", + "怎么", + "怎么办", + "怎么样", + "怎样", + "咋", + "照", + "照着", + "者", + "这", + "这边", + "这儿", + "这个", + "这会儿", + "这就是说", + "这里", + "这么", + "这么点儿", + "这么些", + "这么样", + "这时", + "这些", + "这样", + "正如", + "吱", + "之", + "之类", + "之所以", + "之一", + "只是", + "只限", + "只要", + "只有", + "至", + "至于", + "诸位", + "着", + "着呢", + "自", + "自从", + "自个儿", + "自各儿", + "自己", + "自家", + "自身", + "综上所述", + "总的来看", + "总的来说", + "总的说来", + "总而言之", + "总之", + "纵", + "纵令", + "纵然", + "纵使", + "遵照", + "作为", + "兮", + "呃", + "呗", + "咚", + "咦", + "喏", + "啐", + "喔唷", + "嗬", + "嗯", + "嗳", + "~", + "!", + ".", + ":", + '"', + "'", + "(", + ")", + "*", + "A", + "白", + "社会主义", + "--", + "..", + ">>", + " [", + " ]", + "", + "<", + ">", + "/", + "\\", + "|", + "-", + "_", + "+", + "=", + "&", + "^", + "%", + "#", + "@", + "`", + ";", + "$", + "(", + ")", + "——", + "—", + "¥", + "·", + "...", + "‘", + "’", + "〉", + "〈", + "…", + " ", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "二", + "三", + "四", + "五", + "六", + "七", + "八", + "九", + "零", + ">", + "<", + "@", + "#", + "$", + "%", + "︿", + "&", + "*", + "+", + "~", + "|", + "[", + "]", + "{", + "}", + "啊哈", + "啊呀", + "啊哟", + "挨次", + "挨个", + "挨家挨户", + "挨门挨户", + "挨门逐户", + "挨着", + "按理", + "按期", + "按时", + "按说", + "暗地里", + "暗中", + "暗自", + "昂然", + "八成", + "白白", + "半", + "梆", + "保管", + "保险", + "饱", + "背地里", + "背靠背", + "倍感", + "倍加", + "本人", + "本身", + "甭", + "比起", + "比如说", + "比照", + "毕竟", + "必", + "必定", + "必将", + "必须", + "便", + "别人", + "并非", + "并肩", + "并没", + "并没有", + "并排", + "并无", + "勃然", + "不", + "不必", + "不常", + "不大", + "不但...而且", + "不得", + "不得不", + "不得了", + "不得已", + "不迭", + "不定", + "不对", + "不妨", + "不管怎样", + "不会", + "不仅...而且", + "不仅仅", + "不仅仅是", + "不经意", + "不可开交", + "不可抗拒", + "不力", + "不了", + "不料", + "不满", + "不免", + "不能不", + "不起", + "不巧", + "不然的话", + "不日", + "不少", + "不胜", + "不时", + "不是", + "不同", + "不能", + "不要", + "不外", + "不外乎", + "不下", + "不限", + "不消", + "不已", + "不亦乐乎", + "不由得", + "不再", + "不择手段", + "不怎么", + "不曾", + "不知不觉", + "不止", + "不止一次", + "不至于", + "才", + "才能", + "策略地", + "差不多", + "差一点", + "常", + "常常", + "常言道", + "常言说", + "常言说得好", + "长此下去", + "长话短说", + "长期以来", + "长线", + "敞开儿", + "彻夜", + "陈年", + "趁便", + "趁机", + "趁热", + "趁势", + "趁早", + "成年", + "成年累月", + "成心", + "乘机", + "乘胜", + "乘势", + "乘隙", + "乘虚", + "诚然", + "迟早", + "充分", + "充其极", + "充其量", + "抽冷子", + "臭", + "初", + "出", + "出来", + "出去", + "除此", + "除此而外", + "除此以外", + "除开", + "除去", + "除却", + "除外", + "处处", + "川流不息", + "传", + "传说", + "传闻", + "串行", + "纯", + "纯粹", + "此后", + "此中", + "次第", + "匆匆", + "从不", + "从此", + "从此以后", + "从古到今", + "从古至今", + "从今以后", + "从宽", + "从来", + "从轻", + "从速", + "从头", + "从未", + "从无到有", + "从小", + "从新", + "从严", + "从优", + "从早到晚", + "从中", + "从重", + "凑巧", + "粗", + "存心", + "达旦", + "打从", + "打开天窗说亮话", + "大", + "大不了", + "大大", + "大抵", + "大都", + "大多", + "大凡", + "大概", + "大家", + "大举", + "大略", + "大面儿上", + "大事", + "大体", + "大体上", + "大约", + "大张旗鼓", + "大致", + "呆呆地", + "带", + "殆", + "待到", + "单", + "单纯", + "单单", + "但愿", + "弹指之间", + "当场", + "当儿", + "当即", + "当口儿", + "当然", + "当庭", + "当头", + "当下", + "当真", + "当中", + "倒不如", + "倒不如说", + "倒是", + "到处", + "到底", + "到了儿", + "到目前为止", + "到头", + "到头来", + "得起", + "得天独厚", + "的确", + "等到", + "叮当", + "顶多", + "定", + "动不动", + "动辄", + "陡然", + "都", + "独", + "独自", + "断然", + "顿时", + "多次", + "多多", + "多多少少", + "多多益善", + "多亏", + "多年来", + "多年前", + "而后", + "而论", + "而又", + "尔等", + "二话不说", + "二话没说", + "反倒", + "反倒是", + "反而", + "反手", + "反之亦然", + "反之则", + "方", + "方才", + "方能", + "放量", + "非常", + "非得", + "分期", + "分期分批", + "分头", + "奋勇", + "愤然", + "风雨无阻", + "逢", + "弗", + "甫", + "嘎嘎", + "该当", + "概", + "赶快", + "赶早不赶晚", + "敢", + "敢情", + "敢于", + "刚", + "刚才", + "刚好", + "刚巧", + "高低", + "格外", + "隔日", + "隔夜", + "个人", + "各式", + "更", + "更加", + "更进一步", + "更为", + "公然", + "共", + "共总", + "够瞧的", + "姑且", + "古来", + "故而", + "故意", + "固", + "怪", + "怪不得", + "惯常", + "光", + "光是", + "归根到底", + "归根结底", + "过于", + "毫不", + "毫无", + "毫无保留地", + "毫无例外", + "好在", + "何必", + "何尝", + "何妨", + "何苦", + "何乐而不为", + "何须", + "何止", + "很", + "很多", + "很少", + "轰然", + "后来", + "呼啦", + "忽地", + "忽然", + "互", + "互相", + "哗啦", + "话说", + "还", + "恍然", + "会", + "豁然", + "活", + "伙同", + "或多或少", + "或许", + "基本", + "基本上", + "基于", + "极", + "极大", + "极度", + "极端", + "极力", + "极其", + "极为", + "急匆匆", + "即将", + "即刻", + "即是说", + "几度", + "几番", + "几乎", + "几经", + "既...又", + "继之", + "加上", + "加以", + "间或", + "简而言之", + "简言之", + "简直", + "见", + "将才", + "将近", + "将要", + "交口", + "较比", + "较为", + "接连不断", + "接下来", + "皆可", + "截然", + "截至", + "藉以", + "借此", + "借以", + "届时", + "仅", + "仅仅", + "谨", + "进来", + "进去", + "近", + "近几年来", + "近来", + "近年来", + "尽管如此", + "尽可能", + "尽快", + "尽量", + "尽然", + "尽如人意", + "尽心竭力", + "尽心尽力", + "尽早", + "精光", + "经常", + "竟", + "竟然", + "究竟", + "就此", + "就地", + "就算", + "居然", + "局外", + "举凡", + "据称", + "据此", + "据实", + "据说", + "据我所知", + "据悉", + "具体来说", + "决不", + "决非", + "绝", + "绝不", + "绝顶", + "绝对", + "绝非", + "均", + "喀", + "看", + "看来", + "看起来", + "看上去", + "看样子", + "可好", + "可能", + "恐怕", + "快", + "快要", + "来不及", + "来得及", + "来讲", + "来看", + "拦腰", + "牢牢", + "老", + "老大", + "老老实实", + "老是", + "累次", + "累年", + "理当", + "理该", + "理应", + "历", + "立", + "立地", + "立刻", + "立马", + "立时", + "联袂", + "连连", + "连日", + "连日来", + "连声", + "连袂", + "临到", + "另方面", + "另行", + "另一个", + "路经", + "屡", + "屡次", + "屡次三番", + "屡屡", + "缕缕", + "率尔", + "率然", + "略", + "略加", + "略微", + "略为", + "论说", + "马上", + "蛮", + "满", + "没", + "没有", + "每逢", + "每每", + "每时每刻", + "猛然", + "猛然间", + "莫", + "莫不", + "莫非", + "莫如", + "默默地", + "默然", + "呐", + "那末", + "奈", + "难道", + "难得", + "难怪", + "难说", + "内", + "年复一年", + "凝神", + "偶而", + "偶尔", + "怕", + "砰", + "碰巧", + "譬如", + "偏偏", + "乒", + "平素", + "颇", + "迫于", + "扑通", + "其后", + "其实", + "奇", + "齐", + "起初", + "起来", + "起首", + "起头", + "起先", + "岂", + "岂非", + "岂止", + "迄", + "恰逢", + "恰好", + "恰恰", + "恰巧", + "恰如", + "恰似", + "千", + "千万", + "千万千万", + "切", + "切不可", + "切莫", + "切切", + "切勿", + "窃", + "亲口", + "亲身", + "亲手", + "亲眼", + "亲自", + "顷", + "顷刻", + "顷刻间", + "顷刻之间", + "请勿", + "穷年累月", + "取道", + "去", + "权时", + "全都", + "全力", + "全年", + "全然", + "全身心", + "然", + "人人", + "仍", + "仍旧", + "仍然", + "日复一日", + "日见", + "日渐", + "日益", + "日臻", + "如常", + "如此等等", + "如次", + "如今", + "如期", + "如前所述", + "如上", + "如下", + "汝", + "三番两次", + "三番五次", + "三天两头", + "瑟瑟", + "沙沙", + "上", + "上来", + "上去", + "一个", + "月", + "日", + "\n", } diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index b77c6562b2..27e4f383ad 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -8,7 +8,6 @@ from models.dataset import Dataset class BaseKeyword(ABC): - def __init__(self, dataset: Dataset): self.dataset = dataset @@ -31,15 +30,12 @@ class BaseKeyword(ABC): def delete(self) -> None: raise NotImplementedError - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts[:]: - doc_id = text.metadata['doc_id'] + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) @@ -47,4 +43,4 @@ class BaseKeyword(ABC): return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata['doc_id'] for text in texts] + return [text.metadata["doc_id"] for text in texts] diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py index 6ac610f82b..3c99f33be6 100644 --- a/api/core/rag/datasource/keyword/keyword_factory.py +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -20,9 +20,7 @@ class Keyword: raise ValueError("Keyword store must be specified.") if keyword_type == "jieba": - return Jieba( - dataset=self._dataset - ) + return Jieba(dataset=self._dataset) else: raise ValueError(f"Keyword store {keyword_type} is not supported.") @@ -41,10 +39,7 @@ class Keyword: def delete(self) -> None: self._keyword_processor.delete() - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: return self._keyword_processor.search(query, **kwargs) def __getattr__(self, name): diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 0dac9bfae6..afac1bf300 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -12,73 +12,83 @@ from extensions.ext_database import db from models.dataset import Dataset default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } class RetrievalService: - @classmethod - def retrieve(cls, retrieval_method: str, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float] = .0, - reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model', - weights: Optional[dict] = None): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + def retrieve( + cls, + retrieval_method: str, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float] = 0.0, + reranking_model: Optional[dict] = None, + reranking_mode: Optional[str] = "reranking_model", + weights: Optional[dict] = None, + ): + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] all_documents = [] threads = [] exceptions = [] # retrieval_model source with keyword - if retrieval_method == 'keyword_search': - keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) + if retrieval_method == "keyword_search": + keyword_thread = threading.Thread( + target=RetrievalService.keyword_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "top_k": top_k, + "all_documents": all_documents, + "exceptions": exceptions, + }, + ) threads.append(keyword_thread) keyword_thread.start() # retrieval_model source with semantic if RetrievalMethod.is_support_semantic_search(retrieval_method): - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'score_threshold': score_threshold, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'retrieval_method': retrieval_method, - 'exceptions': exceptions, - }) + embedding_thread = threading.Thread( + target=RetrievalService.embedding_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "top_k": top_k, + "score_threshold": score_threshold, + "reranking_model": reranking_model, + "all_documents": all_documents, + "retrieval_method": retrieval_method, + "exceptions": exceptions, + }, + ) threads.append(embedding_thread) embedding_thread.start() # retrieval source with full text if RetrievalMethod.is_support_fulltext_search(retrieval_method): - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'retrieval_method': retrieval_method, - 'score_threshold': score_threshold, - 'top_k': top_k, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) + full_text_index_thread = threading.Thread( + target=RetrievalService.full_text_index_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "retrieval_method": retrieval_method, + "score_threshold": score_threshold, + "top_k": top_k, + "reranking_model": reranking_model, + "all_documents": all_documents, + "exceptions": exceptions, + }, + ) threads.append(full_text_index_thread) full_text_index_thread.start() @@ -86,110 +96,117 @@ class RetrievalService: thread.join() if exceptions: - exception_message = ';\n'.join(exceptions) + exception_message = ";\n".join(exceptions) raise Exception(exception_message) if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode, - reranking_model, weights, False) + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), reranking_mode, reranking_model, weights, False + ) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k + query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k ) return all_documents @classmethod - def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, all_documents: list, exceptions: list): + def keyword_search( + cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - keyword = Keyword( - dataset=dataset - ) + keyword = Keyword(dataset=dataset) - documents = keyword.search( - cls.escape_query_for_search(query), - top_k=top_k - ) + documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) all_documents.extend(documents) except Exception as e: exceptions.append(str(e)) @classmethod - def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrieval_method: str, exceptions: list): + def embedding_search( + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) documents = vector.search_by_vector( cls.escape_query_for_search(query), - search_type='similarity_score_threshold', + search_type="similarity_score_threshold", top_k=top_k, score_threshold=score_threshold, - filter={ - 'group_id': [dataset.id] - } + filter={"group_id": [dataset.id]}, ) if documents: - if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), - RerankMode.RERANKING_MODEL.value, - reranking_model, None, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) + if ( + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value + ): + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False + ) + all_documents.extend( + data_post_processor.invoke( + query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + ) + ) else: all_documents.extend(documents) except Exception as e: exceptions.append(str(e)) @classmethod - def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrieval_method: str, exceptions: list): + def full_text_index_search( + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() vector_processor = Vector( dataset=dataset, ) - documents = vector_processor.search_by_full_text( - cls.escape_query_for_search(query), - top_k=top_k - ) + documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) if documents: - if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), - RerankMode.RERANKING_MODEL.value, - reranking_model, None, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) + if ( + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value + ): + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False + ) + all_documents.extend( + data_post_processor.invoke( + query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + ) + ) else: all_documents.extend(documents) except Exception as e: @@ -197,4 +214,4 @@ class RetrievalService: @staticmethod def escape_query_for_search(query: str) -> str: - return query.replace('"', '\\"') \ No newline at end of file + return query.replace('"', '\\"') diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index b78e2a59b1..a9c0eefb78 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -29,6 +29,7 @@ class AnalyticdbConfig(BaseModel): namespace_password: str = (None,) metrics: str = ("cosine",) read_timeout: int = 60000 + def to_analyticdb_client_params(self): return { "access_key_id": self.access_key_id, @@ -37,6 +38,7 @@ class AnalyticdbConfig(BaseModel): "read_timeout": self.read_timeout, } + class AnalyticdbVector(BaseVector): _instance = None _init = False @@ -57,9 +59,7 @@ class AnalyticdbVector(BaseVector): except: raise ImportError(_import_err_msg) self.config = config - self._client_config = open_api_models.Config( - user_agent="dify", **config.to_analyticdb_client_params() - ) + self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params()) self._client = Client(self._client_config) self._initialize() AnalyticdbVector._init = True @@ -77,6 +77,7 @@ class AnalyticdbVector(BaseVector): def _initialize_vector_database(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.InitVectorDatabaseRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -88,6 +89,7 @@ class AnalyticdbVector(BaseVector): def _create_namespace_if_not_exists(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException + try: request = gpdb_20160503_models.DescribeNamespaceRequest( dbinstance_id=self.config.instance_id, @@ -109,13 +111,12 @@ class AnalyticdbVector(BaseVector): ) self._client.create_namespace(request) else: - raise ValueError( - f"failed to create namespace {self.config.namespace}: {e}" - ) + raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") def _create_collection_if_not_exists(self, embedding_dimension: int): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException + cache_key = f"vector_indexing_{self._collection_name}" lock_name = f"{cache_key}_lock" with redis_client.lock(lock_name, timeout=20): @@ -149,9 +150,7 @@ class AnalyticdbVector(BaseVector): ) self._client.create_collection(request) else: - raise ValueError( - f"failed to create collection {self._collection_name}: {e}" - ) + raise ValueError(f"failed to create collection {self._collection_name}: {e}") redis_client.set(collection_exist_cache_key, 1, ex=3600) def get_type(self) -> str: @@ -162,10 +161,9 @@ class AnalyticdbVector(BaseVector): self._create_collection_if_not_exists(dimension) self.add_texts(texts, embeddings) - def add_texts( - self, documents: list[Document], embeddings: list[list[float]], **kwargs - ): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] for doc, embedding in zip(documents, embeddings, strict=True): metadata = { @@ -191,6 +189,7 @@ class AnalyticdbVector(BaseVector): def text_exists(self, id: str) -> bool: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -202,13 +201,14 @@ class AnalyticdbVector(BaseVector): vector=None, content=None, top_k=1, - filter=f"ref_doc_id='{id}'" + filter=f"ref_doc_id='{id}'", ) response = self._client.query_collection_data(request) return len(response.body.matches.match) > 0 def delete_by_ids(self, ids: list[str]) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + ids_str = ",".join(f"'{id}'" for id in ids) ids_str = f"({ids_str})" request = gpdb_20160503_models.DeleteCollectionDataRequest( @@ -224,6 +224,7 @@ class AnalyticdbVector(BaseVector): def delete_by_metadata_field(self, key: str, value: str) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.DeleteCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -235,15 +236,10 @@ class AnalyticdbVector(BaseVector): ) self._client.delete_collection_data(request) - def search_by_vector( - self, query_vector: list[float], **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = ( - kwargs.get("score_threshold", 0.0) - if kwargs.get("score_threshold", 0.0) - else 0.0 - ) + + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -270,11 +266,8 @@ class AnalyticdbVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = ( - kwargs.get("score_threshold", 0.0) - if kwargs.get("score_threshold", 0.0) - else 0.0 - ) + + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -304,6 +297,7 @@ class AnalyticdbVector(BaseVector): def delete(self) -> None: try: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.DeleteCollectionRequest( collection=self._collection_name, dbinstance_id=self.config.instance_id, @@ -315,19 +309,16 @@ class AnalyticdbVector(BaseVector): except Exception as e: raise e + class AnalyticdbVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings): if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict["vector_store"][ - "class_prefix" - ] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name) - ) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)) # handle optional params if dify_config.ANALYTICDB_KEY_ID is None: diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 3629887b44..cb38cf94a9 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -27,21 +27,20 @@ class ChromaConfig(BaseModel): settings = Settings( # auth chroma_client_auth_provider=self.auth_provider, - chroma_client_auth_credentials=self.auth_credentials + chroma_client_auth_credentials=self.auth_credentials, ) return { - 'host': self.host, - 'port': self.port, - 'ssl': False, - 'tenant': self.tenant, - 'database': self.database, - 'settings': settings, + "host": self.host, + "port": self.port, + "ssl": False, + "tenant": self.tenant, + "database": self.database, + "settings": settings, } class ChromaVector(BaseVector): - def __init__(self, collection_name: str, config: ChromaConfig): super().__init__(collection_name) self._client_config = config @@ -58,9 +57,9 @@ class ChromaVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str): - lock_name = 'vector_indexing_lock_{}'.format(collection_name) + lock_name = "vector_indexing_lock_{}".format(collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return self._client.get_or_create_collection(collection_name) @@ -76,7 +75,7 @@ class ChromaVector(BaseVector): def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) - collection.delete(where={key: {'$eq': value}}) + collection.delete(where={key: {"$eq": value}}) def delete(self): self._client.delete_collection(self._collection_name) @@ -93,26 +92,26 @@ class ChromaVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: collection = self._client.get_or_create_collection(self._collection_name) results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 - ids: list[str] = results['ids'][0] - documents: list[str] = results['documents'][0] - metadatas: dict[str, Any] = results['metadatas'][0] - distances: list[float] = results['distances'][0] + ids: list[str] = results["ids"][0] + documents: list[str] = results["documents"][0] + metadatas: dict[str, Any] = results["metadatas"][0] + distances: list[float] = results["distances"][0] docs = [] for index in range(len(ids)): distance = distances[index] metadata = metadatas[index] if distance >= score_threshold: - metadata['score'] = distance + metadata["score"] = distance doc = Document( page_content=documents[index], metadata=metadata, ) docs.append(doc) - # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -123,15 +122,12 @@ class ChromaVector(BaseVector): class ChromaVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - index_struct_dict = { - "type": VectorType.CHROMA, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) return ChromaVector( diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 233539756f..76c808f76e 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -26,15 +26,15 @@ class ElasticSearchConfig(BaseModel): username: str password: str - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config PORT is required") - if not values['username']: + if not values["username"]: raise ValueError("config USERNAME is required") - if not values['password']: + if not values["password"]: raise ValueError("config PASSWORD is required") return values @@ -50,10 +50,10 @@ class ElasticSearchVector(BaseVector): def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: try: parsed_url = urlparse(config.host) - if parsed_url.scheme in ['http', 'https']: - hosts = f'{config.host}:{config.port}' + if parsed_url.scheme in ["http", "https"]: + hosts = f"{config.host}:{config.port}" else: - hosts = f'http://{config.host}:{config.port}' + hosts = f"http://{config.host}:{config.port}" client = Elasticsearch( hosts=hosts, basic_auth=(config.username, config.password), @@ -68,25 +68,27 @@ class ElasticSearchVector(BaseVector): def _get_version(self) -> str: info = self._client.info() - return info['version']['number'] + return info["version"]["number"] def _check_version(self): - if self._version < '8.0.0': + if self._version < "8.0.0": raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") def get_type(self) -> str: - return 'elasticsearch' + return "elasticsearch" def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) for i in range(len(documents)): - self._client.index(index=self._collection_name, - id=uuids[i], - document={ - Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i] if embeddings[i] else None, - Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {} - }) + self._client.index( + index=self._collection_name, + id=uuids[i], + document={ + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i] if embeddings[i] else None, + Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}, + }, + ) self._client.indices.refresh(index=self._collection_name) return uuids @@ -98,15 +100,9 @@ class ElasticSearchVector(BaseVector): self._client.delete(index=self._collection_name, id=id) def delete_by_metadata_field(self, key: str, value: str) -> None: - query_str = { - 'query': { - 'match': { - f'metadata.{key}': f'{value}' - } - } - } + query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} results = self._client.search(index=self._collection_name, body=query_str) - ids = [hit['_id'] for hit in results['hits']['hits']] + ids = [hit["_id"] for hit in results["hits"]["hits"]] if ids: self.delete_by_ids(ids) @@ -115,44 +111,44 @@ class ElasticSearchVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 10) - knn = { - "field": Field.VECTOR.value, - "query_vector": query_vector, - "k": top_k - } + knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k} results = self._client.search(index=self._collection_name, knn=knn, size=top_k) docs_and_scores = [] - for hit in results['hits']['hits']: + for hit in results["hits"]["hits"]: docs_and_scores.append( - (Document(page_content=hit['_source'][Field.CONTENT_KEY.value], - vector=hit['_source'][Field.VECTOR.value], - metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score'])) + ( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ), + hit["_score"], + ) + ) docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 if score > score_threshold: - doc.metadata['score'] = score + doc.metadata["score"] = score docs.append(doc) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str = { - "match": { - Field.CONTENT_KEY.value: query - } - } + query_str = {"match": {Field.CONTENT_KEY.value: query}} results = self._client.search(index=self._collection_name, query=query_str) docs = [] - for hit in results['hits']['hits']: - docs.append(Document( - page_content=hit['_source'][Field.CONTENT_KEY.value], - vector=hit['_source'][Field.VECTOR.value], - metadata=hit['_source'][Field.METADATA_KEY.value], - )) + for hit in results["hits"]["hits"]: + docs.append( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ) + ) return docs @@ -162,11 +158,11 @@ class ElasticSearchVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = f'vector_indexing_lock_{self._collection_name}' + lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = f'vector_indexing_{self._collection_name}' + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): logger.info(f"Collection {self._collection_name} already exists.") return @@ -179,14 +175,14 @@ class ElasticSearchVector(BaseVector): Field.VECTOR.value: { # Make sure the dimension is correct here "type": "dense_vector", "dims": dim, - "similarity": "cosine" + "similarity": "cosine", }, Field.METADATA_KEY.value: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type - } - } + }, + }, } } self._client.indices.create(index=self._collection_name, mappings=mappings) @@ -197,22 +193,21 @@ class ElasticSearchVector(BaseVector): class ElasticSearchVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) config = current_app.config return ElasticSearchVector( index_name=collection_name, config=ElasticSearchConfig( - host=config.get('ELASTICSEARCH_HOST'), - port=config.get('ELASTICSEARCH_PORT'), - username=config.get('ELASTICSEARCH_USERNAME'), - password=config.get('ELASTICSEARCH_PASSWORD'), + host=config.get("ELASTICSEARCH_HOST"), + port=config.get("ELASTICSEARCH_PORT"), + username=config.get("ELASTICSEARCH_USERNAME"), + password=config.get("ELASTICSEARCH_PASSWORD"), ), - attributes=[] + attributes=[], ) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index c1c73d1c0d..1d08046641 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -27,44 +27,39 @@ class MilvusConfig(BaseModel): batch_size: int = 100 database: str = "default" - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values.get('uri'): + if not values.get("uri"): raise ValueError("config MILVUS_URI is required") - if not values.get('user'): + if not values.get("user"): raise ValueError("config MILVUS_USER is required") - if not values.get('password'): + if not values.get("password"): raise ValueError("config MILVUS_PASSWORD is required") return values def to_milvus_params(self): return { - 'uri': self.uri, - 'token': self.token, - 'user': self.user, - 'password': self.password, - 'db_name': self.database, + "uri": self.uri, + "token": self.token, + "user": self.user, + "password": self.password, + "db_name": self.database, } class MilvusVector(BaseVector): - def __init__(self, collection_name: str, config: MilvusConfig): super().__init__(collection_name) self._client_config = config self._client = self._init_client(config) - self._consistency_level = 'Session' + self._consistency_level = "Session" self._fields = [] def get_type(self) -> str: return VectorType.MILVUS def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - index_params = { - 'metric_type': 'IP', - 'index_type': "HNSW", - 'params': {"M": 8, "efConstruction": 64} - } + index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} metadatas = [d.metadata for d in texts] self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) @@ -75,7 +70,7 @@ class MilvusVector(BaseVector): insert_dict = { Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], - Field.METADATA_KEY.value: documents[i].metadata + Field.METADATA_KEY.value: documents[i].metadata, } insert_dict_list.append(insert_dict) # Total insert count @@ -84,22 +79,20 @@ class MilvusVector(BaseVector): pks: list[str] = [] for i in range(0, total_count, 1000): - batch_insert_list = insert_dict_list[i:i + 1000] + batch_insert_list = insert_dict_list[i : i + 1000] # Insert into the collection. try: ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) pks.extend(ids) except MilvusException as e: - logger.error( - "Failed to insert batch starting at entity: %s/%s", i, total_count - ) + logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count) raise e return pks def get_ids_by_metadata_field(self, key: str, value: str): - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["{key}"] == "{value}"', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"] + ) if result: return [item["id"] for item in result] else: @@ -107,17 +100,15 @@ class MilvusVector(BaseVector): def delete_by_metadata_field(self, key: str, value: str): if self._client.has_collection(self._collection_name): - ids = self.get_ids_by_metadata_field(key, value) if ids: self._client.delete(collection_name=self._collection_name, pks=ids) def delete_by_ids(self, ids: list[str]) -> None: if self._client.has_collection(self._collection_name): - - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["doc_id"] in {ids}', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"] + ) if result: ids = [item["id"] for item in result] self._client.delete(collection_name=self._collection_name, pks=ids) @@ -130,29 +121,28 @@ class MilvusVector(BaseVector): if not self._client.has_collection(self._collection_name): return False - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["doc_id"] == "{id}"', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"] + ) return len(result) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - # Set search parameters. - results = self._client.search(collection_name=self._collection_name, - data=[query_vector], - limit=kwargs.get('top_k', 4), - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], - ) + results = self._client.search( + collection_name=self._collection_name, + data=[query_vector], + limit=kwargs.get("top_k", 4), + output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + ) # Organize results. docs = [] for result in results[0]: - metadata = result['entity'].get(Field.METADATA_KEY.value) - metadata['score'] = result['distance'] - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 - if result['distance'] > score_threshold: - doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value), - metadata=metadata) + metadata = result["entity"].get(Field.METADATA_KEY.value) + metadata["score"] = result["distance"] + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + if result["distance"] > score_threshold: + doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) return docs @@ -161,11 +151,11 @@ class MilvusVector(BaseVector): return [] def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return # Grab the existing collection if it exists @@ -180,19 +170,11 @@ class MilvusVector(BaseVector): fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) # Create the text field - fields.append( - FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) - ) + fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)) # Create the primary key field - fields.append( - FieldSchema( - Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True - ) - ) + fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) # Create the vector field, supports binary or float vectors - fields.append( - FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) - ) + fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) # Create the schema for the collection schema = CollectionSchema(fields) @@ -208,9 +190,12 @@ class MilvusVector(BaseVector): # Create the collection collection_name = self._collection_name - self._client.create_collection(collection_name=collection_name, - schema=schema, index_params=index_params_obj, - consistency_level=self._consistency_level) + self._client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_params_obj, + consistency_level=self._consistency_level, + ) redis_client.set(collection_exist_cache_key, 1, ex=3600) def _init_client(self, config) -> MilvusClient: @@ -221,13 +206,12 @@ class MilvusVector(BaseVector): class MilvusVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) return MilvusVector( collection_name=collection_name, @@ -237,5 +221,5 @@ class MilvusVectorFactory(AbstractVectorFactory): user=dify_config.MILVUS_USER, password=dify_config.MILVUS_PASSWORD, database=dify_config.MILVUS_DATABASE, - ) + ), ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 05e75effef..90464ac42a 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -31,7 +31,6 @@ class SortOrder(Enum): class MyScaleVector(BaseVector): - def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"): super().__init__(collection_name) self._config = config @@ -80,7 +79,7 @@ class MyScaleVector(BaseVector): doc_id, self.escape_str(doc.page_content), embeddings[i], - json.dumps(doc.metadata) if doc.metadata else {} + json.dumps(doc.metadata) if doc.metadata else {}, ) values.append(str(row)) ids.append(doc_id) @@ -101,7 +100,8 @@ class MyScaleVector(BaseVector): def delete_by_ids(self, ids: list[str]) -> None: self._client.command( - f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}") + f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" + ) def get_ids_by_metadata_field(self, key: str, value: str): rows = self._client.query( @@ -122,9 +122,12 @@ class MyScaleVector(BaseVector): def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get('score_threshold') or 0.0 - where_str = f"WHERE dist < {1 - score_threshold}" if \ - self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else "" + score_threshold = kwargs.get("score_threshold") or 0.0 + where_str = ( + f"WHERE dist < {1 - score_threshold}" + if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 + else "" + ) sql = f""" SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} {where_str} ORDER BY dist {order.value} LIMIT {top_k} @@ -133,7 +136,7 @@ class MyScaleVector(BaseVector): return [ Document( page_content=r["text"], - vector=r['vector'], + vector=r["vector"], metadata=r["metadata"], ) for r in self._client.query(sql).named_results() @@ -149,13 +152,12 @@ class MyScaleVector(BaseVector): class MyScaleVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) return MyScaleVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index c95d202173..ecd7e0271c 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -28,11 +28,11 @@ class OpenSearchConfig(BaseModel): password: Optional[str] = None secure: bool = False - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values.get('host'): + if not values.get("host"): raise ValueError("config OPENSEARCH_HOST is required") - if not values.get('port'): + if not values.get("port"): raise ValueError("config OPENSEARCH_PORT is required") return values @@ -44,19 +44,18 @@ class OpenSearchConfig(BaseModel): def to_opensearch_params(self) -> dict[str, Any]: params = { - 'hosts': [{'host': self.host, 'port': self.port}], - 'use_ssl': self.secure, - 'verify_certs': self.secure, + "hosts": [{"host": self.host, "port": self.port}], + "use_ssl": self.secure, + "verify_certs": self.secure, } if self.user and self.password: - params['http_auth'] = (self.user, self.password) + params["http_auth"] = (self.user, self.password) if self.secure: - params['ssl_context'] = self.create_ssl_context() + params["ssl_context"] = self.create_ssl_context() return params class OpenSearchVector(BaseVector): - def __init__(self, collection_name: str, config: OpenSearchConfig): super().__init__(collection_name) self._client_config = config @@ -81,7 +80,7 @@ class OpenSearchVector(BaseVector): Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.METADATA_KEY.value: documents[i].metadata, - } + }, } actions.append(action) @@ -90,8 +89,8 @@ class OpenSearchVector(BaseVector): def get_ids_by_metadata_field(self, key: str, value: str): query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} response = self._client.search(index=self._collection_name.lower(), body=query) - if response['hits']['hits']: - return [hit['_id'] for hit in response['hits']['hits']] + if response["hits"]["hits"]: + return [hit["_id"] for hit in response["hits"]["hits"]] else: return None @@ -110,7 +109,7 @@ class OpenSearchVector(BaseVector): actual_ids = [] for doc_id in ids: - es_ids = self.get_ids_by_metadata_field('doc_id', doc_id) + es_ids = self.get_ids_by_metadata_field("doc_id", doc_id) if es_ids: actual_ids.extend(es_ids) else: @@ -122,9 +121,9 @@ class OpenSearchVector(BaseVector): helpers.bulk(self._client, actions) except BulkIndexError as e: for error in e.errors: - delete_error = error.get('delete', {}) - status = delete_error.get('status') - doc_id = delete_error.get('_id') + delete_error = error.get("delete", {}) + status = delete_error.get("status") + doc_id = delete_error.get("_id") if status == 404: logger.warning(f"Document not found for deletion: {doc_id}") @@ -151,15 +150,8 @@ class OpenSearchVector(BaseVector): raise ValueError("All elements in query_vector should be floats") query = { - "size": kwargs.get('top_k', 4), - "query": { - "knn": { - Field.VECTOR.value: { - Field.VECTOR.value: query_vector, - "k": kwargs.get('top_k', 4) - } - } - } + "size": kwargs.get("top_k", 4), + "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}}, } try: @@ -169,17 +161,17 @@ class OpenSearchVector(BaseVector): raise docs = [] - for hit in response['hits']['hits']: - metadata = hit['_source'].get(Field.METADATA_KEY.value, {}) + for hit in response["hits"]["hits"]: + metadata = hit["_source"].get(Field.METADATA_KEY.value, {}) # Make sure metadata is a dictionary if metadata is None: metadata = {} - metadata['score'] = hit['_score'] - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 - if hit['_score'] > score_threshold: - doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata) + metadata["score"] = hit["_score"] + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + if hit["_score"] > score_threshold: + doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) return docs @@ -190,32 +182,28 @@ class OpenSearchVector(BaseVector): response = self._client.search(index=self._collection_name.lower(), body=full_text_query) docs = [] - for hit in response['hits']['hits']: - metadata = hit['_source'].get(Field.METADATA_KEY.value) - vector = hit['_source'].get(Field.VECTOR.value) - page_content = hit['_source'].get(Field.CONTENT_KEY.value) + for hit in response["hits"]["hits"]: + metadata = hit["_source"].get(Field.METADATA_KEY.value) + vector = hit["_source"].get(Field.VECTOR.value) + page_content = hit["_source"].get(Field.CONTENT_KEY.value) doc = Document(page_content=page_content, vector=vector, metadata=metadata) docs.append(doc) return docs def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = f'vector_indexing_lock_{self._collection_name.lower()}' + lock_name = f"vector_indexing_lock_{self._collection_name.lower()}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}' + collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}" if redis_client.get(collection_exist_cache_key): logger.info(f"Collection {self._collection_name.lower()} already exists.") return if not self._client.indices.exists(index=self._collection_name.lower()): index_body = { - "settings": { - "index": { - "knn": True - } - }, + "settings": {"index": {"knn": True}}, "mappings": { "properties": { Field.CONTENT_KEY.value: {"type": "text"}, @@ -226,20 +214,17 @@ class OpenSearchVector(BaseVector): "name": "hnsw", "space_type": "l2", "engine": "faiss", - "parameters": { - "ef_construction": 64, - "m": 8 - } - } + "parameters": {"ef_construction": 64, "m": 8}, + }, }, Field.METADATA_KEY.value: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type - } - } + }, + }, } - } + }, } self._client.indices.create(index=self._collection_name.lower(), body=index_body) @@ -248,17 +233,14 @@ class OpenSearchVector(BaseVector): class OpenSearchVectorFactory(AbstractVectorFactory): - def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) - + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) open_search_config = OpenSearchConfig( host=dify_config.OPENSEARCH_HOST, @@ -268,7 +250,4 @@ class OpenSearchVectorFactory(AbstractVectorFactory): secure=dify_config.OPENSEARCH_SECURE, ) - return OpenSearchVector( - collection_name=collection_name, - config=open_search_config - ) + return OpenSearchVector(collection_name=collection_name, config=open_search_config) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index aa2c6171c3..eb2e3e0a8c 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -31,7 +31,7 @@ class OracleVectorConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config ORACLE_HOST is required") @@ -103,9 +103,16 @@ class OracleVector(BaseVector): arraysize=cursor.arraysize, outconverter=self.numpy_converter_out, ) - def _create_connection_pool(self, config: OracleVectorConfig): - return oracledb.create_pool(user=config.user, password=config.password, dsn="{}:{}/{}".format(config.host, config.port, config.database), min=1, max=50, increment=1) + def _create_connection_pool(self, config: OracleVectorConfig): + return oracledb.create_pool( + user=config.user, + password=config.password, + dsn="{}:{}/{}".format(config.host, config.port, config.database), + min=1, + max=50, + increment=1, + ) @contextmanager def _get_cursor(self): @@ -136,13 +143,15 @@ class OracleVector(BaseVector): doc_id, doc.page_content, json.dumps(doc.metadata), - #array.array("f", embeddings[i]), + # array.array("f", embeddings[i]), numpy.array(embeddings[i]), ) ) - #print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") + # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") with self._get_cursor() as cur: - cur.executemany(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values) + cur.executemany( + f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values + ) return pks def text_exists(self, id: str) -> bool: @@ -157,7 +166,8 @@ class OracleVector(BaseVector): for record in cur: docs.append(Document(page_content=record[1], metadata=record[0])) return docs - #def get_ids_by_metadata_field(self, key: str, value: str): + + # def get_ids_by_metadata_field(self, key: str, value: str): # with self._get_cursor() as cur: # cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" ) # idss = [] @@ -184,7 +194,8 @@ class OracleVector(BaseVector): top_k = kwargs.get("top_k", 5) with self._get_cursor() as cur: cur.execute( - f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only" ,[numpy.array(query_vector)] + f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only", + [numpy.array(query_vector)], ) docs = [] score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 @@ -202,7 +213,7 @@ class OracleVector(BaseVector): score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 if len(query) > 0: # Check which language the query is in - zh_pattern = re.compile('[\u4e00-\u9fa5]+') + zh_pattern = re.compile("[\u4e00-\u9fa5]+") match = zh_pattern.search(query) entities = [] # match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split @@ -210,7 +221,15 @@ class OracleVector(BaseVector): words = pseg.cut(query) current_entity = "" for word, pos in words: - if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v': # nr: 人名, ns: 地名, nt: 机构名 + if ( + pos == "nr" + or pos == "Ng" + or pos == "eng" + or pos == "nz" + or pos == "n" + or pos == "ORG" + or pos == "v" + ): # nr: 人名, ns: 地名, nt: 机构名 current_entity += word else: if current_entity: @@ -220,22 +239,22 @@ class OracleVector(BaseVector): entities.append(current_entity) else: try: - nltk.data.find('tokenizers/punkt') - nltk.data.find('corpora/stopwords') + nltk.data.find("tokenizers/punkt") + nltk.data.find("corpora/stopwords") except LookupError: - nltk.download('punkt') - nltk.download('stopwords') + nltk.download("punkt") + nltk.download("stopwords") print("run download") - e_str = re.sub(r'[^\w ]', '', query) + e_str = re.sub(r"[^\w ]", "", query) all_tokens = nltk.word_tokenize(e_str) - stop_words = stopwords.words('english') + stop_words = stopwords.words("english") for token in all_tokens: if token not in stop_words: entities.append(token) with self._get_cursor() as cur: cur.execute( f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", - [" ACCUM ".join(entities)] + [" ACCUM ".join(entities)], ) docs = [] for record in cur: @@ -273,8 +292,7 @@ class OracleVectorFactory(AbstractVectorFactory): else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) return OracleVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index a48224070f..b778582e8a 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -31,27 +31,28 @@ class PgvectoRSConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config PGVECTO_RS_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config PGVECTO_RS_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config PGVECTO_RS_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config PGVECTO_RS_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config PGVECTO_RS_DATABASE is required") return values class PGVectoRS(BaseVector): - def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int): super().__init__(collection_name) self._client_config = config - self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + self._url = ( + f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + ) self._client = create_engine(self._url) with Session(self._client) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) @@ -80,9 +81,9 @@ class PGVectoRS(BaseVector): self.add_texts(texts, embeddings) def create_collection(self, dimension: int): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" @@ -133,9 +134,7 @@ class PGVectoRS(BaseVector): def get_ids_by_metadata_field(self, key: str, value: str): result = None with Session(self._client) as session: - select_statement = sql_text( - f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; " - ) + select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ") result = session.execute(select_statement).fetchall() if result: return [item[0] for item in result] @@ -143,12 +142,11 @@ class PGVectoRS(BaseVector): return None def delete_by_metadata_field(self, key: str, value: str): - ids = self.get_ids_by_metadata_field(key, value) if ids: with Session(self._client) as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") - session.execute(select_statement, {'ids': ids}) + session.execute(select_statement, {"ids": ids}) session.commit() def delete_by_ids(self, ids: list[str]) -> None: @@ -156,13 +154,13 @@ class PGVectoRS(BaseVector): select_statement = sql_text( f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); " ) - result = session.execute(select_statement, {'doc_ids': ids}).fetchall() + result = session.execute(select_statement, {"doc_ids": ids}).fetchall() if result: ids = [item[0] for item in result] if ids: with Session(self._client) as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") - session.execute(select_statement, {'ids': ids}) + session.execute(select_statement, {"ids": ids}) session.commit() def delete(self) -> None: @@ -187,7 +185,7 @@ class PGVectoRS(BaseVector): query_vector, ).label("distance"), ) - .limit(kwargs.get('top_k', 2)) + .limit(kwargs.get("top_k", 2)) .order_by("distance") ) res = session.execute(stmt) @@ -198,11 +196,10 @@ class PGVectoRS(BaseVector): for record, dis in results: metadata = record.meta score = 1 - dis - metadata['score'] = score - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + metadata["score"] = score + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 if score > score_threshold: - doc = Document(page_content=record.text, - metadata=metadata) + doc = Document(page_content=record.text, metadata=metadata) docs.append(doc) return docs @@ -225,13 +222,12 @@ class PGVectoRS(BaseVector): class PGVectoRSFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) dim = len(embeddings.embed_query("pgvecto_rs")) return PGVectoRS( @@ -243,5 +239,5 @@ class PGVectoRSFactory(AbstractVectorFactory): password=dify_config.PGVECTO_RS_PASSWORD, database=dify_config.PGVECTO_RS_DATABASE, ), - dim=dim + dim=dim, ) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index c9f2f35af0..b01cd91e07 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -24,7 +24,7 @@ class PGVectorConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config PGVECTOR_HOST is required") @@ -201,8 +201,7 @@ class PGVectorFactory(AbstractVectorFactory): else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) return PGVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 297bff928e..83d561819c 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -48,28 +48,25 @@ class QdrantConfig(BaseModel): prefer_grpc: bool = False def to_qdrant_params(self): - if self.endpoint and self.endpoint.startswith('path:'): - path = self.endpoint.replace('path:', '') + if self.endpoint and self.endpoint.startswith("path:"): + path = self.endpoint.replace("path:", "") if not os.path.isabs(path): path = os.path.join(self.root_path, path) - return { - 'path': path - } + return {"path": path} else: return { - 'url': self.endpoint, - 'api_key': self.api_key, - 'timeout': self.timeout, - 'verify': self.endpoint.startswith('https'), - 'grpc_port': self.grpc_port, - 'prefer_grpc': self.prefer_grpc + "url": self.endpoint, + "api_key": self.api_key, + "timeout": self.timeout, + "verify": self.endpoint.startswith("https"), + "grpc_port": self.grpc_port, + "prefer_grpc": self.prefer_grpc, } class QdrantVector(BaseVector): - - def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'): + def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"): super().__init__(collection_name) self._client_config = config self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) @@ -80,10 +77,7 @@ class QdrantVector(BaseVector): return VectorType.QDRANT def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if texts: @@ -97,9 +91,9 @@ class QdrantVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str, vector_size: int): - lock_name = 'vector_indexing_lock_{}'.format(collection_name) + lock_name = "vector_indexing_lock_{}".format(collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return collection_name = collection_name or uuid.uuid4().hex @@ -110,12 +104,19 @@ class QdrantVector(BaseVector): all_collection_name.append(collection.name) if collection_name not in all_collection_name: from qdrant_client.http import models as rest + vectors_config = rest.VectorParams( size=vector_size, distance=rest.Distance[self._distance_func], ) - hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, - max_indexing_threads=0, on_disk=False) + hnsw_config = HnswConfigDiff( + m=0, + payload_m=16, + ef_construct=100, + full_scan_threshold=10000, + max_indexing_threads=0, + on_disk=False, + ) self._client.recreate_collection( collection_name=collection_name, vectors_config=vectors_config, @@ -124,21 +125,24 @@ class QdrantVector(BaseVector): ) # create group_id payload index - self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, - field_schema=PayloadSchemaType.KEYWORD) + self._client.create_payload_index( + collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD + ) # create doc_id payload index - self._client.create_payload_index(collection_name, Field.DOC_ID.value, - field_schema=PayloadSchemaType.KEYWORD) + self._client.create_payload_index( + collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD + ) # create full text index text_index_params = TextIndexParams( type=TextIndexType.TEXT, tokenizer=TokenizerType.MULTILINGUAL, min_token_len=2, max_token_len=20, - lowercase=True + lowercase=True, + ) + self._client.create_payload_index( + collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params ) - self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, - field_schema=text_index_params) redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -147,26 +151,23 @@ class QdrantVector(BaseVector): metadatas = [d.metadata for d in documents] added_ids = [] - for batch_ids, points in self._generate_rest_batches( - texts, embeddings, metadatas, uuids, 64, self._group_id - ): - self._client.upsert( - collection_name=self._collection_name, points=points - ) + for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + self._client.upsert(collection_name=self._collection_name, points=points) added_ids.extend(batch_ids) return added_ids def _generate_rest_batches( - self, - texts: Iterable[str], - embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - group_id: Optional[str] = None, + self, + texts: Iterable[str], + embeddings: list[list[float]], + metadatas: Optional[list[dict]] = None, + ids: Optional[Sequence[str]] = None, + batch_size: int = 64, + group_id: Optional[str] = None, ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: from qdrant_client.http import models as rest + texts_iterator = iter(texts) embeddings_iterator = iter(embeddings) metadatas_iterator = iter(metadatas or []) @@ -203,13 +204,13 @@ class QdrantVector(BaseVector): @classmethod def _build_payloads( - cls, - texts: Iterable[str], - metadatas: Optional[list[dict]], - content_payload_key: str, - metadata_payload_key: str, - group_id: str, - group_payload_key: str + cls, + texts: Iterable[str], + metadatas: Optional[list[dict]], + content_payload_key: str, + metadata_payload_key: str, + group_id: str, + group_payload_key: str, ) -> list[dict]: payloads = [] for i, text in enumerate(texts): @@ -219,18 +220,11 @@ class QdrantVector(BaseVector): "calling .from_texts or .add_texts on Qdrant instance." ) metadata = metadatas[i] if metadatas is not None else None - payloads.append( - { - content_payload_key: text, - metadata_payload_key: metadata, - group_payload_key: group_id - } - ) + payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id}) return payloads def delete_by_metadata_field(self, key: str, value: str): - from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -248,9 +242,7 @@ class QdrantVector(BaseVector): self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -275,9 +267,7 @@ class QdrantVector(BaseVector): ) self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -288,7 +278,6 @@ class QdrantVector(BaseVector): raise e def delete_by_ids(self, ids: list[str]) -> None: - from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -304,9 +293,7 @@ class QdrantVector(BaseVector): ) self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -324,15 +311,13 @@ class QdrantVector(BaseVector): all_collection_name.append(collection.name) if self._collection_name not in all_collection_name: return False - response = self._client.retrieve( - collection_name=self._collection_name, - ids=[id] - ) + response = self._client.retrieve(collection_name=self._collection_name, ids=[id]) return len(response) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from qdrant_client.http import models + filter = models.Filter( must=[ models.FieldCondition( @@ -348,22 +333,22 @@ class QdrantVector(BaseVector): limit=kwargs.get("top_k", 4), with_payload=True, with_vectors=True, - score_threshold=kwargs.get("score_threshold", .0) + score_threshold=kwargs.get("score_threshold", 0.0), ) docs = [] for result in results: metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 if result.score > score_threshold: - metadata['score'] = result.score + metadata["score"] = result.score doc = Document( page_content=result.payload.get(Field.CONTENT_KEY.value), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -372,6 +357,7 @@ class QdrantVector(BaseVector): List of documents most similar to the query text and distance for each. """ from qdrant_client.http import models + scroll_filter = models.Filter( must=[ models.FieldCondition( @@ -381,24 +367,21 @@ class QdrantVector(BaseVector): models.FieldCondition( key="page_content", match=models.MatchText(text=query), - ) + ), ] ) response = self._client.scroll( collection_name=self._collection_name, scroll_filter=scroll_filter, - limit=kwargs.get('top_k', 2), + limit=kwargs.get("top_k", 2), with_payload=True, - with_vectors=True - + with_vectors=True, ) results = response[0] documents = [] for result in results: if result: - document = self._document_from_scored_point( - result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value - ) + document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) documents.append(document) return documents @@ -410,10 +393,10 @@ class QdrantVector(BaseVector): @classmethod def _document_from_scored_point( - cls, - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, + cls, + scored_point: Any, + content_payload_key: str, + metadata_payload_key: str, ) -> Document: return Document( page_content=scored_point.payload.get(content_payload_key), @@ -425,24 +408,25 @@ class QdrantVector(BaseVector): class QdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: if dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ - one_or_none() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .one_or_none() + ) if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: - raise ValueError('Dataset Collection Bindings is not exist!') + raise ValueError("Dataset Collection Bindings is not exist!") else: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) if not dataset.index_struct_dict: - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) config = current_app.config return QdrantVector( @@ -454,6 +438,6 @@ class QdrantVectorFactory(AbstractVectorFactory): root_path=config.root_path, timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, - prefer_grpc=dify_config.QDRANT_GRPC_ENABLED - ) + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, + ), ) diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 63ad0682d7..d8e4ff628c 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -33,28 +33,29 @@ class RelytConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config RELYT_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config RELYT_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config RELYT_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config RELYT_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config RELYT_DATABASE is required") return values class RelytVector(BaseVector): - def __init__(self, collection_name: str, config: RelytConfig, group_id: str): super().__init__(collection_name) self.embedding_dimension = 1536 self._client_config = config - self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + self._url = ( + f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + ) self.client = create_engine(self._url) self._fields = [] self._group_id = group_id @@ -70,9 +71,9 @@ class RelytVector(BaseVector): self.add_texts(texts, embeddings) def create_collection(self, dimension: int): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" @@ -110,7 +111,7 @@ class RelytVector(BaseVector): ids = [str(uuid.uuid1()) for _ in documents] metadatas = [d.metadata for d in documents] for metadata in metadatas: - metadata['group_id'] = self._group_id + metadata["group_id"] = self._group_id texts = [d.page_content for d in documents] # Define the table schema @@ -127,9 +128,7 @@ class RelytVector(BaseVector): chunks_table_data = [] with self.client.connect() as conn: with conn.begin(): - for document, metadata, chunk_id, embedding in zip( - texts, metadatas, ids, embeddings - ): + for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings): chunks_table_data.append( { "id": chunk_id, @@ -196,15 +195,13 @@ class RelytVector(BaseVector): return False def delete_by_metadata_field(self, key: str, value: str): - ids = self.get_ids_by_metadata_field(key, value) if ids: self.delete_by_uuids(ids) def delete_by_ids(self, ids: list[str]) -> None: - with Session(self.client) as session: - ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) + ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """ ) @@ -228,38 +225,34 @@ class RelytVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: results = self.similarity_search_with_score_by_vector( - k=int(kwargs.get('top_k')), - embedding=query_vector, - filter=kwargs.get('filter') + k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter") ) # Organize results. docs = [] for document, score in results: - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 if 1 - score > score_threshold: docs.append(document) return docs def similarity_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: list[float], + k: int = 4, + filter: Optional[dict] = None, ) -> list[tuple[Document, float]]: # Add the filter if provided try: from sqlalchemy.engine import Row except ImportError: - raise ImportError( - "Could not import Row from sqlalchemy.engine. " - "Please 'pip install sqlalchemy>=1.4'." - ) + raise ImportError("Could not import Row from sqlalchemy.engine. " "Please 'pip install sqlalchemy>=1.4'.") filter_condition = "" if filter is not None: conditions = [ - f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1 + f"metadata->>{key!r} in ({', '.join(map(repr, value))})" + if len(value) > 1 else f"metadata->>{key!r} = {value[0]!r}" for key, value in filter.items() ] @@ -305,13 +298,12 @@ class RelytVector(BaseVector): class RelytVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.RELYT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name)) return RelytVector( collection_name=collection_name, @@ -322,5 +314,5 @@ class RelytVectorFactory(AbstractVectorFactory): password=dify_config.RELYT_PASSWORD, database=dify_config.RELYT_DATABASE, ), - group_id=dataset.id + group_id=dataset.id, ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 3325a1028e..ada0c5cf46 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -25,16 +25,11 @@ class TencentConfig(BaseModel): database: Optional[str] index_type: str = "HNSW" metric_type: str = "L2" - shard: int = 1, - replicas: int = 2, + shard: int = (1,) + replicas: int = (2,) def to_tencent_params(self): - return { - 'url': self.url, - 'username': self.username, - 'key': self.api_key, - 'timeout': self.timeout - } + return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} class TencentVector(BaseVector): @@ -61,13 +56,10 @@ class TencentVector(BaseVector): return self._client.create_database(database_name=self._client_config.database) def get_type(self) -> str: - return 'tencent' + return "tencent" def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def _has_collection(self) -> bool: collections = self._db.list_collections() @@ -77,9 +69,9 @@ class TencentVector(BaseVector): return False def _create_collection(self, dimension: int) -> None: - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return @@ -101,9 +93,7 @@ class TencentVector(BaseVector): raise ValueError("unsupported metric_type") params = vdb_index.HNSWParams(m=16, efconstruction=200) index = vdb_index.Index( - vdb_index.FilterIndex( - self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY - ), + vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY), vdb_index.VectorIndex( self.field_vector, dimension, @@ -111,12 +101,8 @@ class TencentVector(BaseVector): metric_type, params, ), - vdb_index.FilterIndex( - self.field_text, enum.FieldType.String, enum.IndexType.FILTER - ), - vdb_index.FilterIndex( - self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER - ), + vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER), + vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER), ) self._db.create_collection( @@ -163,15 +149,14 @@ class TencentVector(BaseVector): self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value]))) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - - res = self._db.collection(self._collection_name).search(vectors=[query_vector], - params=document.HNSWSearchParams( - ef=kwargs.get("ef", 10)), - retrieve_vector=False, - limit=kwargs.get('top_k', 4), - timeout=self._client_config.timeout, - ) - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + res = self._db.collection(self._collection_name).search( + vectors=[query_vector], + params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)), + retrieve_vector=False, + limit=kwargs.get("top_k", 4), + timeout=self._client_config.timeout, + ) + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -200,15 +185,13 @@ class TencentVector(BaseVector): class TencentVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) return TencentVector( collection_name=collection_name, @@ -220,5 +203,5 @@ class TencentVectorFactory(AbstractVectorFactory): database=dify_config.TENCENT_VECTOR_DB_DATABASE, shard=dify_config.TENCENT_VECTOR_DB_SHARD, replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS, - ) + ), ) diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index d3685c0991..0e4b3f67a1 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -28,47 +28,56 @@ class TiDBVectorConfig(BaseModel): database: str program_name: str - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config TIDB_VECTOR_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config TIDB_VECTOR_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config TIDB_VECTOR_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config TIDB_VECTOR_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config TIDB_VECTOR_DATABASE is required") - if not values['program_name']: + if not values["program_name"]: raise ValueError("config APPLICATION_NAME is required") return values class TiDBVector(BaseVector): - def get_type(self) -> str: return VectorType.TIDB_VECTOR def _table(self, dim: int) -> Table: from tidb_vector.sqlalchemy import VectorType + return Table( self._collection_name, self._orm_base.metadata, - Column('id', String(36), primary_key=True, nullable=False), - Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"), + Column("id", String(36), primary_key=True, nullable=False), + Column( + "vector", + VectorType(dim), + nullable=False, + comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})", + ), Column("text", TEXT, nullable=False), Column("meta", JSON, nullable=False), Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")), - Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")), - extend_existing=True + Column( + "update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ), + extend_existing=True, ) - def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'): + def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = "cosine"): super().__init__(collection_name) self._client_config = config - self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" - f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}") + self._url = ( + f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" + f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}" + ) self._distance_func = distance_func.lower() self._engine = create_engine(self._url) self._orm_base = declarative_base() @@ -83,9 +92,9 @@ class TiDBVector(BaseVector): def _create_collection(self, dimension: int): logger.info("_create_collection, collection_name " + self._collection_name) - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return with Session(self._engine) as session: @@ -116,9 +125,7 @@ class TiDBVector(BaseVector): chunks_table_data = [] with self._engine.connect() as conn: with conn.begin(): - for id, text, meta, embedding in zip( - ids, texts, metas, embeddings - ): + for id, text, meta, embedding in zip(ids, texts, metas, embeddings): chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) # Execute the batch insert when the batch size is reached @@ -133,12 +140,12 @@ class TiDBVector(BaseVector): return ids def text_exists(self, id: str) -> bool: - result = self.get_ids_by_metadata_field('doc_id', id) + result = self.get_ids_by_metadata_field("doc_id", id) return bool(result) def delete_by_ids(self, ids: list[str]) -> None: with Session(self._engine) as session: - ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) + ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """ ) @@ -180,20 +187,22 @@ class TiDBVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 - filter = kwargs.get('filter') + filter = kwargs.get("filter") distance = 1 - score_threshold query_vector_str = ", ".join(format(x) for x in query_vector) query_vector_str = "[" + query_vector_str + "]" - logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}") + logger.debug( + f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}" + ) docs = [] - if self._distance_func == 'l2': - tidb_func = 'Vec_l2_distance' - elif self._distance_func == 'cosine': - tidb_func = 'Vec_Cosine_distance' + if self._distance_func == "l2": + tidb_func = "Vec_l2_distance" + elif self._distance_func == "cosine": + tidb_func = "Vec_Cosine_distance" else: - tidb_func = 'Vec_Cosine_distance' + tidb_func = "Vec_Cosine_distance" with Session(self._engine) as session: select_statement = sql_text( @@ -208,7 +217,7 @@ class TiDBVector(BaseVector): results = [(row[0], row[1], row[2]) for row in res] for meta, text, distance in results: metadata = json.loads(meta) - metadata['score'] = 1 - distance + metadata["score"] = 1 - distance docs.append(Document(page_content=text, metadata=metadata)) return docs @@ -224,15 +233,13 @@ class TiDBVector(BaseVector): class TiDBVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) return TiDBVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 3f70e8b608..fb80cdec87 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -7,7 +7,6 @@ from core.rag.models.document import Document class BaseVector(ABC): - def __init__(self, collection_name: str): self._collection_name = collection_name @@ -39,18 +38,11 @@ class BaseVector(ABC): raise NotImplementedError @abstractmethod - def search_by_vector( - self, - query_vector: list[float], - **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: raise NotImplementedError @abstractmethod - def search_by_full_text( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError def delete(self) -> None: @@ -58,7 +50,7 @@ class BaseVector(ABC): def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts[:]: - doc_id = text.metadata['doc_id'] + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) @@ -66,7 +58,7 @@ class BaseVector(ABC): return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata['doc_id'] for text in texts] + return [text.metadata["doc_id"] for text in texts] @property def collection_name(self): diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 627d7c3aeb..7d2db140df 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -20,17 +20,14 @@ class AbstractVectorFactory(ABC): @staticmethod def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict: - index_struct_dict = { - "type": vector_type, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} return index_struct_dict class Vector: def __init__(self, dataset: Dataset, attributes: list = None): if attributes is None: - attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash', 'page'] + attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "page"] self._dataset = dataset self._embeddings = self._get_embeddings() self._attributes = attributes @@ -39,7 +36,7 @@ class Vector: def _init_vector(self) -> BaseVector: vector_type = dify_config.VECTOR_STORE if self._dataset.index_struct_dict: - vector_type = self._dataset.index_struct_dict['type'] + vector_type = self._dataset.index_struct_dict["type"] if not vector_type: raise ValueError("Vector store must be specified.") @@ -52,45 +49,59 @@ class Vector: match vector_type: case VectorType.CHROMA: from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory + return ChromaVectorFactory case VectorType.MILVUS: from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory + return MilvusVectorFactory case VectorType.MYSCALE: from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory + return MyScaleVectorFactory case VectorType.PGVECTOR: from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory + return PGVectorFactory case VectorType.PGVECTO_RS: from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory + return PGVectoRSFactory case VectorType.QDRANT: from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory + return QdrantVectorFactory case VectorType.RELYT: from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory + return RelytVectorFactory case VectorType.ELASTICSEARCH: from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory + return ElasticSearchVectorFactory case VectorType.TIDB_VECTOR: from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory + return TiDBVectorFactory case VectorType.WEAVIATE: from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory + return WeaviateVectorFactory case VectorType.TENCENT: from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory + return TencentVectorFactory case VectorType.ORACLE: from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory + return OracleVectorFactory case VectorType.OPENSEARCH: from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory + return OpenSearchVectorFactory case VectorType.ANALYTICDB: from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory + return AnalyticdbVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") @@ -98,22 +109,14 @@ class Vector: def create(self, texts: list = None, **kwargs): if texts: embeddings = self._embeddings.embed_documents([document.page_content for document in texts]) - self._vector_processor.create( - texts=texts, - embeddings=embeddings, - **kwargs - ) + self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs) def add_texts(self, documents: list[Document], **kwargs): - if kwargs.get('duplicate_check', False): + if kwargs.get("duplicate_check", False): documents = self._filter_duplicate_texts(documents) embeddings = self._embeddings.embed_documents([document.page_content for document in documents]) - self._vector_processor.create( - texts=documents, - embeddings=embeddings, - **kwargs - ) + self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs) def text_exists(self, id: str) -> bool: return self._vector_processor.text_exists(id) @@ -124,24 +127,18 @@ class Vector: def delete_by_metadata_field(self, key: str, value: str) -> None: self._vector_processor.delete_by_metadata_field(key, value) - def search_by_vector( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]: query_vector = self._embeddings.embed_query(query) return self._vector_processor.search_by_vector(query_vector, **kwargs) - def search_by_full_text( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._vector_processor.search_by_full_text(query, **kwargs) def delete(self) -> None: self._vector_processor.delete() # delete collection redis cache if self._vector_processor.collection_name: - collection_exist_cache_key = 'vector_indexing_{}'.format(self._vector_processor.collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name) redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: @@ -151,14 +148,13 @@ class Vector: tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=self._dataset.embedding_model - + model=self._dataset.embedding_model, ) return CacheEmbedding(embedding_model) def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts[:]: - doc_id = text.metadata['doc_id'] + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 317ca6abc8..ba04ea879d 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -2,17 +2,17 @@ from enum import Enum class VectorType(str, Enum): - ANALYTICDB = 'analyticdb' - CHROMA = 'chroma' - MILVUS = 'milvus' - MYSCALE = 'myscale' - PGVECTOR = 'pgvector' - PGVECTO_RS = 'pgvecto-rs' - QDRANT = 'qdrant' - RELYT = 'relyt' - TIDB_VECTOR = 'tidb_vector' - WEAVIATE = 'weaviate' - OPENSEARCH = 'opensearch' - TENCENT = 'tencent' - ORACLE = 'oracle' - ELASTICSEARCH = 'elasticsearch' + ANALYTICDB = "analyticdb" + CHROMA = "chroma" + MILVUS = "milvus" + MYSCALE = "myscale" + PGVECTOR = "pgvector" + PGVECTO_RS = "pgvecto-rs" + QDRANT = "qdrant" + RELYT = "relyt" + TIDB_VECTOR = "tidb_vector" + WEAVIATE = "weaviate" + OPENSEARCH = "opensearch" + TENCENT = "tencent" + ORACLE = "oracle" + ELASTICSEARCH = "elasticsearch" diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 205fe850c3..750172b015 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -22,15 +22,14 @@ class WeaviateConfig(BaseModel): api_key: Optional[str] = None batch_size: int = 100 - @model_validator(mode='before') + @model_validator(mode="before") def validate_config(cls, values: dict) -> dict: - if not values['endpoint']: + if not values["endpoint"]: raise ValueError("config WEAVIATE_ENDPOINT is required") return values class WeaviateVector(BaseVector): - def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): super().__init__(collection_name) self._client = self._init_client(config) @@ -43,10 +42,7 @@ class WeaviateVector(BaseVector): try: client = weaviate.Client( - url=config.endpoint, - auth_client_secret=auth_config, - timeout_config=(5, 60), - startup_period=None + url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None ) except requests.exceptions.ConnectionError: raise ConnectionError("Vector database connection error") @@ -68,10 +64,10 @@ class WeaviateVector(BaseVector): def get_collection_name(self, dataset: Dataset) -> str: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + if not class_prefix.endswith("_Node"): # original class_prefix - class_prefix += '_Node' + class_prefix += "_Node" return class_prefix @@ -79,10 +75,7 @@ class WeaviateVector(BaseVector): return Dataset.gen_collection_name_by_id(dataset_id) def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): # create collection @@ -91,9 +84,9 @@ class WeaviateVector(BaseVector): self.add_texts(texts, embeddings) def _create_collection(self): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return schema = self._default_schema(self._collection_name) @@ -129,17 +122,9 @@ class WeaviateVector(BaseVector): # check whether the index already exists schema = self._default_schema(self._collection_name) if self._client.schema.contains(schema): - where_filter = { - "operator": "Equal", - "path": [key], - "valueText": value - } + where_filter = {"operator": "Equal", "path": [key], "valueText": value} - self._client.batch.delete_objects( - class_name=self._collection_name, - where=where_filter, - output='minimal' - ) + self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal") def delete(self): # check whether the index already exists @@ -154,11 +139,19 @@ class WeaviateVector(BaseVector): # check whether the index already exists if not self._client.schema.contains(schema): return False - result = self._client.query.get(collection_name).with_additional(["id"]).with_where({ - "path": ["doc_id"], - "operator": "Equal", - "valueText": id, - }).with_limit(1).do() + result = ( + self._client.query.get(collection_name) + .with_additional(["id"]) + .with_where( + { + "path": ["doc_id"], + "operator": "Equal", + "valueText": id, + } + ) + .with_limit(1) + .do() + ) if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") @@ -211,13 +204,13 @@ class WeaviateVector(BaseVector): docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 # check score threshold if score > score_threshold: - doc.metadata['score'] = score + doc.metadata["score"] = score docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -240,15 +233,15 @@ class WeaviateVector(BaseVector): if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) query_obj = query_obj.with_additional(["vector"]) - properties = ['text'] - result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do() + properties = ["text"] + result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do() if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") docs = [] for res in result["data"]["Get"][collection_name]: text = res.pop(Field.TEXT_KEY.value) - additional = res.pop('_additional') - docs.append(Document(page_content=text, vector=additional['vector'], metadata=res)) + additional = res.pop("_additional") + docs.append(Document(page_content=text, vector=additional["vector"], metadata=res)) return docs def _default_schema(self, index_name: str) -> dict: @@ -271,20 +264,19 @@ class WeaviateVector(BaseVector): class WeaviateVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( endpoint=dify_config.WEAVIATE_ENDPOINT, api_key=dify_config.WEAVIATE_API_KEY, - batch_size=dify_config.WEAVIATE_BATCH_SIZE + batch_size=dify_config.WEAVIATE_BATCH_SIZE, ), - attributes=attributes + attributes=attributes, ) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 96a15be742..0d4dff5b89 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -12,10 +12,10 @@ from models.dataset import Dataset, DocumentSegment class DatasetDocumentStore: def __init__( - self, - dataset: Dataset, - user_id: str, - document_id: Optional[str] = None, + self, + dataset: Dataset, + user_id: str, + document_id: Optional[str] = None, ): self._dataset = dataset self._user_id = user_id @@ -41,9 +41,9 @@ class DatasetDocumentStore: @property def docs(self) -> dict[str, Document]: - document_segments = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self._dataset.id - ).all() + document_segments = ( + db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all() + ) output = {} for document_segment in document_segments: @@ -55,48 +55,45 @@ class DatasetDocumentStore: "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) return output - def add_documents( - self, docs: Sequence[Document], allow_update: bool = True - ) -> None: - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == self._document_id - ).scalar() + def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == self._document_id) + .scalar() + ) if max_position is None: max_position = 0 embedding_model = None - if self._dataset.indexing_technique == 'high_quality': + if self._dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=self._dataset.embedding_model + model=self._dataset.embedding_model, ) for doc in docs: if not isinstance(doc, Document): raise ValueError("doc must be a Document") - segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id']) + segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"]) # NOTE: doc could already exist in the store, but we overwrite it if not allow_update and segment_document: raise ValueError( - f"doc_id {doc.metadata['doc_id']} already exists. " - "Set allow_update to True to overwrite." + f"doc_id {doc.metadata['doc_id']} already exists. " "Set allow_update to True to overwrite." ) # calc embedding use tokens if embedding_model: - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[doc.page_content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content]) else: tokens = 0 @@ -107,8 +104,8 @@ class DatasetDocumentStore: tenant_id=self._dataset.tenant_id, dataset_id=self._dataset.id, document_id=self._document_id, - index_node_id=doc.metadata['doc_id'], - index_node_hash=doc.metadata['doc_hash'], + index_node_id=doc.metadata["doc_id"], + index_node_hash=doc.metadata["doc_hash"], position=max_position, content=doc.page_content, word_count=len(doc.page_content), @@ -116,15 +113,15 @@ class DatasetDocumentStore: enabled=False, created_by=self._user_id, ) - if doc.metadata.get('answer'): - segment_document.answer = doc.metadata.pop('answer', '') + if doc.metadata.get("answer"): + segment_document.answer = doc.metadata.pop("answer", "") db.session.add(segment_document) else: segment_document.content = doc.page_content - if doc.metadata.get('answer'): - segment_document.answer = doc.metadata.pop('answer', '') - segment_document.index_node_hash = doc.metadata['doc_hash'] + if doc.metadata.get("answer"): + segment_document.answer = doc.metadata.pop("answer", "") + segment_document.index_node_hash = doc.metadata["doc_hash"] segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens @@ -135,9 +132,7 @@ class DatasetDocumentStore: result = self.get_document_segment(doc_id) return result is not None - def get_document( - self, doc_id: str, raise_error: bool = True - ) -> Optional[Document]: + def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[Document]: document_segment = self.get_document_segment(doc_id) if document_segment is None: @@ -153,7 +148,7 @@ class DatasetDocumentStore: "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) def delete_document(self, doc_id: str, raise_error: bool = True) -> None: @@ -188,9 +183,10 @@ class DatasetDocumentStore: return document_segment.index_node_hash def get_document_segment(self, doc_id: str) -> DocumentSegment: - document_segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self._dataset.id, - DocumentSegment.index_node_id == doc_id - ).first() + document_segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) + .first() + ) return document_segment diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py index abfdafcfa2..f4c7b4b5f7 100644 --- a/api/core/rag/extractor/blob/blob.py +++ b/api/core/rag/extractor/blob/blob.py @@ -4,6 +4,7 @@ The goal is to facilitate decoupling of content loading from content parsing cod In addition, content loading code should provide a lazy loading interface by default. """ + from __future__ import annotations import contextlib diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 0470569f39..5b67403902 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import csv from typing import Optional @@ -18,12 +19,12 @@ class CSVExtractor(BaseExtractor): """ def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False, - source_column: Optional[str] = None, - csv_args: Optional[dict] = None, + self, + file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False, + source_column: Optional[str] = None, + csv_args: Optional[dict] = None, ): """Initialize with file path.""" self._file_path = file_path @@ -57,7 +58,7 @@ class CSVExtractor(BaseExtractor): docs = [] try: # load csv file into pandas dataframe - df = pd.read_csv(csvfile, on_bad_lines='skip', **self.csv_args) + df = pd.read_csv(csvfile, on_bad_lines="skip", **self.csv_args) # check source column exists if self.source_column and self.source_column not in df.columns: @@ -67,7 +68,7 @@ class CSVExtractor(BaseExtractor): for i, row in df.iterrows(): content = ";".join(f"{col.strip()}: {str(row[col]).strip()}" for col in df.columns) - source = row[self.source_column] if self.source_column else '' + source = row[self.source_column] if self.source_column else "" metadata = {"source": source, "row": i} doc = Document(page_content=content, metadata=metadata) docs.append(doc) diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 7479b1d97b..3692b5d19d 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -10,6 +10,7 @@ class NotionInfo(BaseModel): """ Notion import info. """ + notion_workspace_id: str notion_obj_id: str notion_page_type: str @@ -25,6 +26,7 @@ class WebsiteInfo(BaseModel): """ website import info. """ + provider: str job_id: str url: str @@ -43,6 +45,7 @@ class ExtractSetting(BaseModel): """ Model class for provider response. """ + datasource_type: str upload_file: Optional[UploadFile] = None notion_info: Optional[NotionInfo] = None diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index 526c66042c..fc33165719 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import os from typing import Optional @@ -17,23 +18,18 @@ class ExcelExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False - ): + def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding self._autodetect_encoding = autodetect_encoding def extract(self) -> list[Document]: - """ Load from Excel file in xls or xlsx format using Pandas and openpyxl.""" + """Load from Excel file in xls or xlsx format using Pandas and openpyxl.""" documents = [] file_extension = os.path.splitext(self._file_path)[-1].lower() - if file_extension == '.xlsx': + if file_extension == ".xlsx": wb = load_workbook(self._file_path, data_only=True) for sheet_name in wb.sheetnames: sheet = wb[sheet_name] @@ -44,35 +40,38 @@ class ExcelExtractor(BaseExtractor): continue df = pd.DataFrame(data, columns=cols) - df.dropna(how='all', inplace=True) + df.dropna(how="all", inplace=True) for index, row in df.iterrows(): page_content = [] for col_index, (k, v) in enumerate(row.items()): if pd.notna(v): - cell = sheet.cell(row=index + 2, - column=col_index + 1) # +2 to account for header and 1-based index + cell = sheet.cell( + row=index + 2, column=col_index + 1 + ) # +2 to account for header and 1-based index if cell.hyperlink: value = f"[{v}]({cell.hyperlink.target})" page_content.append(f'"{k}":"{value}"') else: page_content.append(f'"{k}":"{v}"') - documents.append(Document(page_content=';'.join(page_content), - metadata={'source': self._file_path})) + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) - elif file_extension == '.xls': - excel_file = pd.ExcelFile(self._file_path, engine='xlrd') + elif file_extension == ".xls": + excel_file = pd.ExcelFile(self._file_path, engine="xlrd") for sheet_name in excel_file.sheet_names: df = excel_file.parse(sheet_name=sheet_name) - df.dropna(how='all', inplace=True) + df.dropna(how="all", inplace=True) for _, row in df.iterrows(): page_content = [] for k, v in row.items(): if pd.notna(v): page_content.append(f'"{k}":"{v}"') - documents.append(Document(page_content=';'.join(page_content), - metadata={'source': self._file_path})) + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) else: raise ValueError(f"Unsupported file extension: {file_extension}") diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index f7a08135f5..a00b3cba53 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -29,61 +29,60 @@ from core.rag.models.document import Document from extensions.ext_storage import storage from models.model import UploadFile -SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain', 'application/json'] +SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"] USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" class ExtractProcessor: @classmethod - def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \ - -> Union[list[Document], str]: + def load_from_upload_file( + cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False + ) -> Union[list[Document], str]: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=upload_file, - document_model='text_model' + datasource_type="upload_file", upload_file=upload_file, document_model="text_model" ) if return_text: - delimiter = '\n' + delimiter = "\n" return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)]) else: return cls.extract(extract_setting, is_automatic) @classmethod def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: - response = ssrf_proxy.get(url, headers={ - "User-Agent": USER_AGENT - }) + response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT}) with tempfile.TemporaryDirectory() as temp_dir: suffix = Path(url).suffix - if not suffix and suffix != '.': + if not suffix and suffix != ".": # get content-type - if response.headers.get('Content-Type'): - suffix = '.' + response.headers.get('Content-Type').split('/')[-1] + if response.headers.get("Content-Type"): + suffix = "." + response.headers.get("Content-Type").split("/")[-1] else: - content_disposition = response.headers.get('Content-Disposition') + content_disposition = response.headers.get("Content-Disposition") filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - suffix = '.' + re.search(r'\.(\w+)$', filename).group(1) + suffix = "." + re.search(r"\.(\w+)$", filename).group(1) file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" - with open(file_path, 'wb') as file: + with open(file_path, "wb") as file: file.write(response.content) - extract_setting = ExtractSetting( - datasource_type="upload_file", - document_model='text_model' - ) + extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") if return_text: - delimiter = '\n' - return delimiter.join([document.page_content for document in cls.extract( - extract_setting=extract_setting, file_path=file_path)]) + delimiter = "\n" + return delimiter.join( + [ + document.page_content + for document in cls.extract(extract_setting=extract_setting, file_path=file_path) + ] + ) else: return cls.extract(extract_setting=extract_setting, file_path=file_path) @classmethod - def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, - file_path: str = None) -> list[Document]: + def extract( + cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str = None + ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: if not file_path: @@ -96,50 +95,56 @@ class ExtractProcessor: etl_type = dify_config.ETL_TYPE unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY - if etl_type == 'Unstructured': - if file_extension == '.xlsx' or file_extension == '.xls': + if etl_type == "Unstructured": + if file_extension == ".xlsx" or file_extension == ".xls": extractor = ExcelExtractor(file_path) - elif file_extension == '.pdf': + elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in ['.md', '.markdown']: - extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \ + elif file_extension in [".md", ".markdown"]: + extractor = ( + UnstructuredMarkdownExtractor(file_path, unstructured_api_url) + if is_automatic else MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: + ) + elif file_extension in [".htm", ".html"]: extractor = HtmlExtractor(file_path) - elif file_extension in ['.docx']: + elif file_extension in [".docx"]: extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) - elif file_extension == '.csv': + elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) - elif file_extension == '.msg': + elif file_extension == ".msg": extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url) - elif file_extension == '.eml': + elif file_extension == ".eml": extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url) - elif file_extension == '.ppt': + elif file_extension == ".ppt": extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key) - elif file_extension == '.pptx': + elif file_extension == ".pptx": extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url) - elif file_extension == '.xml': + elif file_extension == ".xml": extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url) - elif file_extension == 'epub': + elif file_extension == "epub": extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url) else: # txt - extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \ + extractor = ( + UnstructuredTextExtractor(file_path, unstructured_api_url) + if is_automatic else TextExtractor(file_path, autodetect_encoding=True) + ) else: - if file_extension == '.xlsx' or file_extension == '.xls': + if file_extension == ".xlsx" or file_extension == ".xls": extractor = ExcelExtractor(file_path) - elif file_extension == '.pdf': + elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in ['.md', '.markdown']: + elif file_extension in [".md", ".markdown"]: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: + elif file_extension in [".htm", ".html"]: extractor = HtmlExtractor(file_path) - elif file_extension in ['.docx']: + elif file_extension in [".docx"]: extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) - elif file_extension == '.csv': + elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) - elif file_extension == 'epub': + elif file_extension == "epub": extractor = UnstructuredEpubExtractor(file_path) else: # txt @@ -155,13 +160,13 @@ class ExtractProcessor: ) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: - if extract_setting.website_info.provider == 'firecrawl': + if extract_setting.website_info.provider == "firecrawl": extractor = FirecrawlWebExtractor( url=extract_setting.website_info.url, job_id=extract_setting.website_info.job_id, tenant_id=extract_setting.website_info.tenant_id, mode=extract_setting.website_info.mode, - only_main_content=extract_setting.website_info.only_main_content + only_main_content=extract_setting.website_info.only_main_content, ) return extractor.extract() else: diff --git a/api/core/rag/extractor/extractor_base.py b/api/core/rag/extractor/extractor_base.py index c490e59332..582eca94df 100644 --- a/api/core/rag/extractor/extractor_base.py +++ b/api/core/rag/extractor/extractor_base.py @@ -1,12 +1,11 @@ """Abstract interface for document loader implementations.""" + from abc import ABC, abstractmethod class BaseExtractor(ABC): - """Interface for extract files. - """ + """Interface for extract files.""" @abstractmethod def extract(self): raise NotImplementedError - diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 2b85ad9739..054ce5f4b2 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -9,108 +9,98 @@ from extensions.ext_storage import storage class FirecrawlApp: def __init__(self, api_key=None, base_url=None): self.api_key = api_key - self.base_url = base_url or 'https://api.firecrawl.dev' - if self.api_key is None and self.base_url == 'https://api.firecrawl.dev': - raise ValueError('No API key provided') + self.base_url = base_url or "https://api.firecrawl.dev" + if self.api_key is None and self.base_url == "https://api.firecrawl.dev": + raise ValueError("No API key provided") def scrape_url(self, url, params=None) -> dict: - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } - json_data = {'url': url} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + json_data = {"url": url} if params: json_data.update(params) - response = requests.post( - f'{self.base_url}/v0/scrape', - headers=headers, - json=json_data - ) + response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data) if response.status_code == 200: response = response.json() - if response['success'] == True: - data = response['data'] + if response["success"] == True: + data = response["data"] return { - 'title': data.get('metadata').get('title'), - 'description': data.get('metadata').get('description'), - 'source_url': data.get('metadata').get('sourceURL'), - 'markdown': data.get('markdown') + "title": data.get("metadata").get("title"), + "description": data.get("metadata").get("description"), + "source_url": data.get("metadata").get("sourceURL"), + "markdown": data.get("markdown"), } else: raise Exception(f'Failed to scrape URL. Error: {response["error"]}') elif response.status_code in [402, 409, 500]: - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}') + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}") else: - raise Exception(f'Failed to scrape URL. Status code: {response.status_code}') + raise Exception(f"Failed to scrape URL. Status code: {response.status_code}") def crawl_url(self, url, params=None) -> str: headers = self._prepare_headers() - json_data = {'url': url} + json_data = {"url": url} if params: json_data.update(params) - response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers) + response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers) if response.status_code == 200: - job_id = response.json().get('jobId') + job_id = response.json().get("jobId") return job_id else: - self._handle_error(response, 'start crawl job') + self._handle_error(response, "start crawl job") def check_crawl_status(self, job_id) -> dict: headers = self._prepare_headers() - response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers) + response = self._get_request(f"{self.base_url}/v0/crawl/status/{job_id}", headers) if response.status_code == 200: crawl_status_response = response.json() - if crawl_status_response.get('status') == 'completed': - total = crawl_status_response.get('total', 0) + if crawl_status_response.get("status") == "completed": + total = crawl_status_response.get("total", 0) if total == 0: - raise Exception('Failed to check crawl status. Error: No page found') - data = crawl_status_response.get('data', []) + raise Exception("Failed to check crawl status. Error: No page found") + data = crawl_status_response.get("data", []) url_data_list = [] for item in data: - if isinstance(item, dict) and 'metadata' in item and 'markdown' in item: + if isinstance(item, dict) and "metadata" in item and "markdown" in item: url_data = { - 'title': item.get('metadata').get('title'), - 'description': item.get('metadata').get('description'), - 'source_url': item.get('metadata').get('sourceURL'), - 'markdown': item.get('markdown') + "title": item.get("metadata").get("title"), + "description": item.get("metadata").get("description"), + "source_url": item.get("metadata").get("sourceURL"), + "markdown": item.get("markdown"), } url_data_list.append(url_data) if url_data_list: - file_key = 'website_files/' + job_id + '.txt' + file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(url_data_list).encode('utf-8')) + storage.save(file_key, json.dumps(url_data_list).encode("utf-8")) return { - 'status': 'completed', - 'total': crawl_status_response.get('total'), - 'current': crawl_status_response.get('current'), - 'data': url_data_list + "status": "completed", + "total": crawl_status_response.get("total"), + "current": crawl_status_response.get("current"), + "data": url_data_list, } else: return { - 'status': crawl_status_response.get('status'), - 'total': crawl_status_response.get('total'), - 'current': crawl_status_response.get('current'), - 'data': [] + "status": crawl_status_response.get("status"), + "total": crawl_status_response.get("total"), + "current": crawl_status_response.get("current"), + "data": [], } else: - self._handle_error(response, 'check crawl status') + self._handle_error(response, "check crawl status") def _prepare_headers(self): - return { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5): for attempt in range(retries): response = requests.post(url, headers=headers, json=data) if response.status_code == 502: - time.sleep(backoff_factor * (2 ** attempt)) + time.sleep(backoff_factor * (2**attempt)) else: return response return response @@ -119,13 +109,11 @@ class FirecrawlApp: for attempt in range(retries): response = requests.get(url, headers=headers) if response.status_code == 502: - time.sleep(backoff_factor * (2 ** attempt)) + time.sleep(backoff_factor * (2**attempt)) else: return response return response def _handle_error(self, response, action): - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}') - - + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") diff --git a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py index 8e2f107e5e..b33ce167c2 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py @@ -5,7 +5,7 @@ from services.website_service import WebsiteService class FirecrawlWebExtractor(BaseExtractor): """ - Crawl and scrape websites and return content in clean llm-ready markdown. + Crawl and scrape websites and return content in clean llm-ready markdown. Args: @@ -15,14 +15,7 @@ class FirecrawlWebExtractor(BaseExtractor): mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'. """ - def __init__( - self, - url: str, - job_id: str, - tenant_id: str, - mode: str = 'crawl', - only_main_content: bool = False - ): + def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False): """Initialize with url, api_key, base_url and mode.""" self._url = url self.job_id = job_id @@ -33,28 +26,31 @@ class FirecrawlWebExtractor(BaseExtractor): def extract(self) -> list[Document]: """Extract content from the URL.""" documents = [] - if self.mode == 'crawl': - crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id) + if self.mode == "crawl": + crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id) if crawl_data is None: return [] - document = Document(page_content=crawl_data.get('markdown', ''), - metadata={ - 'source_url': crawl_data.get('source_url'), - 'description': crawl_data.get('description'), - 'title': crawl_data.get('title') - } - ) + document = Document( + page_content=crawl_data.get("markdown", ""), + metadata={ + "source_url": crawl_data.get("source_url"), + "description": crawl_data.get("description"), + "title": crawl_data.get("title"), + }, + ) documents.append(document) - elif self.mode == 'scrape': - scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id, - self.only_main_content) + elif self.mode == "scrape": + scrape_data = WebsiteService.get_scrape_url_data( + "firecrawl", self._url, self.tenant_id, self.only_main_content + ) - document = Document(page_content=scrape_data.get('markdown', ''), - metadata={ - 'source_url': scrape_data.get('source_url'), - 'description': scrape_data.get('description'), - 'title': scrape_data.get('title') - } - ) + document = Document( + page_content=scrape_data.get("markdown", ""), + metadata={ + "source_url": scrape_data.get("source_url"), + "description": scrape_data.get("description"), + "title": scrape_data.get("title"), + }, + ) documents.append(document) return documents diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 0c17a47b32..9a21d4272a 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -37,9 +37,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding try: encodings = future.result(timeout=timeout) except concurrent.futures.TimeoutError: - raise TimeoutError( - f"Timeout reached while detecting encoding for {file_path}" - ) + raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}") if all(encoding["encoding"] is None for encoding in encodings): raise RuntimeError(f"Could not detect encoding for {file_path}") diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py index ceb5306255..560c2d1d84 100644 --- a/api/core/rag/extractor/html_extractor.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from bs4 import BeautifulSoup from core.rag.extractor.extractor_base import BaseExtractor @@ -6,7 +7,6 @@ from core.rag.models.document import Document class HtmlExtractor(BaseExtractor): - """ Load html files. @@ -15,10 +15,7 @@ class HtmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str - ): + def __init__(self, file_path: str): """Initialize with file path.""" self._file_path = file_path @@ -27,8 +24,8 @@ class HtmlExtractor(BaseExtractor): def _load_as_text(self) -> str: with open(self._file_path, "rb") as fp: - soup = BeautifulSoup(fp, 'html.parser') + soup = BeautifulSoup(fp, "html.parser") text = soup.get_text() - text = text.strip() if text else '' + text = text.strip() if text else "" - return text \ No newline at end of file + return text diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index b24cf2e170..ca125ecf55 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import re from typing import Optional, cast @@ -16,12 +17,12 @@ class MarkdownExtractor(BaseExtractor): """ def __init__( - self, - file_path: str, - remove_hyperlinks: bool = False, - remove_images: bool = False, - encoding: Optional[str] = None, - autodetect_encoding: bool = True, + self, + file_path: str, + remove_hyperlinks: bool = False, + remove_images: bool = False, + encoding: Optional[str] = None, + autodetect_encoding: bool = True, ): """Initialize with file path.""" self._file_path = file_path @@ -78,13 +79,10 @@ class MarkdownExtractor(BaseExtractor): if current_header is not None: # pass linting, assert keys are defined markdown_tups = [ - (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) - for key, value in markdown_tups + (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups ] else: - markdown_tups = [ - (key, re.sub("\n", "", value)) for key, value in markdown_tups - ] + markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups] return markdown_tups diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 7e839804c8..b02e30de62 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -21,22 +21,21 @@ RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" # if user want split by headings, use the corresponding splitter HEADING_SPLITTER = { - 'heading_1': '# ', - 'heading_2': '## ', - 'heading_3': '### ', + "heading_1": "# ", + "heading_2": "## ", + "heading_3": "### ", } + class NotionExtractor(BaseExtractor): - def __init__( - self, - notion_workspace_id: str, - notion_obj_id: str, - notion_page_type: str, - tenant_id: str, - document_model: Optional[DocumentModel] = None, - notion_access_token: Optional[str] = None, - + self, + notion_workspace_id: str, + notion_obj_id: str, + notion_page_type: str, + tenant_id: str, + document_model: Optional[DocumentModel] = None, + notion_access_token: Optional[str] = None, ): self._notion_access_token = None self._document_model = document_model @@ -46,46 +45,38 @@ class NotionExtractor(BaseExtractor): if notion_access_token: self._notion_access_token = notion_access_token else: - self._notion_access_token = self._get_access_token(tenant_id, - self._notion_workspace_id) + self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id) if not self._notion_access_token: integration_token = dify_config.NOTION_INTEGRATION_TOKEN if integration_token is None: raise ValueError( - "Must specify `integration_token` or set environment " - "variable `NOTION_INTEGRATION_TOKEN`." + "Must specify `integration_token` or set environment " "variable `NOTION_INTEGRATION_TOKEN`." ) self._notion_access_token = integration_token def extract(self) -> list[Document]: - self.update_last_edited_time( - self._document_model - ) + self.update_last_edited_time(self._document_model) text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type) return text_docs - def _load_data_as_documents( - self, notion_obj_id: str, notion_page_type: str - ) -> list[Document]: + def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) -> list[Document]: docs = [] - if notion_page_type == 'database': + if notion_page_type == "database": # get all the pages in the database page_text_documents = self._get_notion_database_data(notion_obj_id) docs.extend(page_text_documents) - elif notion_page_type == 'page': + elif notion_page_type == "page": page_text_list = self._get_notion_block_data(notion_obj_id) - docs.append(Document(page_content='\n'.join(page_text_list))) + docs.append(Document(page_content="\n".join(page_text_list))) else: raise ValueError("notion page type not supported") return docs - def _get_notion_database_data( - self, database_id: str, query_dict: dict[str, Any] = {} - ) -> list[Document]: + def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]: """Get all the pages from a Notion database.""" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), @@ -100,50 +91,50 @@ class NotionExtractor(BaseExtractor): data = res.json() database_content = [] - if 'results' not in data or data["results"] is None: + if "results" not in data or data["results"] is None: return [] for result in data["results"]: - properties = result['properties'] + properties = result["properties"] data = {} for property_name, property_value in properties.items(): - type = property_value['type'] - if type == 'multi_select': + type = property_value["type"] + if type == "multi_select": value = [] multi_select_list = property_value[type] for multi_select in multi_select_list: - value.append(multi_select['name']) - elif type == 'rich_text' or type == 'title': + value.append(multi_select["name"]) + elif type == "rich_text" or type == "title": if len(property_value[type]) > 0: - value = property_value[type][0]['plain_text'] + value = property_value[type][0]["plain_text"] else: - value = '' - elif type == 'select' or type == 'status': + value = "" + elif type == "select" or type == "status": if property_value[type]: - value = property_value[type]['name'] + value = property_value[type]["name"] else: - value = '' + value = "" else: value = property_value[type] data[property_name] = value row_dict = {k: v for k, v in data.items() if v} - row_content = '' + row_content = "" for key, value in row_dict.items(): if isinstance(value, dict): value_dict = {k: v for k, v in value.items() if v} - value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items()) - row_content = row_content + f'{key}:{value_content}\n' + value_content = "".join(f"{k}:{v} " for k, v in value_dict.items()) + row_content = row_content + f"{key}:{value_content}\n" else: - row_content = row_content + f'{key}:{value}\n' + row_content = row_content + f"{key}:{value}\n" database_content.append(row_content) - return [Document(page_content='\n'.join(database_content))] + return [Document(page_content="\n".join(database_content))] def _get_notion_block_data(self, page_id: str) -> list[str]: result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id) while True: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", block_url, @@ -152,14 +143,14 @@ class NotionExtractor(BaseExtractor): "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() for result in data["results"]: result_type = result["type"] result_obj = result[result_type] cur_result_text_arr = [] - if result_type == 'table': + if result_type == "table": result_block_id = result["id"] text = self._read_table_rows(result_block_id) text += "\n\n" @@ -175,17 +166,15 @@ class NotionExtractor(BaseExtractor): result_block_id = result["id"] has_children = result["has_children"] block_type = result["type"] - if has_children and block_type != 'child_page': - children_text = self._read_block( - result_block_id, num_tabs=1 - ) + if has_children and block_type != "child_page": + children_text = self._read_block(result_block_id, num_tabs=1) cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) if result_type in HEADING_SPLITTER: result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}") else: - result_lines_arr.append(cur_result_text + '\n\n') + result_lines_arr.append(cur_result_text + "\n\n") if data["next_cursor"] is None: break @@ -199,7 +188,7 @@ class NotionExtractor(BaseExtractor): start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) while True: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", @@ -209,16 +198,16 @@ class NotionExtractor(BaseExtractor): "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() - if 'results' not in data or data["results"] is None: + if "results" not in data or data["results"] is None: break for result in data["results"]: result_type = result["type"] result_obj = result[result_type] cur_result_text_arr = [] - if result_type == 'table': + if result_type == "table": result_block_id = result["id"] text = self._read_table_rows(result_block_id) result_lines_arr.append(text) @@ -233,17 +222,15 @@ class NotionExtractor(BaseExtractor): result_block_id = result["id"] has_children = result["has_children"] block_type = result["type"] - if has_children and block_type != 'child_page': - children_text = self._read_block( - result_block_id, num_tabs=num_tabs + 1 - ) + if has_children and block_type != "child_page": + children_text = self._read_block(result_block_id, num_tabs=num_tabs + 1) cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) if result_type in HEADING_SPLITTER: - result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}') + result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}") else: - result_lines_arr.append(cur_result_text + '\n\n') + result_lines_arr.append(cur_result_text + "\n\n") if data["next_cursor"] is None: break @@ -260,7 +247,7 @@ class NotionExtractor(BaseExtractor): start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) while not done: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", @@ -270,28 +257,28 @@ class NotionExtractor(BaseExtractor): "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() # get table headers text table_header_cell_texts = [] - table_header_cells = data["results"][0]['table_row']['cells'] + table_header_cells = data["results"][0]["table_row"]["cells"] for table_header_cell in table_header_cells: if table_header_cell: for table_header_cell_text in table_header_cell: text = table_header_cell_text["text"]["content"] table_header_cell_texts.append(text) else: - table_header_cell_texts.append('') + table_header_cell_texts.append("") # Initialize Markdown table with headers markdown_table = "| " + " | ".join(table_header_cell_texts) + " |\n" - markdown_table += "| " + " | ".join(['---'] * len(table_header_cell_texts)) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(table_header_cell_texts)) + " |\n" # Process data to format each row in Markdown table format results = data["results"] for i in range(len(results) - 1): column_texts = [] - table_column_cells = data["results"][i + 1]['table_row']['cells'] + table_column_cells = data["results"][i + 1]["table_row"]["cells"] for j in range(len(table_column_cells)): if table_column_cells[j]: for table_column_cell_text in table_column_cells[j]: @@ -315,10 +302,8 @@ class NotionExtractor(BaseExtractor): last_edited_time = self.get_notion_last_edited_time() data_source_info = document_model.data_source_info_dict - data_source_info['last_edited_time'] = last_edited_time - update_params = { - DocumentModel.data_source_info: json.dumps(data_source_info) - } + data_source_info["last_edited_time"] = last_edited_time + update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)} DocumentModel.query.filter_by(id=document_model.id).update(update_params) db.session.commit() @@ -326,7 +311,7 @@ class NotionExtractor(BaseExtractor): def get_notion_last_edited_time(self) -> str: obj_id = self._notion_obj_id page_type = self._notion_page_type - if page_type == 'database': + if page_type == "database": retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id) else: retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id) @@ -341,7 +326,7 @@ class NotionExtractor(BaseExtractor): "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - json=query_dict + json=query_dict, ) data = res.json() @@ -352,14 +337,16 @@ class NotionExtractor(BaseExtractor): data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', ) ).first() if not data_source_binding: - raise Exception(f'No notion data source binding found for tenant {tenant_id} ' - f'and notion workspace {notion_workspace_id}') + raise Exception( + f"No notion data source binding found for tenant {tenant_id} " + f"and notion workspace {notion_workspace_id}" + ) return data_source_binding.access_token diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 0864fec6c8..57cb9610ba 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from collections.abc import Iterator from typing import Optional @@ -16,21 +17,17 @@ class PdfExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - file_cache_key: Optional[str] = None - ): + def __init__(self, file_path: str, file_cache_key: Optional[str] = None): """Initialize with file path.""" self._file_path = file_path self._file_cache_key = file_cache_key def extract(self) -> list[Document]: - plaintext_file_key = '' + plaintext_file_key = "" plaintext_file_exists = False if self._file_cache_key: try: - text = storage.load(self._file_cache_key).decode('utf-8') + text = storage.load(self._file_cache_key).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] except FileNotFoundError: @@ -43,12 +40,12 @@ class PdfExtractor(BaseExtractor): # save plaintext file for caching if not plaintext_file_exists and plaintext_file_key: - storage.save(plaintext_file_key, text.encode('utf-8')) + storage.save(plaintext_file_key, text.encode("utf-8")) return documents def load( - self, + self, ) -> Iterator[Document]: """Lazy load given path as pages.""" blob = Blob.from_path(self._file_path) diff --git a/api/core/rag/extractor/text_extractor.py b/api/core/rag/extractor/text_extractor.py index ac5d0920cf..ed0ae41f51 100644 --- a/api/core/rag/extractor/text_extractor.py +++ b/api/core/rag/extractor/text_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor @@ -14,12 +15,7 @@ class TextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False - ): + def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py index 0323b14a4a..a525c9e9e3 100644 --- a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -8,13 +8,12 @@ logger = logging.getLogger(__name__) class UnstructuredWordExtractor(BaseExtractor): - """Loader that uses unstructured to load word documents. - """ + """Loader that uses unstructured to load word documents.""" def __init__( - self, - file_path: str, - api_url: str, + self, + file_path: str, + api_url: str, ): """Initialize with file path.""" self._file_path = file_path @@ -24,9 +23,7 @@ class UnstructuredWordExtractor(BaseExtractor): from unstructured.__version__ import __version__ as __unstructured_version__ from unstructured.file_utils.filetype import FileType, detect_filetype - unstructured_version = tuple( - int(x) for x in __unstructured_version__.split(".") - ) + unstructured_version = tuple(int(x) for x in __unstructured_version__.split(".")) # check the file extension try: import magic # noqa: F401 @@ -53,6 +50,7 @@ class UnstructuredWordExtractor(BaseExtractor): elements = partition_docx(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index 2e704f187d..34c6811b67 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -26,6 +26,7 @@ class UnstructuredEmailExtractor(BaseExtractor): def extract(self) -> list[Document]: from unstructured.partition.email import partition_email + elements = partition_email(filename=self._file_path) # noinspection PyBroadException @@ -34,15 +35,16 @@ class UnstructuredEmailExtractor(BaseExtractor): element_text = element.text.strip() padding_needed = 4 - len(element_text) % 4 - element_text += '=' * padding_needed + element_text += "=" * padding_needed element_decode = base64.b64decode(element_text) - soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser') + soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser") element.text = soup.get_text() except Exception: pass from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index 44cf958ea2..fa50fa76b2 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -28,6 +28,7 @@ class UnstructuredEpubExtractor(BaseExtractor): elements = partition_epub(filename=self._file_path, xml_keep_tags=True) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 144b4e0c1d..fc3ff10693 100644 --- a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -38,6 +38,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor): elements = partition_md(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index ad09b79eb0..8091e83e85 100644 --- a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -14,11 +14,7 @@ class UnstructuredMsgExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ class UnstructuredMsgExtractor(BaseExtractor): elements = partition_msg(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index d354b593ed..b69394b3b1 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -14,12 +14,7 @@ class UnstructuredPPTExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str, - api_key: str - ): + def __init__(self, file_path: str, api_url: str, api_key: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index 6fcbb5feb9..6ed4a0dfb3 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -14,11 +14,7 @@ class UnstructuredPPTXExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py index f4a4adbc16..22dfdd2075 100644 --- a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py @@ -14,11 +14,7 @@ class UnstructuredTextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ class UnstructuredTextExtractor(BaseExtractor): elements = partition_text(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index 6aef8e0f7e..3bffc01fbf 100644 --- a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -14,11 +14,7 @@ class UnstructuredXmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ class UnstructuredXmlExtractor(BaseExtractor): elements = partition_xml(filename=self._file_path, xml_keep_tags=True) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 15822867bb..2db00d161b 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import datetime import logging import mimetypes @@ -21,6 +22,7 @@ from models.model import UploadFile logger = logging.getLogger(__name__) + class WordExtractor(BaseExtractor): """Load docx files. @@ -43,9 +45,7 @@ class WordExtractor(BaseExtractor): r = requests.get(self.file_path) if r.status_code != 200: - raise ValueError( - f"Check the url of your file; returned status code {r.status_code}" - ) + raise ValueError(f"Check the url of your file; returned status code {r.status_code}") self.web_path = self.file_path self.temp_file = tempfile.NamedTemporaryFile() @@ -60,11 +60,13 @@ class WordExtractor(BaseExtractor): def extract(self) -> list[Document]: """Load given path as single page.""" - content = self.parse_docx(self.file_path, 'storage') - return [Document( - page_content=content, - metadata={"source": self.file_path}, - )] + content = self.parse_docx(self.file_path, "storage") + return [ + Document( + page_content=content, + metadata={"source": self.file_path}, + ) + ] @staticmethod def _is_valid_url(url: str) -> bool: @@ -84,18 +86,18 @@ class WordExtractor(BaseExtractor): url = rel.reltype response = requests.get(url, stream=True) if response.status_code == 200: - image_ext = mimetypes.guess_extension(response.headers['Content-Type']) + image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) file_uuid = str(uuid.uuid4()) - file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, response.content) else: continue else: - image_ext = rel.target_ref.split('.')[-1] + image_ext = rel.target_ref.split(".")[-1] # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, rel.target_part.blob) @@ -112,12 +114,14 @@ class WordExtractor(BaseExtractor): created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=True, used_by=self.user_id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) db.session.add(upload_file) db.session.commit() - image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" + image_map[rel.target_part] = ( + f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" + ) return image_map @@ -167,8 +171,8 @@ class WordExtractor(BaseExtractor): def _parse_cell_paragraph(self, paragraph, image_map): paragraph_content = [] for run in paragraph.runs: - if run.element.xpath('.//a:blip'): - for blip in run.element.xpath('.//a:blip'): + if run.element.xpath(".//a:blip"): + for blip in run.element.xpath(".//a:blip"): image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") if not image_id: continue @@ -184,16 +188,16 @@ class WordExtractor(BaseExtractor): def _parse_paragraph(self, paragraph, image_map): paragraph_content = [] for run in paragraph.runs: - if run.element.xpath('.//a:blip'): - for blip in run.element.xpath('.//a:blip'): - embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed') + if run.element.xpath(".//a:blip"): + for blip in run.element.xpath(".//a:blip"): + embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") if embed_id: rel_target = run.part.rels[embed_id].target_ref if rel_target in image_map: paragraph_content.append(image_map[rel_target]) if run.text.strip(): paragraph_content.append(run.text.strip()) - return ' '.join(paragraph_content) if paragraph_content else '' + return " ".join(paragraph_content) if paragraph_content else "" def parse_docx(self, docx_path, image_folder): doc = DocxDocument(docx_path) @@ -204,60 +208,59 @@ class WordExtractor(BaseExtractor): image_map = self._extract_images_from_docx(doc, image_folder) hyperlinks_url = None - url_pattern = re.compile(r'http://[^\s+]+//|https://[^\s+]+') + url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+") for para in doc.paragraphs: for run in para.runs: if run.text and hyperlinks_url: - result = f' [{run.text}]({hyperlinks_url}) ' + result = f" [{run.text}]({hyperlinks_url}) " run.text = result hyperlinks_url = None - if 'HYPERLINK' in run.element.xml: + if "HYPERLINK" in run.element.xml: try: xml = ET.XML(run.element.xml) x_child = [c for c in xml.iter() if c is not None] for x in x_child: if x_child is None: continue - if x.tag.endswith('instrText'): + if x.tag.endswith("instrText"): for i in url_pattern.findall(x.text): hyperlinks_url = str(i) except Exception as e: logger.error(e) - - - def parse_paragraph(paragraph): paragraph_content = [] for run in paragraph.runs: - if hasattr(run.element, 'tag') and isinstance(element.tag, str) and run.element.tag.endswith('r'): + if hasattr(run.element, "tag") and isinstance(element.tag, str) and run.element.tag.endswith("r"): drawing_elements = run.element.findall( - './/{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing') + ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing" + ) for drawing in drawing_elements: blip_elements = drawing.findall( - './/{http://schemas.openxmlformats.org/drawingml/2006/main}blip') + ".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip" + ) for blip in blip_elements: embed_id = blip.get( - '{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed') + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed" + ) if embed_id: image_part = doc.part.related_parts.get(embed_id) if image_part in image_map: paragraph_content.append(image_map[image_part]) if run.text.strip(): paragraph_content.append(run.text.strip()) - return ''.join(paragraph_content) if paragraph_content else '' + return "".join(paragraph_content) if paragraph_content else "" paragraphs = doc.paragraphs.copy() tables = doc.tables.copy() for element in doc.element.body: - if hasattr(element, 'tag'): - if isinstance(element.tag, str) and element.tag.endswith('p'): # paragraph + if hasattr(element, "tag"): + if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph para = paragraphs.pop(0) parsed_paragraph = parse_paragraph(para) if parsed_paragraph: content.append(parsed_paragraph) - elif isinstance(element.tag, str) and element.tag.endswith('tbl'): # table + elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table table = tables.pop(0) content.append(self._table_to_markdown(table, image_map)) - return '\n'.join(content) - + return "\n".join(content) diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 630387fe3a..be857bd122 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from abc import ABC, abstractmethod from typing import Optional @@ -15,8 +16,7 @@ from models.dataset import Dataset, DatasetProcessRule class BaseIndexProcessor(ABC): - """Interface for extract files. - """ + """Interface for extract files.""" @abstractmethod def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: @@ -34,18 +34,24 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict) -> list[Document]: + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: raise NotImplementedError - def _get_splitter(self, processing_rule: dict, - embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - if processing_rule['mode'] == "custom": + if processing_rule["mode"] == "custom": # The user-defined segmentation rule - rules = processing_rule['rules'] + rules = processing_rule["rules"] segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: @@ -53,22 +59,22 @@ class BaseIndexProcessor(ABC): separator = segmentation["separator"] if separator: - separator = separator.replace('\\n', '\n') + separator = separator.replace("\\n", "\n") character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( chunk_size=segmentation["max_tokens"], - chunk_overlap=segmentation.get('chunk_overlap', 0) or 0, + chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) else: # Automatic segmentation character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], + chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], + chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) return character_splitter diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index df43a64910..9b855ece2c 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -7,8 +7,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess class IndexProcessorFactory: - """IndexProcessorInit. - """ + """IndexProcessorInit.""" def __init__(self, index_type: str): self._index_type = index_type @@ -22,7 +21,6 @@ class IndexProcessorFactory: if self._index_type == IndexType.PARAGRAPH_INDEX.value: return ParagraphIndexProcessor() elif self._index_type == IndexType.QA_INDEX.value: - return QAIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index bd7f6093bd..ed5712220f 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,4 +1,5 @@ """Paragraph index processor.""" + import uuid from typing import Optional @@ -15,33 +16,32 @@ from models.dataset import Dataset class ParagraphIndexProcessor(BaseIndexProcessor): - def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: - - text_docs = ExtractProcessor.extract(extract_setting=extract_setting, - is_automatic=kwargs.get('process_rule_mode') == "automatic") + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: # Split the text documents into nodes. - splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=kwargs.get('embedding_model_instance')) + splitter = self._get_splitter( + processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + ) all_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) split_documents = [] for document_node in document_nodes: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): @@ -55,7 +55,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return all_documents def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) if with_keywords: @@ -63,7 +63,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword.create(documents) def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -76,17 +76,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor): else: keyword.delete() - def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict) -> list[Document]: + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: # Set search parameters. - results = RetrievalService.retrieve(retrieval_method=retrieval_method, dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) # Organize results. docs = [] for result in results: metadata = result.metadata - metadata['score'] = result.score + metadata["score"] = result.score if result.score > score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index a44fd98036..1dbc473281 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -1,4 +1,5 @@ """Paragraph index processor.""" + import logging import re import threading @@ -23,33 +24,33 @@ from models.dataset import Dataset class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: - - text_docs = ExtractProcessor.extract(extract_setting=extract_setting, - is_automatic=kwargs.get('process_rule_mode') == "automatic") + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: - splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=kwargs.get('embedding_model_instance')) + splitter = self._get_splitter( + processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + ) # Split the text documents into nodes. all_documents = [] all_qa_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) split_documents = [] for document_node in document_nodes: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): @@ -61,14 +62,18 @@ class QAIndexProcessor(BaseIndexProcessor): all_documents.extend(split_documents) for i in range(0, len(all_documents), 10): threads = [] - sub_documents = all_documents[i:i + 10] + sub_documents = all_documents[i : i + 10] for doc in sub_documents: - document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={ - 'flask_app': current_app._get_current_object(), - 'tenant_id': kwargs.get('tenant_id'), - 'document_node': doc, - 'all_qa_documents': all_qa_documents, - 'document_language': kwargs.get('doc_language', 'English')}) + document_format_thread = threading.Thread( + target=self._format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), + "tenant_id": kwargs.get("tenant_id"), + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": kwargs.get("doc_language", "English"), + }, + ) threads.append(document_format_thread) document_format_thread.start() for thread in threads: @@ -76,9 +81,8 @@ class QAIndexProcessor(BaseIndexProcessor): return all_qa_documents def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: - # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: @@ -86,7 +90,7 @@ class QAIndexProcessor(BaseIndexProcessor): df = pd.read_csv(file) text_docs = [] for index, row in df.iterrows(): - data = Document(page_content=row[0], metadata={'answer': row[1]}) + data = Document(page_content=row[0], metadata={"answer": row[1]}) text_docs.append(data) if len(text_docs) == 0: raise ValueError("The CSV file is empty.") @@ -96,7 +100,7 @@ class QAIndexProcessor(BaseIndexProcessor): return text_docs def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) @@ -107,17 +111,29 @@ class QAIndexProcessor(BaseIndexProcessor): else: vector.delete() - def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict): + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ): # Set search parameters. - results = RetrievalService.retrieve(retrieval_method=retrieval_method, dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) # Organize results. docs = [] for result in results: metadata = result.metadata - metadata['score'] = result.score + metadata["score"] = result.score if result.score > score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) @@ -134,12 +150,12 @@ class QAIndexProcessor(BaseIndexProcessor): document_qa_list = self._format_split_text(response) qa_documents = [] for result in document_qa_list: - qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) + qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy()) doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result['question']) - qa_document.metadata['answer'] = result['answer'] - qa_document.metadata['doc_id'] = doc_id - qa_document.metadata['doc_hash'] = hash + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -151,10 +167,4 @@ class QAIndexProcessor(BaseIndexProcessor): regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [ - { - "question": q, - "answer": re.sub(r"\n\s*", "\n", a.strip()) - } - for q, a in matches if q and a - ] + return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 6f3c1c5d34..0ff1fdb81c 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -55,9 +55,7 @@ class BaseDocumentTransformer(ABC): """ @abstractmethod - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform a list of documents. Args: @@ -68,9 +66,7 @@ class BaseDocumentTransformer(ABC): """ @abstractmethod - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a list of documents. Args: diff --git a/api/core/rag/rerank/constants/rerank_mode.py b/api/core/rag/rerank/constants/rerank_mode.py index afbb9fd89d..d4894e3cc6 100644 --- a/api/core/rag/rerank/constants/rerank_mode.py +++ b/api/core/rag/rerank/constants/rerank_mode.py @@ -2,7 +2,5 @@ from enum import Enum class RerankMode(Enum): - - RERANKING_MODEL = 'reranking_model' - WEIGHTED_SCORE = 'weighted_score' - + RERANKING_MODEL = "reranking_model" + WEIGHTED_SCORE = "weighted_score" diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index d9067da288..6356ff87ab 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -8,8 +8,14 @@ class RerankModelRunner: def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance - def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: """ Run rerank model :param query: search query @@ -23,19 +29,15 @@ class RerankModelRunner: doc_id = [] unique_documents = [] for document in documents: - if document.metadata['doc_id'] not in doc_id: - doc_id.append(document.metadata['doc_id']) + if document.metadata["doc_id"] not in doc_id: + doc_id.append(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) documents = unique_documents rerank_result = self.rerank_model_instance.invoke_rerank( - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - user=user + query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) rerank_documents = [] @@ -45,12 +47,12 @@ class RerankModelRunner: rerank_document = Document( page_content=result.text, metadata={ - "doc_id": documents[result.index].metadata['doc_id'], - "doc_hash": documents[result.index].metadata['doc_hash'], - "document_id": documents[result.index].metadata['document_id'], - "dataset_id": documents[result.index].metadata['dataset_id'], - 'score': result.score - } + "doc_id": documents[result.index].metadata["doc_id"], + "doc_hash": documents[result.index].metadata["doc_hash"], + "document_id": documents[result.index].metadata["document_id"], + "dataset_id": documents[result.index].metadata["dataset_id"], + "score": result.score, + }, ) rerank_documents.append(rerank_document) diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index d8a7873982..4375079ee5 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -13,13 +13,18 @@ from core.rag.rerank.entity.weight import VectorSetting, Weights class WeightRerankRunner: - def __init__(self, tenant_id: str, weights: Weights) -> None: self.tenant_id = tenant_id self.weights = weights - def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: """ Run rerank model :param query: search query @@ -34,8 +39,8 @@ class WeightRerankRunner: doc_id = [] unique_documents = [] for document in documents: - if document.metadata['doc_id'] not in doc_id: - doc_id.append(document.metadata['doc_id']) + if document.metadata["doc_id"] not in doc_id: + doc_id.append(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) @@ -47,13 +52,15 @@ class WeightRerankRunner: query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting) for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores): # format document - score = self.weights.vector_setting.vector_weight * query_vector_score + \ - self.weights.keyword_setting.keyword_weight * query_score + score = ( + self.weights.vector_setting.vector_weight * query_vector_score + + self.weights.keyword_setting.keyword_weight * query_score + ) if score_threshold and score < score_threshold: continue - document.metadata['score'] = score + document.metadata["score"] = score rerank_documents.append(document) - rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True) + rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True) return rerank_documents[:top_n] if top_n else rerank_documents def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]: @@ -70,7 +77,7 @@ class WeightRerankRunner: for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata['keywords'] = document_keywords + document.metadata["keywords"] = document_keywords documents_keywords.append(document_keywords) # Counter query keywords(TF) @@ -132,8 +139,9 @@ class WeightRerankRunner: return similarities - def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document], - vector_setting: VectorSetting) -> list[float]: + def _calculate_cosine( + self, tenant_id: str, query: str, documents: list[Document], vector_setting: VectorSetting + ) -> list[float]: """ Calculate Cosine scores :param query: search query @@ -149,15 +157,14 @@ class WeightRerankRunner: tenant_id=tenant_id, provider=vector_setting.embedding_provider_name, model_type=ModelType.TEXT_EMBEDDING, - model=vector_setting.embedding_model_name - + model=vector_setting.embedding_model_name, ) cache_embedding = CacheEmbedding(embedding_model) query_vector = cache_embedding.embed_query(query) for document in documents: # calculate cosine similarity - if 'score' in document.metadata: - query_vector_scores.append(document.metadata['score']) + if "score" in document.metadata: + query_vector_scores.append(document.metadata["score"]) else: # transform to NumPy vec1 = np.array(query_vector) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index db01652f89..4948ec6ba8 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -32,14 +32,11 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -48,15 +45,18 @@ class DatasetRetrieval: self.application_generate_entity = application_generate_entity def retrieve( - self, app_id: str, user_id: str, tenant_id: str, - model_config: ModelConfigWithCredentialsEntity, - config: DatasetEntity, - query: str, - invoke_from: InvokeFrom, - show_retrieve_source: bool, - hit_callback: DatasetIndexToolCallbackHandler, - message_id: str, - memory: Optional[TokenBufferMemory] = None, + self, + app_id: str, + user_id: str, + tenant_id: str, + model_config: ModelConfigWithCredentialsEntity, + config: DatasetEntity, + query: str, + invoke_from: InvokeFrom, + show_retrieve_source: bool, + hit_callback: DatasetIndexToolCallbackHandler, + message_id: str, + memory: Optional[TokenBufferMemory] = None, ) -> Optional[str]: """ Retrieve dataset. @@ -84,16 +84,12 @@ class DatasetRetrieval: model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - provider=model_config.provider, - model=model_config.model + tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model ) # get model schema model_schema = model_type_instance.get_model_schema( - model=model_config.model, - credentials=model_config.credentials + model=model_config.model, credentials=model_config.credentials ) if not model_schema: @@ -102,39 +98,46 @@ class DatasetRetrieval: planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: + if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: continue # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): + if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: continue available_datasets.append(dataset) all_documents = [] - user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user' + user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: all_documents = self.single_retrieve( - app_id, tenant_id, user_id, user_from, available_datasets, query, + app_id, + tenant_id, + user_id, + user_from, + available_datasets, + query, model_instance, - model_config, planning_strategy, message_id + model_config, + planning_strategy, + message_id, ) elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: all_documents = self.multiple_retrieve( - app_id, tenant_id, user_id, user_from, - available_datasets, query, retrieve_config.top_k, + app_id, + tenant_id, + user_id, + user_from, + available_datasets, + query, + retrieve_config.top_k, retrieve_config.score_threshold, retrieve_config.rerank_mode, retrieve_config.reranking_model, @@ -145,89 +148,89 @@ class DatasetRetrieval: document_score_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if show_retrieve_source: context_list = [] resource_number = 1 for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = DatasetDocument.query.filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, ).first() - document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': invoke_from.to_source(), - 'score': document_score_list.get(segment.index_node_id, None) + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": invoke_from.to_source(), + "score": document_score_list.get(segment.index_node_id, None), } - if invoke_from.to_source() == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if invoke_from.to_source() == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 if hit_callback: hit_callback.return_retriever_resource_info(context_list) return str("\n".join(document_context_list)) - return '' + return "" def single_retrieve( - self, app_id: str, - tenant_id: str, - user_id: str, - user_from: str, - available_datasets: list, - query: str, - model_instance: ModelInstance, - model_config: ModelConfigWithCredentialsEntity, - planning_strategy: PlanningStrategy, - message_id: Optional[str] = None, + self, + app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + model_instance: ModelInstance, + model_config: ModelConfigWithCredentialsEntity, + planning_strategy: PlanningStrategy, + message_id: Optional[str] = None, ): tools = [] for dataset in available_datasets: description = dataset.description if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name + description = "useful for when you want to answer queries about the " + dataset.name - description = description.replace('\n', '').replace('\r', '') + description = description.replace("\n", "").replace("\r", "") message_tool = PromptMessageTool( name=dataset.id, description=description, @@ -235,14 +238,15 @@ class DatasetRetrieval: "type": "object", "properties": {}, "required": [], - } + }, ) tools.append(message_tool) dataset_id = None if planning_strategy == PlanningStrategy.REACT_ROUTER: react_multi_dataset_router = ReactMultiDatasetRouter() - dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance, - user_id, tenant_id) + dataset_id = react_multi_dataset_router.invoke( + query, tools, model_config, model_instance, user_id, tenant_id + ) elif planning_strategy == PlanningStrategy.ROUTER: function_call_router = FunctionCallMultiDatasetRouter() @@ -250,37 +254,37 @@ class DatasetRetrieval: if dataset_id: # get retrieval model config - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if dataset: - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model + retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model # get top k - top_k = retrieval_model_config['top_k'] + top_k = retrieval_model_config["top_k"] # get retrieval method if dataset.indexing_technique == "economy": - retrieval_method = 'keyword_search' + retrieval_method = "keyword_search" else: - retrieval_method = retrieval_model_config['search_method'] + retrieval_method = retrieval_model_config["search_method"] # get reranking model - reranking_model = retrieval_model_config['reranking_model'] \ - if retrieval_model_config['reranking_enable'] else None + reranking_model = ( + retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None + ) # get score threshold - score_threshold = .0 + score_threshold = 0.0 score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") if score_threshold_enabled: score_threshold = retrieval_model_config.get("score_threshold") with measure_time() as timer: results = RetrievalService.retrieve( - retrieval_method=retrieval_method, dataset_id=dataset.id, + retrieval_method=retrieval_method, + dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, + top_k=top_k, + score_threshold=score_threshold, reranking_model=reranking_model, - reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'), - weights=retrieval_model_config.get('weights', None), + reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"), + weights=retrieval_model_config.get("weights", None), ) self._on_query(query, [dataset_id], app_id, user_from, user_id) @@ -291,20 +295,20 @@ class DatasetRetrieval: return [] def multiple_retrieve( - self, - app_id: str, - tenant_id: str, - user_id: str, - user_from: str, - available_datasets: list, - query: str, - top_k: int, - score_threshold: float, - reranking_mode: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, - reranking_enable: bool = True, - message_id: Optional[str] = None, + self, + app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + top_k: int, + score_threshold: float, + reranking_mode: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + reranking_enable: bool = True, + message_id: Optional[str] = None, ): threads = [] all_documents = [] @@ -312,13 +316,16 @@ class DatasetRetrieval: index_type = None for dataset in available_datasets: index_type = dataset.indexing_technique - retrieval_thread = threading.Thread(target=self._retriever, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset.id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - }) + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents, + }, + ) threads.append(retrieval_thread) retrieval_thread.start() for thread in threads: @@ -327,16 +334,10 @@ class DatasetRetrieval: with measure_time() as timer: if reranking_enable: # do rerank for searched documents - data_post_processor = DataPostProcessor( - tenant_id, reranking_mode, - reranking_model, weights, False - ) + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k + query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k ) else: if index_type == "economy": @@ -357,30 +358,26 @@ class DatasetRetrieval: """Handle retrieval end.""" for document in documents: query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata['doc_id'] + DocumentSegment.index_node_id == document.metadata["doc_id"] ) # if 'dataset_id' in document.metadata: - if 'dataset_id' in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False - ) + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) db.session.commit() # get tracing instance - trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None + trace_manager: TraceQueueManager = ( + self.application_generate_entity.trace_manager if self.application_generate_entity else None + ) if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.DATASET_RETRIEVAL_TRACE, - message_id=message_id, - documents=documents, - timer=timer + TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer ) ) @@ -395,10 +392,10 @@ class DatasetRetrieval: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source='app', + source="app", source_app_id=app_id, created_by_role=user_from, - created_by=user_id + created_by=user_id, ) dataset_queries.append(dataset_query) if dataset_queries: @@ -407,9 +404,7 @@ class DatasetRetrieval: def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] @@ -419,38 +414,42 @@ class DatasetRetrieval: if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrieval_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k + ) if documents: all_documents.extend(documents) else: if top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'], - dataset_id=dataset.id, - query=query, - top_k=top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) all_documents.extend(documents) - def to_dataset_retriever_tool(self, tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler) \ - -> Optional[list[DatasetRetrieverBaseTool]]: + def to_dataset_retriever_tool( + self, + tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler, + ) -> Optional[list[DatasetRetrieverBaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id @@ -464,18 +463,14 @@ class DatasetRetrieval: available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: continue # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): + if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: continue available_datasets.append(dataset) @@ -483,22 +478,18 @@ class DatasetRetrieval: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # get retrieval model config default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } for dataset in available_datasets: - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model + retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model # get top k - top_k = retrieval_model_config['top_k'] + top_k = retrieval_model_config["top_k"] # get score threshold score_threshold = None @@ -512,7 +503,7 @@ class DatasetRetrieval: score_threshold=score_threshold, hit_callbacks=[hit_callback], return_resource=return_resource, - retriever_from=invoke_from.to_source() + retriever_from=invoke_from.to_source(), ) tools.append(tool) @@ -525,8 +516,8 @@ class DatasetRetrieval: hit_callbacks=[hit_callback], return_resource=return_resource, retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'), - reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name') + reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), + reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), ) tools.append(tool) @@ -547,7 +538,7 @@ class DatasetRetrieval: for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata['keywords'] = document_keywords + document.metadata["keywords"] = document_keywords documents_keywords.append(document_keywords) # Counter query keywords(TF) @@ -606,21 +597,19 @@ class DatasetRetrieval: for document, score in zip(documents, similarities): # format document - document.metadata['score'] = score - documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True) + document.metadata["score"] = score + documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) return documents[:top_k] if top_k else documents - def calculate_vector_score(self, all_documents: list[Document], - top_k: int, score_threshold: float) -> list[Document]: + def calculate_vector_score( + self, all_documents: list[Document], top_k: int, score_threshold: float + ) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold is None or document.metadata['score'] >= score_threshold: + if score_threshold is None or document.metadata["score"] >= score_threshold: filter_documents.append(document) if not filter_documents: return [] - filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True) + filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True) return filter_documents[:top_k] if top_k else filter_documents - - - diff --git a/api/core/rag/retrieval/output_parser/structured_chat.py b/api/core/rag/retrieval/output_parser/structured_chat.py index 60770bd4c6..7fc78bce83 100644 --- a/api/core/rag/retrieval/output_parser/structured_chat.py +++ b/api/core/rag/retrieval/output_parser/structured_chat.py @@ -16,9 +16,7 @@ class StructuredChatOutputParser: if response["action"] == "Final Answer": return ReactFinish({"output": response["action_input"]}, text) else: - return ReactAction( - response["action"], response.get("action_input", {}), text - ) + return ReactAction(response["action"], response.get("action_input", {}), text) else: return ReactFinish({"output": text}, text) except Exception as e: diff --git a/api/core/rag/retrieval/retrieval_methods.py b/api/core/rag/retrieval/retrieval_methods.py index 12aa28a51c..eaa00bca88 100644 --- a/api/core/rag/retrieval/retrieval_methods.py +++ b/api/core/rag/retrieval/retrieval_methods.py @@ -2,9 +2,9 @@ from enum import Enum class RetrievalMethod(Enum): - SEMANTIC_SEARCH = 'semantic_search' - FULL_TEXT_SEARCH = 'full_text_search' - HYBRID_SEARCH = 'hybrid_search' + SEMANTIC_SEARCH = "semantic_search" + FULL_TEXT_SEARCH = "full_text_search" + HYBRID_SEARCH = "hybrid_search" @staticmethod def is_support_semantic_search(retrieval_method: str) -> bool: diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 84e53952ac..06147fe7b5 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -6,14 +6,12 @@ from core.model_runtime.entities.message_entities import PromptMessageTool, Syst class FunctionCallMultiDatasetRouter: - def invoke( - self, - query: str, - dataset_tools: list[PromptMessageTool], - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - + self, + query: str, + dataset_tools: list[PromptMessageTool], + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, ) -> Union[str, None]: """Given input, decided what to do. Returns: @@ -26,22 +24,18 @@ class FunctionCallMultiDatasetRouter: try: prompt_messages = [ - SystemPromptMessage(content='You are a helpful AI assistant.'), - UserPromptMessage(content=query) + SystemPromptMessage(content="You are a helpful AI assistant."), + UserPromptMessage(content=query), ] result = model_instance.invoke_llm( prompt_messages=prompt_messages, tools=dataset_tools, stream=False, - model_parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } + model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, ) if result.message.tool_calls: # get retrieval model config return result.message.tool_calls[0].function.name return None except Exception as e: - return None \ No newline at end of file + return None diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 92f24277c1..33841cac06 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -50,16 +50,14 @@ Action: class ReactMultiDatasetRouter: - def invoke( - self, - query: str, - dataset_tools: list[PromptMessageTool], - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - user_id: str, - tenant_id: str - + self, + query: str, + dataset_tools: list[PromptMessageTool], + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + user_id: str, + tenant_id: str, ) -> Union[str, None]: """Given input, decided what to do. Returns: @@ -71,23 +69,28 @@ class ReactMultiDatasetRouter: return dataset_tools[0].name try: - return self._react_invoke(query=query, model_config=model_config, - model_instance=model_instance, - tools=dataset_tools, user_id=user_id, tenant_id=tenant_id) + return self._react_invoke( + query=query, + model_config=model_config, + model_instance=model_instance, + tools=dataset_tools, + user_id=user_id, + tenant_id=tenant_id, + ) except Exception as e: return None def _react_invoke( - self, - query: str, - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - tools: Sequence[PromptMessageTool], - user_id: str, - tenant_id: str, - prefix: str = PREFIX, - suffix: str = SUFFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + query: str, + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + tools: Sequence[PromptMessageTool], + user_id: str, + tenant_id: str, + prefix: str = PREFIX, + suffix: str = SUFFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> Union[str, None]: if model_config.mode == "chat": prompt = self.create_chat_prompt( @@ -103,18 +106,18 @@ class ReactMultiDatasetRouter: prefix=prefix, format_instructions=format_instructions, ) - stop = ['Observation:'] + stop = ["Observation:"] # handle invoke result prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( prompt_template=prompt, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=None, memory=None, - model_config=model_config + model_config=model_config, ) result_text, usage = self._invoke_llm( completion_param=model_config.parameters, @@ -122,7 +125,7 @@ class ReactMultiDatasetRouter: prompt_messages=prompt_messages, stop=stop, user_id=user_id, - tenant_id=tenant_id + tenant_id=tenant_id, ) output_parser = StructuredChatOutputParser() react_decision = output_parser.parse(result_text) @@ -130,17 +133,21 @@ class ReactMultiDatasetRouter: return react_decision.tool return None - def _invoke_llm(self, completion_param: dict, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: list[str], user_id: str, tenant_id: str - ) -> tuple[str, LLMUsage]: + def _invoke_llm( + self, + completion_param: dict, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: list[str], + user_id: str, + tenant_id: str, + ) -> tuple[str, LLMUsage]: """ - Invoke large language model - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: + Invoke large language model + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: """ invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, @@ -151,9 +158,7 @@ class ReactMultiDatasetRouter: ) # handle invoke result - text, usage = self._handle_invoke_result( - invoke_result=invoke_result - ) + text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) @@ -168,7 +173,7 @@ class ReactMultiDatasetRouter: """ model = None prompt_messages = [] - full_text = '' + full_text = "" usage = None for result in invoke_result: text = result.delta.message.content @@ -189,40 +194,35 @@ class ReactMultiDatasetRouter: return full_text, usage def create_chat_prompt( - self, - query: str, - tools: Sequence[PromptMessageTool], - prefix: str = PREFIX, - suffix: str = SUFFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + query: str, + tools: Sequence[PromptMessageTool], + prefix: str = PREFIX, + suffix: str = SUFFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> list[ChatModelMessage]: tool_strings = [] for tool in tools: tool_strings.append( - f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") + f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}" + ) formatted_tools = "\n".join(tool_strings) unique_tool_names = {tool.name for tool in tools} tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) format_instructions = format_instructions.format(tool_names=tool_names) template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) prompt_messages = [] - system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=template - ) + system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=template) prompt_messages.append(system_prompt_messages) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=query - ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=query) prompt_messages.append(user_prompt_message) return prompt_messages def create_completion_prompt( - self, - tools: Sequence[PromptMessageTool], - prefix: str = PREFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + tools: Sequence[PromptMessageTool], + prefix: str = PREFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> CompletionModelPromptTemplate: """Create prompt in the style of the zero shot agent. diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 0c1cb57c7f..53032b34d5 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -1,4 +1,5 @@ """Functionality for splitting text.""" + from __future__ import annotations from typing import Any, Optional @@ -18,31 +19,29 @@ from core.rag.splitter.text_splitter import ( class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): """ - This class is used to implement from_gpt2_encoder, to prevent using of tiktoken + This class is used to implement from_gpt2_encoder, to prevent using of tiktoken """ @classmethod def from_encoder( - cls: type[TS], - embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal[all], Set[str]] = set(), - disallowed_special: Union[Literal[all], Collection[str]] = "all", - **kwargs: Any, + cls: type[TS], + embedding_model_instance: Optional[ModelInstance], + allowed_special: Union[Literal[all], Set[str]] = set(), + disallowed_special: Union[Literal[all], Collection[str]] = "all", + **kwargs: Any, ): def _token_encoder(text: str) -> int: if not text: return 0 if embedding_model_instance: - return embedding_model_instance.get_text_embedding_num_tokens( - texts=[text] - ) + return embedding_model_instance.get_text_embedding_num_tokens(texts=[text]) else: return GPT2Tokenizer.get_num_tokens(text) if issubclass(cls, TokenTextSplitter): extra_kwargs = { - "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2', + "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2", "allowed_special": allowed_special, "disallowed_special": disallowed_special, } diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index f06f22a00e..97d0721304 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -22,9 +22,7 @@ logger = logging.getLogger(__name__) TS = TypeVar("TS", bound="TextSplitter") -def _split_text_with_regex( - text: str, separator: str, keep_separator: bool -) -> list[str]: +def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]: # Now that we have the separator, split the text if separator: if keep_separator: @@ -37,19 +35,19 @@ def _split_text_with_regex( splits = re.split(separator, text) else: splits = list(text) - return [s for s in splits if (s != "" and s != '\n')] + return [s for s in splits if (s != "" and s != "\n")] class TextSplitter(BaseDocumentTransformer, ABC): """Interface for splitting text into chunks.""" def __init__( - self, - chunk_size: int = 4000, - chunk_overlap: int = 200, - length_function: Callable[[str], int] = len, - keep_separator: bool = False, - add_start_index: bool = False, + self, + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, ) -> None: """Create a new TextSplitter. @@ -62,8 +60,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): """ if chunk_overlap > chunk_size: raise ValueError( - f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " - f"({chunk_size}), should be smaller." + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller." ) self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap @@ -75,9 +72,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" - def create_documents( - self, texts: list[str], metadatas: Optional[list[dict]] = None - ) -> list[Document]: + def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] @@ -119,14 +114,10 @@ class TextSplitter(BaseDocumentTransformer, ABC): index = 0 for d in splits: _len = lengths[index] - if ( - total + _len + (separator_len if len(current_doc) > 0 else 0) - > self._chunk_size - ): + if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: if total > self._chunk_size: logger.warning( - f"Created a chunk of size {total}, " - f"which is longer than the specified {self._chunk_size}" + f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}" ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) @@ -136,13 +127,9 @@ class TextSplitter(BaseDocumentTransformer, ABC): # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( - total + _len + (separator_len if len(current_doc) > 0 else 0) - > self._chunk_size - and total > 0 + total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 ): - total -= self._length_function(current_doc[0]) + ( - separator_len if len(current_doc) > 1 else 0 - ) + total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0) current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) @@ -159,28 +146,25 @@ class TextSplitter(BaseDocumentTransformer, ABC): from transformers import PreTrainedTokenizerBase if not isinstance(tokenizer, PreTrainedTokenizerBase): - raise ValueError( - "Tokenizer received was not an instance of PreTrainedTokenizerBase" - ) + raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") def _huggingface_tokenizer_length(text: str) -> int: return len(tokenizer.encode(text)) except ImportError: raise ValueError( - "Could not import transformers python package. " - "Please install it with `pip install transformers`." + "Could not import transformers python package. " "Please install it with `pip install transformers`." ) return cls(length_function=_huggingface_tokenizer_length, **kwargs) @classmethod def from_tiktoken_encoder( - cls: type[TS], - encoding_name: str = "gpt2", - model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, + cls: type[TS], + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ) -> TS: """Text splitter that uses tiktoken encoder to count length.""" try: @@ -217,15 +201,11 @@ class TextSplitter(BaseDocumentTransformer, ABC): return cls(length_function=_tiktoken_encoder, **kwargs) - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform sequence of documents by splitting them.""" return self.split_documents(list(documents)) - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a sequence of documents by splitting them.""" raise NotImplementedError @@ -267,9 +247,7 @@ class HeaderType(TypedDict): class MarkdownHeaderTextSplitter: """Splitting markdown files based on specified headers.""" - def __init__( - self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False - ): + def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False): """Create a new MarkdownHeaderTextSplitter. Args: @@ -280,9 +258,7 @@ class MarkdownHeaderTextSplitter: self.return_each_line = return_each_line # Given the headers we want to split on, # (e.g., "#, ##, etc") order by length - self.headers_to_split_on = sorted( - headers_to_split_on, key=lambda split: len(split[0]), reverse=True - ) + self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True) def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: """Combine lines with common metadata into chunks @@ -292,10 +268,7 @@ class MarkdownHeaderTextSplitter: aggregated_chunks: list[LineType] = [] for line in lines: - if ( - aggregated_chunks - and aggregated_chunks[-1]["metadata"] == line["metadata"] - ): + if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]: # If the last line in the aggregated list # has the same metadata as the current line, # append the current content to the last lines's content @@ -304,10 +277,7 @@ class MarkdownHeaderTextSplitter: # Otherwise, append the current line to the aggregated list aggregated_chunks.append(line) - return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) - for chunk in aggregated_chunks - ] + return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks] def split_text(self, text: str) -> list[Document]: """Split markdown file @@ -332,10 +302,9 @@ class MarkdownHeaderTextSplitter: for sep, name in self.headers_to_split_on: # Check if line starts with a header that we intend to split on if stripped_line.startswith(sep) and ( - # Header with no text OR header is followed by space - # Both are valid conditions that sep is being used a header - len(stripped_line) == len(sep) - or stripped_line[len(sep)] == " " + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " ): # Ensure we are tracking the header as metadata if name is not None: @@ -343,10 +312,7 @@ class MarkdownHeaderTextSplitter: current_header_level = sep.count("#") # Pop out headers of lower or same level from the stack - while ( - header_stack - and header_stack[-1]["level"] >= current_header_level - ): + while header_stack and header_stack[-1]["level"] >= current_header_level: # We have encountered a new header # at the same or higher level popped_header = header_stack.pop() @@ -359,7 +325,7 @@ class MarkdownHeaderTextSplitter: header: HeaderType = { "level": current_header_level, "name": name, - "data": stripped_line[len(sep):].strip(), + "data": stripped_line[len(sep) :].strip(), } header_stack.append(header) # Update initial_metadata with the current header @@ -392,9 +358,7 @@ class MarkdownHeaderTextSplitter: current_metadata = initial_metadata.copy() if current_content: - lines_with_metadata.append( - {"content": "\n".join(current_content), "metadata": current_metadata} - ) + lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata}) # lines_with_metadata has each line with associated header metadata # aggregate these into chunks based on common metadata @@ -402,8 +366,7 @@ class MarkdownHeaderTextSplitter: return self.aggregate_lines_to_chunks(lines_with_metadata) else: return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) - for chunk in lines_with_metadata + Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata ] @@ -436,12 +399,12 @@ class TokenTextSplitter(TextSplitter): """Splitting text to tokens using model tokenizer.""" def __init__( - self, - encoding_name: str = "gpt2", - model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, + self, + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) @@ -488,10 +451,10 @@ class RecursiveCharacterTextSplitter(TextSplitter): """ def __init__( - self, - separators: Optional[list[str]] = None, - keep_separator: bool = True, - **kwargs: Any, + self, + separators: Optional[list[str]] = None, + keep_separator: bool = True, + **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(keep_separator=keep_separator, **kwargs) @@ -508,7 +471,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): break if re.search(_s, text): separator = _s - new_separators = separators[i + 1:] + new_separators = separators[i + 1 :] break splits = _split_text_with_regex(text, separator, self._keep_separator) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 2b01b8fd8e..b988a588e9 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -10,23 +10,23 @@ from core.tools.tool.tool import ToolParameter class UserTool(BaseModel): author: str - name: str # identifier - label: I18nObject # label + name: str # identifier + label: I18nObject # label description: I18nObject parameters: Optional[list[ToolParameter]] = None labels: list[str] = None -UserToolProviderTypeLiteral = Optional[Literal[ - 'builtin', 'api', 'workflow' -]] + +UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]] + class UserToolProvider(BaseModel): id: str author: str - name: str # identifier + name: str # identifier description: I18nObject icon: str - label: I18nObject # label + label: I18nObject # label type: ToolProviderType masked_credentials: Optional[dict] = None original_credentials: Optional[dict] = None @@ -40,26 +40,27 @@ class UserToolProvider(BaseModel): # overwrite tool parameter types for temp fix tools = jsonable_encoder(self.tools) for tool in tools: - if tool.get('parameters'): - for parameter in tool.get('parameters'): - if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value: - parameter['type'] = 'files' + if tool.get("parameters"): + for parameter in tool.get("parameters"): + if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value: + parameter["type"] = "files" # ------------- return { - 'id': self.id, - 'author': self.author, - 'name': self.name, - 'description': self.description.to_dict(), - 'icon': self.icon, - 'label': self.label.to_dict(), - 'type': self.type.value, - 'team_credentials': self.masked_credentials, - 'is_team_authorization': self.is_team_authorization, - 'allow_delete': self.allow_delete, - 'tools': tools, - 'labels': self.labels, + "id": self.id, + "author": self.author, + "name": self.name, + "description": self.description.to_dict(), + "icon": self.icon, + "label": self.label.to_dict(), + "type": self.type.value, + "team_credentials": self.masked_credentials, + "is_team_authorization": self.is_team_authorization, + "allow_delete": self.allow_delete, + "tools": tools, + "labels": self.labels, } + class UserToolProviderCredentials(BaseModel): - credentials: dict[str, ToolProviderCredentials] \ No newline at end of file + credentials: dict[str, ToolProviderCredentials] diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index 55e31e8c35..37a926697b 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -7,6 +7,7 @@ class I18nObject(BaseModel): """ Model class for i18n object. """ + zh_Hans: Optional[str] = None pt_BR: Optional[str] = None en_US: str @@ -19,8 +20,4 @@ class I18nObject(BaseModel): self.pt_BR = self.en_US def to_dict(self) -> dict: - return { - 'zh_Hans': self.zh_Hans, - 'en_US': self.en_US, - 'pt_BR': self.pt_BR - } + return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR} diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index d18d27fb02..da6201c5aa 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -9,6 +9,7 @@ class ApiToolBundle(BaseModel): """ This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc. """ + # server_url server_url: str # method diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index e31dec55d2..02b8b35be7 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -7,27 +7,29 @@ from core.tools.entities.common_entities import I18nObject class ToolLabelEnum(Enum): - SEARCH = 'search' - IMAGE = 'image' - VIDEOS = 'videos' - WEATHER = 'weather' - FINANCE = 'finance' - DESIGN = 'design' - TRAVEL = 'travel' - SOCIAL = 'social' - NEWS = 'news' - MEDICAL = 'medical' - PRODUCTIVITY = 'productivity' - EDUCATION = 'education' - BUSINESS = 'business' - ENTERTAINMENT = 'entertainment' - UTILITIES = 'utilities' - OTHER = 'other' + SEARCH = "search" + IMAGE = "image" + VIDEOS = "videos" + WEATHER = "weather" + FINANCE = "finance" + DESIGN = "design" + TRAVEL = "travel" + SOCIAL = "social" + NEWS = "news" + MEDICAL = "medical" + PRODUCTIVITY = "productivity" + EDUCATION = "education" + BUSINESS = "business" + ENTERTAINMENT = "entertainment" + UTILITIES = "utilities" + OTHER = "other" + class ToolProviderType(Enum): """ - Enum class for tool provider + Enum class for tool provider """ + BUILT_IN = "builtin" WORKFLOW = "workflow" API = "api" @@ -35,7 +37,7 @@ class ToolProviderType(Enum): DATASET_RETRIEVAL = "dataset-retrieval" @classmethod - def value_of(cls, value: str) -> 'ToolProviderType': + def value_of(cls, value: str) -> "ToolProviderType": """ Get value of given mode. @@ -45,19 +47,21 @@ class ToolProviderType(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ApiProviderSchemaType(Enum): """ Enum class for api provider schema type. """ + OPENAPI = "openapi" SWAGGER = "swagger" OPENAI_PLUGIN = "openai_plugin" OPENAI_ACTIONS = "openai_actions" @classmethod - def value_of(cls, value: str) -> 'ApiProviderSchemaType': + def value_of(cls, value: str) -> "ApiProviderSchemaType": """ Get value of given mode. @@ -67,17 +71,19 @@ class ApiProviderSchemaType(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ApiProviderAuthType(Enum): """ Enum class for api provider auth type. """ + NONE = "none" API_KEY = "api_key" @classmethod - def value_of(cls, value: str) -> 'ApiProviderAuthType': + def value_of(cls, value: str) -> "ApiProviderAuthType": """ Get value of given mode. @@ -87,7 +93,8 @@ class ApiProviderAuthType(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ToolInvokeMessage(BaseModel): class MessageType(Enum): @@ -105,19 +112,21 @@ class ToolInvokeMessage(BaseModel): """ message: str | bytes | dict | None = None meta: dict[str, Any] | None = None - save_as: str = '' + save_as: str = "" + class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") - save_as: str = '' + save_as: str = "" file_var: Optional[dict[str, Any]] = None + class ToolParameterOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") - @field_validator('value', mode='before') + @field_validator("value", mode="before") @classmethod def transform_id_to_str(cls, value) -> str: if not isinstance(value, str): @@ -136,9 +145,9 @@ class ToolParameter(BaseModel): FILE = "file" class ToolParameterForm(Enum): - SCHEMA = "schema" # should be set while adding tool - FORM = "form" # should be set before invoking tool - LLM = "llm" # will be set by LLM + SCHEMA = "schema" # should be set while adding tool + FORM = "form" # should be set before invoking tool + LLM = "llm" # will be set by LLM name: str = Field(..., description="The name of the parameter") label: I18nObject = Field(..., description="The label presented to the user") @@ -154,25 +163,32 @@ class ToolParameter(BaseModel): options: Optional[list[ToolParameterOption]] = None @classmethod - def get_simple_instance(cls, - name: str, llm_description: str, type: ToolParameterType, - required: bool, options: Optional[list[str]] = None) -> 'ToolParameter': + def get_simple_instance( + cls, + name: str, + llm_description: str, + type: ToolParameterType, + required: bool, + options: Optional[list[str]] = None, + ) -> "ToolParameter": """ - get a simple tool parameter + get a simple tool parameter - :param name: the name of the parameter - :param llm_description: the description presented to the LLM - :param type: the type of the parameter - :param required: if the parameter is required - :param options: the options of the parameter + :param name: the name of the parameter + :param llm_description: the description presented to the LLM + :param type: the type of the parameter + :param required: if the parameter is required + :param options: the options of the parameter """ # convert options to ToolParameterOption if options: - options = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options] + options = [ + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options + ] return cls( name=name, - label=I18nObject(en_US='', zh_Hans=''), - human_description=I18nObject(en_US='', zh_Hans=''), + label=I18nObject(en_US="", zh_Hans=""), + human_description=I18nObject(en_US="", zh_Hans=""), type=type, form=cls.ToolParameterForm.LLM, llm_description=llm_description, @@ -180,18 +196,24 @@ class ToolParameter(BaseModel): options=options, ) + class ToolProviderIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field(default=[], description="The tags of the tool", ) + tags: Optional[list[ToolLabelEnum]] = Field( + default=[], + description="The tags of the tool", + ) + class ToolDescription(BaseModel): human: I18nObject = Field(..., description="The description presented to the user") llm: str = Field(..., description="The description presented to the LLM") + class ToolIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") @@ -199,10 +221,12 @@ class ToolIdentity(BaseModel): provider: str = Field(..., description="The provider of the tool") icon: Optional[str] = None + class ToolCredentialsOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") + class ToolProviderCredentials(BaseModel): class CredentialsType(Enum): SECRET_INPUT = "secret-input" @@ -221,7 +245,7 @@ class ToolProviderCredentials(BaseModel): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") @staticmethod def default(value: str) -> str: @@ -239,33 +263,38 @@ class ToolProviderCredentials(BaseModel): def to_dict(self) -> dict: return { - 'name': self.name, - 'type': self.type.value, - 'required': self.required, - 'default': self.default, - 'options': self.options, - 'help': self.help.to_dict() if self.help else None, - 'label': self.label.to_dict(), - 'url': self.url, - 'placeholder': self.placeholder.to_dict() if self.placeholder else None, + "name": self.name, + "type": self.type.value, + "required": self.required, + "default": self.default, + "options": self.options, + "help": self.help.to_dict() if self.help else None, + "label": self.label.to_dict(), + "url": self.url, + "placeholder": self.placeholder.to_dict() if self.placeholder else None, } + class ToolRuntimeVariableType(Enum): TEXT = "text" IMAGE = "image" + class ToolRuntimeVariable(BaseModel): type: ToolRuntimeVariableType = Field(..., description="The type of the variable") name: str = Field(..., description="The name of the variable") position: int = Field(..., description="The position of the variable") tool_name: str = Field(..., description="The name of the tool") + class ToolRuntimeTextVariable(ToolRuntimeVariable): value: str = Field(..., description="The value of the variable") + class ToolRuntimeImageVariable(ToolRuntimeVariable): value: str = Field(..., description="The path of the image") + class ToolRuntimeVariablePool(BaseModel): conversation_id: str = Field(..., description="The conversation id") user_id: str = Field(..., description="The user id") @@ -274,26 +303,26 @@ class ToolRuntimeVariablePool(BaseModel): pool: list[ToolRuntimeVariable] = Field(..., description="The pool of variables") def __init__(self, **data: Any): - pool = data.get('pool', []) + pool = data.get("pool", []) # convert pool into correct type for index, variable in enumerate(pool): - if variable['type'] == ToolRuntimeVariableType.TEXT.value: + if variable["type"] == ToolRuntimeVariableType.TEXT.value: pool[index] = ToolRuntimeTextVariable(**variable) - elif variable['type'] == ToolRuntimeVariableType.IMAGE.value: + elif variable["type"] == ToolRuntimeVariableType.IMAGE.value: pool[index] = ToolRuntimeImageVariable(**variable) super().__init__(**data) def dict(self) -> dict: return { - 'conversation_id': self.conversation_id, - 'user_id': self.user_id, - 'tenant_id': self.tenant_id, - 'pool': [variable.model_dump() for variable in self.pool], + "conversation_id": self.conversation_id, + "user_id": self.user_id, + "tenant_id": self.tenant_id, + "pool": [variable.model_dump() for variable in self.pool], } def set_text(self, tool_name: str, name: str, value: str) -> None: """ - set a text variable + set a text variable """ for variable in self.pool: if variable.name == name: @@ -314,10 +343,10 @@ class ToolRuntimeVariablePool(BaseModel): def set_file(self, tool_name: str, value: str, name: str = None) -> None: """ - set an image variable + set an image variable - :param tool_name: the name of the tool - :param value: the id of the file + :param tool_name: the name of the tool + :param value: the id of the file """ # check how many image variables are there image_variable_count = 0 @@ -345,22 +374,27 @@ class ToolRuntimeVariablePool(BaseModel): self.pool.append(variable) + class ModelToolPropertyKey(Enum): IMAGE_PARAMETER_NAME = "image_parameter_name" + class ModelToolConfiguration(BaseModel): """ Model tool configuration """ + type: str = Field(..., description="The type of the model tool") model: str = Field(..., description="The model") label: I18nObject = Field(..., description="The label of the model tool") properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool") + class ModelToolProviderConfiguration(BaseModel): """ Model tool provider configuration """ + provider: str = Field(..., description="The provider of the model tool") models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") label: I18nObject = Field(..., description="The label of the model tool") @@ -370,27 +404,30 @@ class WorkflowToolParameterConfiguration(BaseModel): """ Workflow tool configuration """ + name: str = Field(..., description="The name of the parameter") description: str = Field(..., description="The description of the parameter") form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") + class ToolInvokeMeta(BaseModel): """ Tool invoke meta """ + time_cost: float = Field(..., description="The time cost of the tool invoke") error: Optional[str] = None tool_config: Optional[dict] = None @classmethod - def empty(cls) -> 'ToolInvokeMeta': + def empty(cls) -> "ToolInvokeMeta": """ Get an empty instance of ToolInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> 'ToolInvokeMeta': + def error_instance(cls, error: str) -> "ToolInvokeMeta": """ Get an instance of ToolInvokeMeta with error """ @@ -398,22 +435,26 @@ class ToolInvokeMeta(BaseModel): def to_dict(self) -> dict: return { - 'time_cost': self.time_cost, - 'error': self.error, - 'tool_config': self.tool_config, + "time_cost": self.time_cost, + "error": self.error, + "tool_config": self.tool_config, } + class ToolLabel(BaseModel): """ Tool label """ + name: str = Field(..., description="The name of the tool") label: I18nObject = Field(..., description="The label of the tool") icon: str = Field(..., description="The icon of the tool") + class ToolInvokeFrom(Enum): """ Enum class for tool invoke """ + WORKFLOW = "workflow" AGENT = "agent" diff --git a/api/core/tools/entities/values.py b/api/core/tools/entities/values.py index d0be5e9355..f9db190f91 100644 --- a/api/core/tools/entities/values.py +++ b/api/core/tools/entities/values.py @@ -2,73 +2,109 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum ICONS = { - ToolLabelEnum.SEARCH: ''' + ToolLabelEnum.SEARCH: """ -''', - ToolLabelEnum.IMAGE: ''' +""", + ToolLabelEnum.IMAGE: """ -''', - ToolLabelEnum.VIDEOS: ''' +""", + ToolLabelEnum.VIDEOS: """ -''', - ToolLabelEnum.WEATHER: ''' +""", + ToolLabelEnum.WEATHER: """ -''', - ToolLabelEnum.FINANCE: ''' +""", + ToolLabelEnum.FINANCE: """ -''', - ToolLabelEnum.DESIGN: ''' +""", + ToolLabelEnum.DESIGN: """ -''', - ToolLabelEnum.TRAVEL: ''' +""", + ToolLabelEnum.TRAVEL: """ -''', - ToolLabelEnum.SOCIAL: ''' +""", + ToolLabelEnum.SOCIAL: """ -''', - ToolLabelEnum.NEWS: ''' +""", + ToolLabelEnum.NEWS: """ -''', - ToolLabelEnum.MEDICAL: ''' +""", + ToolLabelEnum.MEDICAL: """ -''', - ToolLabelEnum.PRODUCTIVITY: ''' +""", + ToolLabelEnum.PRODUCTIVITY: """ -''', - ToolLabelEnum.EDUCATION: ''' +""", + ToolLabelEnum.EDUCATION: """ -''', - ToolLabelEnum.BUSINESS: ''' +""", + ToolLabelEnum.BUSINESS: """ -''', - ToolLabelEnum.ENTERTAINMENT: ''' +""", + ToolLabelEnum.ENTERTAINMENT: """ -''', - ToolLabelEnum.UTILITIES: ''' +""", + ToolLabelEnum.UTILITIES: """ -''', - ToolLabelEnum.OTHER: ''' +""", + ToolLabelEnum.OTHER: """ -''' +""", } default_tool_label_dict = { - ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]), - ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]), - ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]), - ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]), - ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]), - ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]), - ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]), - ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]), - ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]), - ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]), - ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]), - ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]), - ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]), - ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]), - ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]), - ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]), + ToolLabelEnum.SEARCH: ToolLabel( + name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH] + ), + ToolLabelEnum.IMAGE: ToolLabel( + name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE] + ), + ToolLabelEnum.VIDEOS: ToolLabel( + name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS] + ), + ToolLabelEnum.WEATHER: ToolLabel( + name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER] + ), + ToolLabelEnum.FINANCE: ToolLabel( + name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE] + ), + ToolLabelEnum.DESIGN: ToolLabel( + name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN] + ), + ToolLabelEnum.TRAVEL: ToolLabel( + name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL] + ), + ToolLabelEnum.SOCIAL: ToolLabel( + name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL] + ), + ToolLabelEnum.NEWS: ToolLabel( + name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS] + ), + ToolLabelEnum.MEDICAL: ToolLabel( + name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL] + ), + ToolLabelEnum.PRODUCTIVITY: ToolLabel( + name="productivity", + label=I18nObject(en_US="Productivity", zh_Hans="生产力"), + icon=ICONS[ToolLabelEnum.PRODUCTIVITY], + ), + ToolLabelEnum.EDUCATION: ToolLabel( + name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION] + ), + ToolLabelEnum.BUSINESS: ToolLabel( + name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS] + ), + ToolLabelEnum.ENTERTAINMENT: ToolLabel( + name="entertainment", + label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"), + icon=ICONS[ToolLabelEnum.ENTERTAINMENT], + ), + ToolLabelEnum.UTILITIES: ToolLabel( + name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES] + ), + ToolLabelEnum.OTHER: ToolLabel( + name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER] + ), } default_tool_labels = [v for k, v in default_tool_label_dict.items()] diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index 9fd8322db1..6febf137b0 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -4,23 +4,30 @@ from core.tools.entities.tool_entities import ToolInvokeMeta class ToolProviderNotFoundError(ValueError): pass + class ToolNotFoundError(ValueError): pass + class ToolParameterValidationError(ValueError): pass + class ToolProviderCredentialValidationError(ValueError): pass + class ToolNotSupportedError(ValueError): pass + class ToolInvokeError(ValueError): pass + class ToolApiSchemaError(ValueError): pass + class ToolEngineInvokeError(Exception): - meta: ToolInvokeMeta \ No newline at end of file + meta: ToolInvokeMeta diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index ae80ad2114..2e6018cffc 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -1,4 +1,3 @@ - from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( @@ -18,85 +17,69 @@ class ApiToolProviderController(ToolProviderController): provider_id: str @staticmethod - def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController': + def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": credentials_schema = { - 'auth_type': ToolProviderCredentials( - name='auth_type', + "auth_type": ToolProviderCredentials( + name="auth_type", required=True, type=ToolProviderCredentials.CredentialsType.SELECT, options=[ - ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')), - ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key')) + ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), + ToolCredentialsOption(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")), ], - default='none', - help=I18nObject( - en_US='The auth type of the api provider', - zh_Hans='api provider 的认证类型' - ) + default="none", + help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"), ) } if auth_type == ApiProviderAuthType.API_KEY: credentials_schema = { **credentials_schema, - 'api_key_header': ToolProviderCredentials( - name='api_key_header', + "api_key_header": ToolProviderCredentials( + name="api_key_header", required=False, - default='api_key', + default="api_key", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, - help=I18nObject( - en_US='The header name of the api key', - zh_Hans='携带 api key 的 header 名称' - ) + help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"), ), - 'api_key_value': ToolProviderCredentials( - name='api_key_value', + "api_key_value": ToolProviderCredentials( + name="api_key_value", required=True, type=ToolProviderCredentials.CredentialsType.SECRET_INPUT, - help=I18nObject( - en_US='The api key', - zh_Hans='api key的值' - ) + help=I18nObject(en_US="The api key", zh_Hans="api key的值"), ), - 'api_key_header_prefix': ToolProviderCredentials( - name='api_key_header_prefix', + "api_key_header_prefix": ToolProviderCredentials( + name="api_key_header_prefix", required=False, - default='basic', + default="basic", type=ToolProviderCredentials.CredentialsType.SELECT, - help=I18nObject( - en_US='The prefix of the api key header', - zh_Hans='api key header 的前缀' - ), + help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"), options=[ - ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')), - ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')), - ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom')) - ] - ) + ToolCredentialsOption(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")), + ToolCredentialsOption(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")), + ToolCredentialsOption(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")), + ], + ), } elif auth_type == ApiProviderAuthType.NONE: pass else: - raise ValueError(f'invalid auth type {auth_type}') + raise ValueError(f"invalid auth type {auth_type}") - user_name = db_provider.user.name if db_provider.user_id else '' + user_name = db_provider.user.name if db_provider.user_id else "" - return ApiToolProviderController(**{ - 'identity': { - 'author': user_name, - 'name': db_provider.name, - 'label': { - 'en_US': db_provider.name, - 'zh_Hans': db_provider.name + return ApiToolProviderController( + **{ + "identity": { + "author": user_name, + "name": db_provider.name, + "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name}, + "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, + "icon": db_provider.icon, }, - 'description': { - 'en_US': db_provider.description, - 'zh_Hans': db_provider.description - }, - 'icon': db_provider.icon, - }, - 'credentials_schema': credentials_schema, - 'provider_id': db_provider.id or '', - }) + "credentials_schema": credentials_schema, + "provider_id": db_provider.id or "", + } + ) @property def provider_type(self) -> ToolProviderType: @@ -104,39 +87,35 @@ class ApiToolProviderController(ToolProviderController): def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool: """ - parse tool bundle to tool + parse tool bundle to tool - :param tool_bundle: the tool bundle - :return: the tool + :param tool_bundle: the tool bundle + :return: the tool """ - return ApiTool(**{ - 'api_bundle': tool_bundle, - 'identity' : { - 'author': tool_bundle.author, - 'name': tool_bundle.operation_id, - 'label': { - 'en_US': tool_bundle.operation_id, - 'zh_Hans': tool_bundle.operation_id + return ApiTool( + **{ + "api_bundle": tool_bundle, + "identity": { + "author": tool_bundle.author, + "name": tool_bundle.operation_id, + "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id}, + "icon": self.identity.icon, + "provider": self.provider_id, }, - 'icon': self.identity.icon, - 'provider': self.provider_id, - }, - 'description': { - 'human': { - 'en_US': tool_bundle.summary or '', - 'zh_Hans': tool_bundle.summary or '' + "description": { + "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""}, + "llm": tool_bundle.summary or "", }, - 'llm': tool_bundle.summary or '' - }, - 'parameters' : tool_bundle.parameters if tool_bundle.parameters else [], - }) + "parameters": tool_bundle.parameters if tool_bundle.parameters else [], + } + ) def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: """ - load bundled tools + load bundled tools - :param tools: the bundled tools - :return: the tools + :param tools: the bundled tools + :return: the tools """ self.tools = [self._parse_tool_bundle(tool) for tool in tools] @@ -144,22 +123,23 @@ class ApiToolProviderController(ToolProviderController): def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: """ - fetch tools from database + fetch tools from database - :param user_id: the user id - :param tenant_id: the tenant id - :return: the tools + :param user_id: the user id + :param tenant_id: the tenant id + :return: the tools """ if self.tools is not None: return self.tools - + tools: list[Tool] = [] # get tenant api providers - db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == self.identity.name - ).all() + db_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name) + .all() + ) if db_providers and len(db_providers) != 0: for db_provider in db_providers: @@ -167,16 +147,16 @@ class ApiToolProviderController(ToolProviderController): assistant_tool = self._parse_tool_bundle(tool) assistant_tool.is_team_authorization = True tools.append(assistant_tool) - + self.tools = tools return tools - + def get_tool(self, tool_name: str) -> ApiTool: """ - get tool by name + get tool by name - :param tool_name: the name of the tool - :return: the tool + :param tool_name: the name of the tool + :return: the tool """ if self.tools is None: self.get_tools() @@ -185,4 +165,4 @@ class ApiToolProviderController(ToolProviderController): if tool.identity.name == tool_name: return tool - raise ValueError(f'tool {tool_name} not found') \ No newline at end of file + raise ValueError(f"tool {tool_name} not found") diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 2d472e0a93..01544d7e56 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -11,11 +11,12 @@ from models.tools import PublishedAppTool logger = logging.getLogger(__name__) + class AppToolProviderEntity(ToolProviderController): @property def provider_type(self) -> ToolProviderType: return ToolProviderType.APP - + def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: pass @@ -23,9 +24,13 @@ class AppToolProviderEntity(ToolProviderController): pass def get_tools(self, user_id: str) -> list[Tool]: - db_tools: list[PublishedAppTool] = db.session.query(PublishedAppTool).filter( - PublishedAppTool.user_id == user_id, - ).all() + db_tools: list[PublishedAppTool] = ( + db.session.query(PublishedAppTool) + .filter( + PublishedAppTool.user_id == user_id, + ) + .all() + ) if not db_tools or len(db_tools) == 0: return [] @@ -34,23 +39,17 @@ class AppToolProviderEntity(ToolProviderController): for db_tool in db_tools: tool = { - 'identity': { - 'author': db_tool.author, - 'name': db_tool.tool_name, - 'label': { - 'en_US': db_tool.tool_name, - 'zh_Hans': db_tool.tool_name - }, - 'icon': '' + "identity": { + "author": db_tool.author, + "name": db_tool.tool_name, + "label": {"en_US": db_tool.tool_name, "zh_Hans": db_tool.tool_name}, + "icon": "", }, - 'description': { - 'human': { - 'en_US': db_tool.description_i18n.en_US, - 'zh_Hans': db_tool.description_i18n.zh_Hans - }, - 'llm': db_tool.llm_description + "description": { + "human": {"en_US": db_tool.description_i18n.en_US, "zh_Hans": db_tool.description_i18n.zh_Hans}, + "llm": db_tool.llm_description, }, - 'parameters': [] + "parameters": [], } # get app from db app: App = db_tool.app @@ -64,52 +63,41 @@ class AppToolProviderEntity(ToolProviderController): for input_form in user_input_form_list: # get type form_type = input_form.keys()[0] - default = input_form[form_type]['default'] - required = input_form[form_type]['required'] - label = input_form[form_type]['label'] - variable_name = input_form[form_type]['variable_name'] - options = input_form[form_type].get('options', []) - if form_type == 'paragraph' or form_type == 'text-input': - tool['parameters'].append(ToolParameter( - name=variable_name, - label=I18nObject( - en_US=label, - zh_Hans=label - ), - human_description=I18nObject( - en_US=label, - zh_Hans=label - ), - llm_description=label, - form=ToolParameter.ToolParameterForm.FORM, - type=ToolParameter.ToolParameterType.STRING, - required=required, - default=default - )) - elif form_type == 'select': - tool['parameters'].append(ToolParameter( - name=variable_name, - label=I18nObject( - en_US=label, - zh_Hans=label - ), - human_description=I18nObject( - en_US=label, - zh_Hans=label - ), - llm_description=label, - form=ToolParameter.ToolParameterForm.FORM, - type=ToolParameter.ToolParameterType.SELECT, - required=required, - default=default, - options=[ToolParameterOption( - value=option, - label=I18nObject( - en_US=option, - zh_Hans=option - ) - ) for option in options] - )) + default = input_form[form_type]["default"] + required = input_form[form_type]["required"] + label = input_form[form_type]["label"] + variable_name = input_form[form_type]["variable_name"] + options = input_form[form_type].get("options", []) + if form_type == "paragraph" or form_type == "text-input": + tool["parameters"].append( + ToolParameter( + name=variable_name, + label=I18nObject(en_US=label, zh_Hans=label), + human_description=I18nObject(en_US=label, zh_Hans=label), + llm_description=label, + form=ToolParameter.ToolParameterForm.FORM, + type=ToolParameter.ToolParameterType.STRING, + required=required, + default=default, + ) + ) + elif form_type == "select": + tool["parameters"].append( + ToolParameter( + name=variable_name, + label=I18nObject(en_US=label, zh_Hans=label), + human_description=I18nObject(en_US=label, zh_Hans=label), + llm_description=label, + form=ToolParameter.ToolParameterForm.FORM, + type=ToolParameter.ToolParameterType.SELECT, + required=required, + default=default, + options=[ + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in options + ], + ) + ) tools.append(Tool(**tool)) - return tools \ No newline at end of file + return tools diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index 062668fc5b..5c10f72fda 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -10,7 +10,7 @@ class BuiltinToolProviderSort: @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: if not cls._position: - cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..')) + cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), "..")) def name_func(provider: UserToolProvider) -> str: return provider.name diff --git a/api/core/tools/provider/builtin/aippt/aippt.py b/api/core/tools/provider/builtin/aippt/aippt.py index 25133c51df..e0cbbd2992 100644 --- a/api/core/tools/provider/builtin/aippt/aippt.py +++ b/api/core/tools/provider/builtin/aippt/aippt.py @@ -6,6 +6,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class AIPPTProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__') + AIPPTGenerateTool._get_api_token(credentials, user_id="__dify_system__") except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index 8d6883a3b1..7cee8f9f79 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -20,16 +20,16 @@ class AIPPTGenerateTool(BuiltinTool): A tool for generating a ppt """ - _api_base_url = URL('https://co.aippt.cn/api') + _api_base_url = URL("https://co.aippt.cn/api") _api_token_cache = {} - _api_token_cache_lock:Optional[Lock] = None + _api_token_cache_lock: Optional[Lock] = None _style_cache = {} - _style_cache_lock:Optional[Lock] = None + _style_cache_lock: Optional[Lock] = None _task = {} _task_type_map = { - 'auto': 1, - 'markdown': 7, + "auto": 1, + "markdown": 7, } def __init__(self, **kwargs: Any): @@ -48,65 +48,55 @@ class AIPPTGenerateTool(BuiltinTool): Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. """ - title = tool_parameters.get('title', '') + title = tool_parameters.get("title", "") if not title: - return self.create_text_message('Please provide a title for the ppt') - - model = tool_parameters.get('model', 'aippt') + return self.create_text_message("Please provide a title for the ppt") + + model = tool_parameters.get("model", "aippt") if not model: - return self.create_text_message('Please provide a model for the ppt') - - outline = tool_parameters.get('outline', '') + return self.create_text_message("Please provide a model for the ppt") + + outline = tool_parameters.get("outline", "") # create task task_id = self._create_task( - type=self._task_type_map['auto' if not outline else 'markdown'], + type=self._task_type_map["auto" if not outline else "markdown"], title=title, content=outline, - user_id=user_id + user_id=user_id, ) # get suit - color = tool_parameters.get('color') - style = tool_parameters.get('style') + color = tool_parameters.get("color") + style = tool_parameters.get("style") - if color == '__default__': - color_id = '' + if color == "__default__": + color_id = "" else: - color_id = int(color.split('-')[1]) + color_id = int(color.split("-")[1]) - if style == '__default__': - style_id = '' + if style == "__default__": + style_id = "" else: - style_id = int(style.split('-')[1]) + style_id = int(style.split("-")[1]) suit_id = self._get_suit(style_id=style_id, colour_id=color_id) # generate outline if not outline: - self._generate_outline( - task_id=task_id, - model=model, - user_id=user_id - ) + self._generate_outline(task_id=task_id, model=model, user_id=user_id) # generate content - self._generate_content( - task_id=task_id, - model=model, - user_id=user_id - ) + self._generate_content(task_id=task_id, model=model, user_id=user_id) # generate ppt - _, ppt_url = self._generate_ppt( - task_id=task_id, - suit_id=suit_id, - user_id=user_id - ) + _, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id) - return self.create_text_message('''the ppt has been created successfully,''' - f'''the ppt url is {ppt_url}''' - '''please give the ppt url to user and direct user to download it.''') + return self.create_text_message( + """the ppt has been created successfully,""" + f"""the ppt url is {ppt_url}""" + """please give the ppt url to user and direct user to download it.""" + ) def _create_task(self, type: int, title: str, content: str, user_id: str) -> str: """ @@ -119,129 +109,121 @@ class AIPPTGenerateTool(BuiltinTool): :return: the task ID """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } response = post( - str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'), + str(self._api_base_url / "ai" / "chat" / "v2" / "task"), headers=headers, - files={ - 'type': ('', str(type)), - 'title': ('', title), - 'content': ('', content) - } + files={"type": ("", str(type)), "title": ("", title), "content": ("", content)}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to create task: {response.get("msg")}') - return response.get('data', {}).get('id') - + return response.get("data", {}).get("id") + def _generate_outline(self, task_id: str, model: str, user_id: str) -> str: - api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \ - self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline' - api_url %= {'task_id': task_id} + api_url = ( + self._api_base_url / "ai" / "chat" / "outline" + if model == "aippt" + else self._api_base_url / "ai" / "chat" / "wx" / "outline" + ) + api_url %= {"task_id": task_id} headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } - response = requests_get( - url=api_url, - headers=headers, - stream=True, - timeout=(10, 60) - ) + response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - - outline = '' - for chunk in response.iter_lines(delimiter=b'\n\n'): + raise Exception(f"Failed to connect to aippt: {response.text}") + + outline = "" + for chunk in response.iter_lines(delimiter=b"\n\n"): if not chunk: continue - - event = '' - lines = chunk.decode('utf-8').split('\n') + + event = "" + lines = chunk.decode("utf-8").split("\n") for line in lines: - if line.startswith('event:'): + if line.startswith("event:"): event = line[6:] - elif line.startswith('data:'): + elif line.startswith("data:"): data = line[5:] - if event == 'message': + if event == "message": try: data = json_loads(data) - outline += data.get('content', '') + outline += data.get("content", "") except Exception as e: pass - elif event == 'close': + elif event == "close": break - elif event == 'error' or event == 'filter': - raise Exception(f'Failed to generate outline: {data}') - + elif event == "error" or event == "filter": + raise Exception(f"Failed to generate outline: {data}") + return outline - + def _generate_content(self, task_id: str, model: str, user_id: str) -> str: - api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \ - self._api_base_url / 'ai' / 'chat' / 'wx' / 'content' - api_url %= {'task_id': task_id} + api_url = ( + self._api_base_url / "ai" / "chat" / "content" + if model == "aippt" + else self._api_base_url / "ai" / "chat" / "wx" / "content" + ) + api_url %= {"task_id": task_id} headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } - response = requests_get( - url=api_url, - headers=headers, - stream=True, - timeout=(10, 60) - ) + response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - - if model == 'aippt': - content = '' - for chunk in response.iter_lines(delimiter=b'\n\n'): + raise Exception(f"Failed to connect to aippt: {response.text}") + + if model == "aippt": + content = "" + for chunk in response.iter_lines(delimiter=b"\n\n"): if not chunk: continue - - event = '' - lines = chunk.decode('utf-8').split('\n') + + event = "" + lines = chunk.decode("utf-8").split("\n") for line in lines: - if line.startswith('event:'): + if line.startswith("event:"): event = line[6:] - elif line.startswith('data:'): + elif line.startswith("data:"): data = line[5:] - if event == 'message': + if event == "message": try: data = json_loads(data) - content += data.get('content', '') + content += data.get("content", "") except Exception as e: pass - elif event == 'close': + elif event == "close": break - elif event == 'error' or event == 'filter': - raise Exception(f'Failed to generate content: {data}') - + elif event == "error" or event == "filter": + raise Exception(f"Failed to generate content: {data}") + return content - elif model == 'wenxin': + elif model == "wenxin": response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate content: {response.get("msg")}') - - return response.get('data', '') - - return '' + + return response.get("data", "") + + return "" def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: """ @@ -252,83 +234,73 @@ class AIPPTGenerateTool(BuiltinTool): :return: the cover url of the ppt and the ppt url """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), } response = post( - str(self._api_base_url / 'design' / 'v2' / 'save'), + str(self._api_base_url / "design" / "v2" / "save"), headers=headers, - data={ - 'task_id': task_id, - 'template_id': suit_id - } + data={"task_id": task_id, "template_id": suit_id}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - id = response.get('data', {}).get('id') - cover_url = response.get('data', {}).get('cover_url') + + id = response.get("data", {}).get("id") + cover_url = response.get("data", {}).get("cover_url") response = post( - str(self._api_base_url / 'download' / 'export' / 'file'), + str(self._api_base_url / "download" / "export" / "file"), headers=headers, - data={ - 'id': id, - 'format': 'ppt', - 'files_to_zip': False, - 'edit': True - } + data={"id": id, "format": "ppt", "files_to_zip": False, "edit": True}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - export_code = response.get('data') + + export_code = response.get("data") if not export_code: - raise Exception('Failed to generate ppt, the export code is empty') - + raise Exception("Failed to generate ppt, the export code is empty") + current_iteration = 0 while current_iteration < 50: # get ppt url response = post( - str(self._api_base_url / 'download' / 'export' / 'file' / 'result'), + str(self._api_base_url / "download" / "export" / "file" / "result"), headers=headers, - data={ - 'task_key': export_code - } + data={"task_key": export_code}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - if response.get('msg') == '导出中': + + if response.get("msg") == "导出中": current_iteration += 1 sleep(2) continue - - ppt_url = response.get('data', []) + + ppt_url = response.get("data", []) if len(ppt_url) == 0: - raise Exception('Failed to generate ppt, the ppt url is empty') - + raise Exception("Failed to generate ppt, the ppt url is empty") + return cover_url, ppt_url[0] - - raise Exception('Failed to generate ppt, the export is timeout') - + + raise Exception("Failed to generate ppt, the export is timeout") + @classmethod def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: """ @@ -337,53 +309,43 @@ class AIPPTGenerateTool(BuiltinTool): :param credentials: the credentials :return: the API token """ - access_key = credentials['aippt_access_key'] - secret_key = credentials['aippt_secret_key'] + access_key = credentials["aippt_access_key"] + secret_key = credentials["aippt_secret_key"] - cache_key = f'{access_key}#@#{user_id}' + cache_key = f"{access_key}#@#{user_id}" with cls._api_token_cache_lock: # clear expired tokens now = time() for key in list(cls._api_token_cache.keys()): - if cls._api_token_cache[key]['expire'] < now: + if cls._api_token_cache[key]["expire"] < now: del cls._api_token_cache[key] if cache_key in cls._api_token_cache: - return cls._api_token_cache[cache_key]['token'] - + return cls._api_token_cache[cache_key]["token"] + # get token headers = { - 'x-api-key': access_key, - 'x-timestamp': str(int(now)), - 'x-signature': cls._calculate_sign(access_key, secret_key, int(now)) + "x-api-key": access_key, + "x-timestamp": str(int(now)), + "x-signature": cls._calculate_sign(access_key, secret_key, int(now)), } - param = { - 'uid': user_id, - 'channel': '' - } + param = {"uid": user_id, "channel": ""} - response = get( - str(cls._api_base_url / 'grant' / 'token'), - params=param, - headers=headers - ) + response = get(str(cls._api_base_url / "grant" / "token"), params=param, headers=headers) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') + raise Exception(f"Failed to connect to aippt: {response.text}") response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - token = response.get('data', {}).get('token') - expire = response.get('data', {}).get('time_expire') + + token = response.get("data", {}).get("token") + expire = response.get("data", {}).get("time_expire") with cls._api_token_cache_lock: - cls._api_token_cache[cache_key] = { - 'token': token, - 'expire': now + expire - } + cls._api_token_cache[cache_key] = {"token": token, "expire": now + expire} return token @@ -391,11 +353,9 @@ class AIPPTGenerateTool(BuiltinTool): def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str: return b64encode( hmac_new( - key=secret_key.encode('utf-8'), - msg=f'GET@/api/grant/token/@{timestamp}'.encode(), - digestmod=sha1 + key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1 ).digest() - ).decode('utf-8') + ).decode("utf-8") @classmethod def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]: @@ -408,47 +368,46 @@ class AIPPTGenerateTool(BuiltinTool): # clear expired styles now = time() for key in list(cls._style_cache.keys()): - if cls._style_cache[key]['expire'] < now: + if cls._style_cache[key]["expire"] < now: del cls._style_cache[key] key = f'{credentials["aippt_access_key"]}#@#{user_id}' if key in cls._style_cache: - return cls._style_cache[key]['colors'], cls._style_cache[key]['styles'] + return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"] headers = { - 'x-channel': '', - 'x-api-key': credentials['aippt_access_key'], - 'x-token': cls._get_api_token(credentials=credentials, user_id=user_id) + "x-channel": "", + "x-api-key": credentials["aippt_access_key"], + "x-token": cls._get_api_token(credentials=credentials, user_id=user_id), } - response = get( - str(cls._api_base_url / 'template_component' / 'suit' / 'select'), - headers=headers - ) + response = get(str(cls._api_base_url / "template_component" / "suit" / "select"), headers=headers) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - colors = [{ - 'id': f'id-{item.get("id")}', - 'name': item.get('name'), - 'en_name': item.get('en_name', item.get('name')), - } for item in response.get('data', {}).get('colour') or []] - styles = [{ - 'id': f'id-{item.get("id")}', - 'name': item.get('title'), - } for item in response.get('data', {}).get('suit_style') or []] + + colors = [ + { + "id": f'id-{item.get("id")}', + "name": item.get("name"), + "en_name": item.get("en_name", item.get("name")), + } + for item in response.get("data", {}).get("colour") or [] + ] + styles = [ + { + "id": f'id-{item.get("id")}', + "name": item.get("title"), + } + for item in response.get("data", {}).get("suit_style") or [] + ] with cls._style_cache_lock: - cls._style_cache[key] = { - 'colors': colors, - 'styles': styles, - 'expire': now + 60 * 60 - } + cls._style_cache[key] = {"colors": colors, "styles": styles, "expire": now + 60 * 60} return colors, styles @@ -459,44 +418,39 @@ class AIPPTGenerateTool(BuiltinTool): :param credentials: the credentials :return: Tuple[list[dict[id, color]], list[dict[id, style]] """ - if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'): - raise Exception('Please provide aippt credentials') + if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"): + raise Exception("Please provide aippt credentials") return self._get_styles(credentials=self.runtime.credentials, user_id=user_id) - + def _get_suit(self, style_id: int, colour_id: int) -> int: """ Get suit """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__') + "x-channel": "", + "x-api-key": self.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"), } response = get( - str(self._api_base_url / 'template_component' / 'suit' / 'search'), + str(self._api_base_url / "template_component" / "suit" / "search"), headers=headers, - params={ - 'style_id': style_id, - 'colour_id': colour_id, - 'page': 1, - 'page_size': 1 - } + params={"style_id": style_id, "colour_id": colour_id, "page": 1, "page_size": 1}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - if len(response.get('data', {}).get('list') or []) > 0: - return response.get('data', {}).get('list')[0].get('id') - - raise Exception('Failed to get suit, the suit does not exist, please check the style and color') - + + if len(response.get("data", {}).get("list") or []) > 0: + return response.get("data", {}).get("list")[0].get("id") + + raise Exception("Failed to get suit, the suit does not exist, please check the style and color") + def get_runtime_parameters(self) -> list[ToolParameter]: """ Get runtime parameters @@ -504,43 +458,40 @@ class AIPPTGenerateTool(BuiltinTool): Override this method to add runtime parameters to the tool. """ try: - colors, styles = self.get_styles(user_id='__dify_system__') + colors, styles = self.get_styles(user_id="__dify_system__") except Exception as e: - colors, styles = [ - {'id': '-1', 'name': '__default__', 'en_name': '__default__'} - ], [ - {'id': '-1', 'name': '__default__', 'en_name': '__default__'} - ] + colors, styles = ( + [{"id": "-1", "name": "__default__", "en_name": "__default__"}], + [{"id": "-1", "name": "__default__", "en_name": "__default__"}], + ) return [ ToolParameter( - name='color', - label=I18nObject(zh_Hans='颜色', en_US='Color'), - human_description=I18nObject(zh_Hans='颜色', en_US='Color'), + name="color", + label=I18nObject(zh_Hans="颜色", en_US="Color"), + human_description=I18nObject(zh_Hans="颜色", en_US="Color"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=False, - default=colors[0]['id'], + default=colors[0]["id"], options=[ ToolParameterOption( - value=color['id'], - label=I18nObject(zh_Hans=color['name'], en_US=color['en_name']) - ) for color in colors - ] + value=color["id"], label=I18nObject(zh_Hans=color["name"], en_US=color["en_name"]) + ) + for color in colors + ], ), ToolParameter( - name='style', - label=I18nObject(zh_Hans='风格', en_US='Style'), - human_description=I18nObject(zh_Hans='风格', en_US='Style'), + name="style", + label=I18nObject(zh_Hans="风格", en_US="Style"), + human_description=I18nObject(zh_Hans="风格", en_US="Style"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=False, - default=styles[0]['id'], + default=styles[0]["id"], options=[ - ToolParameterOption( - value=style['id'], - label=I18nObject(zh_Hans=style['name'], en_US=style['name']) - ) for style in styles - ] + ToolParameterOption(value=style["id"], label=I18nObject(zh_Hans=style["name"], en_US=style["name"])) + for style in styles + ], ), - ] \ No newline at end of file + ] diff --git a/api/core/tools/provider/builtin/alphavantage/alphavantage.py b/api/core/tools/provider/builtin/alphavantage/alphavantage.py index 01f2acfb5b..a84630e5aa 100644 --- a/api/core/tools/provider/builtin/alphavantage/alphavantage.py +++ b/api/core/tools/provider/builtin/alphavantage/alphavantage.py @@ -13,7 +13,7 @@ class AlphaVantageProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "code": "AAPL", # Apple Inc. }, diff --git a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py index 5c379b746d..d06611acd0 100644 --- a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py +++ b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py @@ -9,17 +9,16 @@ ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query" class QueryStockTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - stock_code = tool_parameters.get('code', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + stock_code = tool_parameters.get("code", "") if not stock_code: - return self.create_text_message('Please tell me your stock code') + return self.create_text_message("Please tell me your stock code") - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("Alpha Vantage API key is required.") params = { @@ -27,7 +26,7 @@ class QueryStockTool(BuiltinTool): "symbol": stock_code, "outputsize": "compact", "datatype": "json", - "apikey": self.runtime.credentials['api_key'] + "apikey": self.runtime.credentials["api_key"], } response = requests.get(url=ALPHAVANTAGE_API_URL, params=params) response.raise_for_status() @@ -35,15 +34,15 @@ class QueryStockTool(BuiltinTool): return self.create_json_message(result) def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]: - result = response.get('Time Series (Daily)', {}) + result = response.get("Time Series (Daily)", {}) if not result: return {} stock_result = {} for k, v in result.items(): stock_result[k] = {} - stock_result[k]['open'] = v.get('1. open') - stock_result[k]['high'] = v.get('2. high') - stock_result[k]['low'] = v.get('3. low') - stock_result[k]['close'] = v.get('4. close') - stock_result[k]['volume'] = v.get('5. volume') + stock_result[k]["open"] = v.get("1. open") + stock_result[k]["high"] = v.get("2. high") + stock_result[k]["low"] = v.get("3. low") + stock_result[k]["close"] = v.get("4. close") + stock_result[k]["volume"] = v.get("5. volume") return stock_result diff --git a/api/core/tools/provider/builtin/arxiv/arxiv.py b/api/core/tools/provider/builtin/arxiv/arxiv.py index 707fc69be3..ebb2d1a8c4 100644 --- a/api/core/tools/provider/builtin/arxiv/arxiv.py +++ b/api/core/tools/provider/builtin/arxiv/arxiv.py @@ -11,11 +11,10 @@ class ArxivProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py index ce28373880..98d82c233e 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py @@ -8,6 +8,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool logger = logging.getLogger(__name__) + + class ArxivAPIWrapper(BaseModel): """Wrapper around ArxivAPI. @@ -86,11 +88,13 @@ class ArxivAPIWrapper(BaseModel): class ArxivSearchInput(BaseModel): query: str = Field(..., description="Search query.") - + + class ArxivSearchTool(BuiltinTool): """ A tool for searching articles on Arxiv. """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ Invokes the Arxiv search tool with the given user ID and tool parameters. @@ -102,13 +106,13 @@ class ArxivSearchTool(BuiltinTool): Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - + return self.create_text_message("Please input query") + arxiv = ArxivAPIWrapper() - + response = arxiv.run(query) - + return self.create_text_message(self.summary(user_id=user_id, content=response)) diff --git a/api/core/tools/provider/builtin/aws/aws.py b/api/core/tools/provider/builtin/aws/aws.py index 13ede96015..f81b5dbd27 100644 --- a/api/core/tools/provider/builtin/aws/aws.py +++ b/api/core/tools/provider/builtin/aws/aws.py @@ -11,15 +11,14 @@ class SageMakerProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - "sagemaker_endpoint" : "", + "sagemaker_endpoint": "", "query": "misaka mikoto", - "candidate_texts" : "hello$$$hello world", - "topk" : 5, - "aws_region" : "" + "candidate_texts": "hello$$$hello world", + "topk": 5, + "aws_region": "", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index 06fcf8a453..d6a65b1708 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -12,6 +12,7 @@ from core.tools.tool.builtin_tool import BuiltinTool logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + class GuardrailParameters(BaseModel): guardrail_id: str = Field(..., description="The identifier of the guardrail") guardrail_version: str = Field(..., description="The version of the guardrail") @@ -19,35 +20,35 @@ class GuardrailParameters(BaseModel): text: str = Field(..., description="The text to apply the guardrail to") aws_region: str = Field(..., description="AWS region for the Bedrock client") + class ApplyGuardrailTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the ApplyGuardrail tool """ try: # Validate and parse input parameters params = GuardrailParameters(**tool_parameters) - + # Initialize AWS client - bedrock_client = boto3.client('bedrock-runtime', region_name=params.aws_region) + bedrock_client = boto3.client("bedrock-runtime", region_name=params.aws_region) # Apply guardrail response = bedrock_client.apply_guardrail( guardrailIdentifier=params.guardrail_id, guardrailVersion=params.guardrail_version, source=params.source, - content=[{"text": {"text": params.text}}] + content=[{"text": {"text": params.text}}], ) - + logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}") # Check for empty response if not response: return self.create_text_message(text="Received empty response from AWS Bedrock.") - + # Process the result action = response.get("action", "No action specified") outputs = response.get("outputs", []) @@ -58,9 +59,11 @@ class ApplyGuardrailTool(BuiltinTool): formatted_assessments = [] for assessment in assessments: for policy_type, policy_data in assessment.items(): - if isinstance(policy_data, dict) and 'topics' in policy_data: - for topic in policy_data['topics']: - formatted_assessments.append(f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}") + if isinstance(policy_data, dict) and "topics" in policy_data: + for topic in policy_data["topics"]: + formatted_assessments.append( + f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}" + ) else: formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}") @@ -68,19 +71,19 @@ class ApplyGuardrailTool(BuiltinTool): result += f"Output: {output}\n " if formatted_assessments: result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n " -# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}" + # result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}" return self.create_text_message(text=result) except BotoCoreError as e: - error_message = f'AWS service error: {str(e)}' + error_message = f"AWS service error: {str(e)}" logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) except json.JSONDecodeError as e: - error_message = f'JSON parsing error: {str(e)}' + error_message = f"JSON parsing error: {str(e)}" logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) except Exception as e: - error_message = f'An unexpected error occurred: {str(e)}' + error_message = f"An unexpected error occurred: {str(e)}" logger.error(error_message, exc_info=True) - return self.create_text_message(text=error_message) \ No newline at end of file + return self.create_text_message(text=error_message) diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index 005ba3deb5..48755753ac 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -11,78 +11,81 @@ class LambdaTranslateUtilsTool(BuiltinTool): lambda_client: Any = None def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name): - msg = { - "src_content":text_content, - "src_lang": src_lang, - "dest_lang":dest_lang, + msg = { + "src_content": text_content, + "src_lang": src_lang, + "dest_lang": dest_lang, "dictionary_id": dictionary_name, - "request_type" : request_type, - "model_id" : model_id + "request_type": request_type, + "model_id": model_id, } - invoke_response = self.lambda_client.invoke(FunctionName=lambda_name, - InvocationType='RequestResponse', - Payload=json.dumps(msg)) - response_body = invoke_response['Payload'] + invoke_response = self.lambda_client.invoke( + FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg) + ) + response_body = invoke_response["Payload"] response_str = response_body.read().decode("unicode_escape") return response_str - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ line = 0 try: if not self.lambda_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.lambda_client = boto3.client("lambda", region_name=aws_region) else: self.lambda_client = boto3.client("lambda") line = 1 - text_content = tool_parameters.get('text_content', '') + text_content = tool_parameters.get("text_content", "") if not text_content: - return self.create_text_message('Please input text_content') - + return self.create_text_message("Please input text_content") + line = 2 - src_lang = tool_parameters.get('src_lang', '') + src_lang = tool_parameters.get("src_lang", "") if not src_lang: - return self.create_text_message('Please input src_lang') - + return self.create_text_message("Please input src_lang") + line = 3 - dest_lang = tool_parameters.get('dest_lang', '') + dest_lang = tool_parameters.get("dest_lang", "") if not dest_lang: - return self.create_text_message('Please input dest_lang') - + return self.create_text_message("Please input dest_lang") + line = 4 - lambda_name = tool_parameters.get('lambda_name', '') + lambda_name = tool_parameters.get("lambda_name", "") if not lambda_name: - return self.create_text_message('Please input lambda_name') - + return self.create_text_message("Please input lambda_name") + line = 5 - request_type = tool_parameters.get('request_type', '') + request_type = tool_parameters.get("request_type", "") if not request_type: - return self.create_text_message('Please input request_type') - + return self.create_text_message("Please input request_type") + line = 6 - model_id = tool_parameters.get('model_id', '') + model_id = tool_parameters.get("model_id", "") if not model_id: - return self.create_text_message('Please input model_id') + return self.create_text_message("Please input model_id") line = 7 - dictionary_name = tool_parameters.get('dictionary_name', '') + dictionary_name = tool_parameters.get("dictionary_name", "") if not dictionary_name: - return self.create_text_message('Please input dictionary_name') - - result = self._invoke_lambda(text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name) + return self.create_text_message("Please input dictionary_name") + + result = self._invoke_lambda( + text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name + ) return self.create_text_message(text=result) except Exception as e: - return self.create_text_message(f'Exception {str(e)}, line : {line}') + return self.create_text_message(f"Exception {str(e)}, line : {line}") diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py index bb7f6840b8..f43f3b6fe0 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py @@ -18,54 +18,53 @@ class LambdaYamlToJsonTool(BuiltinTool): lambda_client: Any = None def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str: - msg = { - "body": yaml_content - } + msg = {"body": yaml_content} logger.info(json.dumps(msg)) - invoke_response = self.lambda_client.invoke(FunctionName=lambda_name, - InvocationType='RequestResponse', - Payload=json.dumps(msg)) - response_body = invoke_response['Payload'] + invoke_response = self.lambda_client.invoke( + FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg) + ) + response_body = invoke_response["Payload"] response_str = response_body.read().decode("utf-8") resp_json = json.loads(response_str) logger.info(resp_json) - if resp_json['statusCode'] != 200: + if resp_json["statusCode"] != 200: raise Exception(f"Invalid status code: {response_str}") - return resp_json['body'] + return resp_json["body"] - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ try: if not self.lambda_client: - aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region + aws_region = tool_parameters.get("aws_region") # todo: move aws_region out, and update client region if aws_region: self.lambda_client = boto3.client("lambda", region_name=aws_region) else: self.lambda_client = boto3.client("lambda") - yaml_content = tool_parameters.get('yaml_content', '') + yaml_content = tool_parameters.get("yaml_content", "") if not yaml_content: - return self.create_text_message('Please input yaml_content') + return self.create_text_message("Please input yaml_content") - lambda_name = tool_parameters.get('lambda_name', '') + lambda_name = tool_parameters.get("lambda_name", "") if not lambda_name: - return self.create_text_message('Please input lambda_name') - logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}') - + return self.create_text_message("Please input lambda_name") + logger.debug(f"{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}") + result = self._invoke_lambda(lambda_name, yaml_content) logger.debug(result) - + return self.create_text_message(result) except Exception as e: - return self.create_text_message(f'Exception: {str(e)}') + return self.create_text_message(f"Exception: {str(e)}") - console_handler.flush() \ No newline at end of file + console_handler.flush() diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index 2b3a3eaad6..3c35b65e66 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -9,37 +9,33 @@ from core.tools.tool.builtin_tool import BuiltinTool class SageMakerReRankTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint:str = None - topk:int = None + sagemaker_endpoint: str = None + topk: int = None - def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): - inputs = [query_input]*len(docs) + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): + inputs = [query_input] * len(docs) response_model = self.sagemaker_client.invoke_endpoint( EndpointName=rerank_endpoint, - Body=json.dumps( - { - "inputs": inputs, - "docs": docs - } - ), + Body=json.dumps({"inputs": inputs, "docs": docs}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - scores = json_obj['scores'] + scores = json_obj["scores"] return scores if isinstance(scores, list) else [scores] - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ line = 0 try: if not self.sagemaker_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -47,25 +43,25 @@ class SageMakerReRankTool(BuiltinTool): line = 1 if not self.sagemaker_endpoint: - self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint') + self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") line = 2 if not self.topk: - self.topk = tool_parameters.get('topk', 5) + self.topk = tool_parameters.get("topk", 5) line = 3 - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - + return self.create_text_message("Please input query") + line = 4 - candidate_texts = tool_parameters.get('candidate_texts') + candidate_texts = tool_parameters.get("candidate_texts") if not candidate_texts: - return self.create_text_message('Please input candidate_texts') - + return self.create_text_message("Please input candidate_texts") + line = 5 candidate_docs = json.loads(candidate_texts) - docs = [ item.get('content') for item in candidate_docs ] + docs = [item.get("content") for item in candidate_docs] line = 6 scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint) @@ -75,10 +71,10 @@ class SageMakerReRankTool(BuiltinTool): candidate_docs[idx]["score"] = scores[idx] line = 8 - sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x["score"], reverse=True) line = 9 - return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ] - + return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]] + except Exception as e: - return self.create_text_message(f'Exception {str(e)}, line : {line}') \ No newline at end of file + return self.create_text_message(f"Exception {str(e)}, line : {line}") diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py index a100e62230..bceeaab745 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py @@ -14,82 +14,88 @@ class TTSModelType(Enum): CloneVoice_CrossLingual = "CloneVoice_CrossLingual" InstructVoice = "InstructVoice" + class SageMakerTTSTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint:str = None - s3_client : Any = None - comprehend_client : Any = None + sagemaker_endpoint: str = None + s3_client: Any = None + comprehend_client: Any = None - def _detect_lang_code(self, content:str, map_dict:dict=None): - map_dict = { - "zh" : "<|zh|>", - "en" : "<|en|>", - "ja" : "<|jp|>", - "zh-TW" : "<|yue|>", - "ko" : "<|ko|>" - } + def _detect_lang_code(self, content: str, map_dict: dict = None): + map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"} response = self.comprehend_client.detect_dominant_language(Text=content) - language_code = response['Languages'][0]['LanguageCode'] - return map_dict.get(language_code, '<|zh|>') + language_code = response["Languages"][0]["LanguageCode"] + return map_dict.get(language_code, "<|zh|>") - def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str): + def _build_tts_payload( + self, + model_type: str, + content_text: str, + model_role: str, + prompt_text: str, + prompt_audio: str, + instruct_text: str, + ): if model_type == TTSModelType.PresetVoice.value and model_role: - return { "tts_text" : content_text, "role" : model_role } + return {"tts_text": content_text, "role": model_role} if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio: - return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio } - if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: + return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: lang_tag = self._detect_lang_code(content_text) - return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag } - if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: - return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text } + return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag} + if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: + return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text} raise RuntimeError(f"Invalid params for {model_type}") - def _invoke_sagemaker(self, payload:dict, endpoint:str): + def _invoke_sagemaker(self, payload: dict, endpoint: str): response_model = self.sagemaker_client.invoke_endpoint( EndpointName=endpoint, Body=json.dumps(payload), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) return json_obj - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ try: if not self.sagemaker_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) self.s3_client = boto3.client("s3", region_name=aws_region) - self.comprehend_client = boto3.client('comprehend', region_name=aws_region) + self.comprehend_client = boto3.client("comprehend", region_name=aws_region) else: self.sagemaker_client = boto3.client("sagemaker-runtime") self.s3_client = boto3.client("s3") - self.comprehend_client = boto3.client('comprehend') + self.comprehend_client = boto3.client("comprehend") if not self.sagemaker_endpoint: - self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint') + self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") - tts_text = tool_parameters.get('tts_text') - tts_infer_type = tool_parameters.get('tts_infer_type') + tts_text = tool_parameters.get("tts_text") + tts_infer_type = tool_parameters.get("tts_infer_type") - voice = tool_parameters.get('voice') - mock_voice_audio = tool_parameters.get('mock_voice_audio') - mock_voice_text = tool_parameters.get('mock_voice_text') - voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt') - payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt) + voice = tool_parameters.get("voice") + mock_voice_audio = tool_parameters.get("mock_voice_audio") + mock_voice_text = tool_parameters.get("mock_voice_text") + voice_instruct_prompt = tool_parameters.get("voice_instruct_prompt") + payload = self._build_tts_payload( + tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt + ) result = self._invoke_sagemaker(payload, self.sagemaker_endpoint) - return self.create_text_message(text=result['s3_presign_url']) - + return self.create_text_message(text=result["s3_presign_url"]) + except Exception as e: - return self.create_text_message(f'Exception {str(e)}') \ No newline at end of file + return self.create_text_message(f"Exception {str(e)}") diff --git a/api/core/tools/provider/builtin/azuredalle/azuredalle.py b/api/core/tools/provider/builtin/azuredalle/azuredalle.py index 2981a54d3c..1fab0d03a2 100644 --- a/api/core/tools/provider/builtin/azuredalle/azuredalle.py +++ b/api/core/tools/provider/builtin/azuredalle/azuredalle.py @@ -13,12 +13,8 @@ class AzureDALLEProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "prompt": "cute girl, blue eyes, white hair, anime style", - "size": "square", - "n": 1 - }, + user_id="", + tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "square", "n": 1}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 2ffdd38b72..09f30a59d6 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -9,47 +9,48 @@ from core.tools.tool.builtin_tool import BuiltinTool class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ client = AzureOpenAI( - api_version=self.runtime.credentials['azure_openai_api_version'], - azure_endpoint=self.runtime.credentials['azure_openai_base_url'], - api_key=self.runtime.credentials['azure_openai_api_key'], + api_version=self.runtime.credentials["azure_openai_api_version"], + azure_endpoint=self.runtime.credentials["azure_openai_base_url"], + api_key=self.runtime.credentials["azure_openai_api_key"], ) SIZE_MAPPING = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'square')] + size = SIZE_MAPPING[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in ["standard", "hd"]: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in ["natural", "vivid"]: + return self.create_text_message("Invalid style") # set extra body - seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) - extra_body = {'seed': seed_id} + seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) + extra_body = {"seed": seed_id} # call openapi dalle3 - model = self.runtime.credentials['azure_openai_api_model_name'] + model = self.runtime.credentials["azure_openai_api_model_name"] response = client.images.generate( prompt=prompt, model=model, @@ -58,21 +59,25 @@ class DallE3Tool(BuiltinTool): extra_body=extra_body, style=style, quality=quality, - response_format='b64_json' + response_format="b64_json", ) result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value)) - result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}')) + result.append( + self.create_blob_message( + blob=b64decode(image.b64_json), + meta={"mime_type": "image/png"}, + save_as=self.VARIABLE_KEY.IMAGE.value, + ) + ) + result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}")) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index f85a5ed472..0d9613c0cf 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -8,142 +8,135 @@ from core.tools.tool.builtin_tool import BuiltinTool class BingSearchTool(BuiltinTool): - url: str = 'https://api.bing.microsoft.com/v7.0/search' + url: str = "https://api.bing.microsoft.com/v7.0/search" - def _invoke_bing(self, - user_id: str, - server_url: str, - subscription_key: str, query: str, limit: int, - result_type: str, market: str, lang: str, - filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke_bing( + self, + user_id: str, + server_url: str, + subscription_key: str, + query: str, + limit: int, + result_type: str, + market: str, + lang: str, + filters: list[str], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke bing search + invoke bing search """ - market_code = f'{lang}-{market}' - accept_language = f'{lang},{market_code};q=0.9' - headers = { - 'Ocp-Apim-Subscription-Key': subscription_key, - 'Accept-Language': accept_language - } + market_code = f"{lang}-{market}" + accept_language = f"{lang},{market_code};q=0.9" + headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language} query = quote(query) server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}' response = get(server_url, headers=headers) if response.status_code != 200: - raise Exception(f'Error {response.status_code}: {response.text}') - - response = response.json() - search_results = response['webPages']['value'][:limit] if 'webPages' in response else [] - related_searches = response['relatedSearches']['value'] if 'relatedSearches' in response else [] - entities = response['entities']['value'] if 'entities' in response else [] - news = response['news']['value'] if 'news' in response else [] - computation = response['computation']['value'] if 'computation' in response else None + raise Exception(f"Error {response.status_code}: {response.text}") - if result_type == 'link': + response = response.json() + search_results = response["webPages"]["value"][:limit] if "webPages" in response else [] + related_searches = response["relatedSearches"]["value"] if "relatedSearches" in response else [] + entities = response["entities"]["value"] if "entities" in response else [] + news = response["news"]["value"] if "news" in response else [] + computation = response["computation"]["value"] if "computation" in response else None + + if result_type == "link": results = [] if search_results: for result in search_results: url = f': {result["url"]}' if "url" in result else "" - results.append(self.create_text_message( - text=f'{result["name"]}{url}' - )) - + results.append(self.create_text_message(text=f'{result["name"]}{url}')) if entities: for entity in entities: url = f': {entity["url"]}' if "url" in entity else "" - results.append(self.create_text_message( - text=f'{entity.get("name", "")}{url}' - )) + results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}')) if news: for news_item in news: url = f': {news_item["url"]}' if "url" in news_item else "" - results.append(self.create_text_message( - text=f'{news_item.get("name", "")}{url}' - )) + results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}')) if related_searches: for related in related_searches: url = f': {related["displayText"]}' if "displayText" in related else "" - results.append(self.create_text_message( - text=f'{related.get("displayText", "")}{url}' - )) - + results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}')) + return results else: # construct text - text = '' + text = "" if search_results: for i, result in enumerate(search_results): text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n' - if computation and 'expression' in computation and 'value' in computation: - text += '\nComputation:\n' + if computation and "expression" in computation and "value" in computation: + text += "\nComputation:\n" text += f'{computation["expression"]} = {computation["value"]}\n' if entities: - text += '\nEntities:\n' + text += "\nEntities:\n" for entity in entities: url = f'- {entity["url"]}' if "url" in entity else "" text += f'{entity.get("name", "")}{url}\n' if news: - text += '\nNews:\n' + text += "\nNews:\n" for news_item in news: url = f'- {news_item["url"]}' if "url" in news_item else "" text += f'{news_item.get("name", "")}{url}\n' if related_searches: - text += '\n\nRelated Searches:\n' + text += "\n\nRelated Searches:\n" for related in related_searches: url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else "" text += f'{related.get("displayText", "")}{url}\n' return self.create_text_message(text=self.summary(user_id=user_id, content=text)) - def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None: - key = credentials.get('subscription_key') + key = credentials.get("subscription_key") if not key: - raise Exception('subscription_key is required') - - server_url = credentials.get('server_url') + raise Exception("subscription_key is required") + + server_url = credentials.get("server_url") if not server_url: server_url = self.url - query = tool_parameters.get('query') + query = tool_parameters.get("query") if not query: - raise Exception('query is required') - - limit = min(tool_parameters.get('limit', 5), 10) - result_type = tool_parameters.get('result_type', 'text') or 'text' + raise Exception("query is required") - market = tool_parameters.get('market', 'US') - lang = tool_parameters.get('language', 'en') + limit = min(tool_parameters.get("limit", 5), 10) + result_type = tool_parameters.get("result_type", "text") or "text" + + market = tool_parameters.get("market", "US") + lang = tool_parameters.get("language", "en") filter = [] - if credentials.get('allow_entities', False): - filter.append('Entities') + if credentials.get("allow_entities", False): + filter.append("Entities") - if credentials.get('allow_computation', False): - filter.append('Computation') + if credentials.get("allow_computation", False): + filter.append("Computation") - if credentials.get('allow_news', False): - filter.append('News') + if credentials.get("allow_news", False): + filter.append("News") - if credentials.get('allow_related_searches', False): - filter.append('RelatedSearches') + if credentials.get("allow_related_searches", False): + filter.append("RelatedSearches") - if credentials.get('allow_web_pages', False): - filter.append('WebPages') + if credentials.get("allow_web_pages", False): + filter.append("WebPages") if not filter: - raise Exception('At least one filter is required') - + raise Exception("At least one filter is required") + self._invoke_bing( - user_id='test', + user_id="test", server_url=server_url, subscription_key=key, query=query, @@ -151,50 +144,51 @@ class BingSearchTool(BuiltinTool): result_type=result_type, market=market, lang=lang, - filters=filter + filters=filter, ) - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - key = self.runtime.credentials.get('subscription_key', None) + key = self.runtime.credentials.get("subscription_key", None) if not key: - raise Exception('subscription_key is required') - - server_url = self.runtime.credentials.get('server_url', None) + raise Exception("subscription_key is required") + + server_url = self.runtime.credentials.get("server_url", None) if not server_url: server_url = self.url - - query = tool_parameters.get('query') + + query = tool_parameters.get("query") if not query: - raise Exception('query is required') - - limit = min(tool_parameters.get('limit', 5), 10) - result_type = tool_parameters.get('result_type', 'text') or 'text' - - market = tool_parameters.get('market', 'US') - lang = tool_parameters.get('language', 'en') + raise Exception("query is required") + + limit = min(tool_parameters.get("limit", 5), 10) + result_type = tool_parameters.get("result_type", "text") or "text" + + market = tool_parameters.get("market", "US") + lang = tool_parameters.get("language", "en") filter = [] - if tool_parameters.get('enable_computation', False): - filter.append('Computation') - if tool_parameters.get('enable_entities', False): - filter.append('Entities') - if tool_parameters.get('enable_news', False): - filter.append('News') - if tool_parameters.get('enable_related_search', False): - filter.append('RelatedSearches') - if tool_parameters.get('enable_webpages', False): - filter.append('WebPages') + if tool_parameters.get("enable_computation", False): + filter.append("Computation") + if tool_parameters.get("enable_entities", False): + filter.append("Entities") + if tool_parameters.get("enable_news", False): + filter.append("News") + if tool_parameters.get("enable_related_search", False): + filter.append("RelatedSearches") + if tool_parameters.get("enable_webpages", False): + filter.append("WebPages") if not filter: - raise Exception('At least one filter is required') - + raise Exception("At least one filter is required") + return self._invoke_bing( user_id=user_id, server_url=server_url, @@ -204,5 +198,5 @@ class BingSearchTool(BuiltinTool): result_type=result_type, market=market, lang=lang, - filters=filter - ) \ No newline at end of file + filters=filter, + ) diff --git a/api/core/tools/provider/builtin/brave/brave.py b/api/core/tools/provider/builtin/brave/brave.py index e5eada80ee..c24ee67334 100644 --- a/api/core/tools/provider/builtin/brave/brave.py +++ b/api/core/tools/provider/builtin/brave/brave.py @@ -13,11 +13,10 @@ class BraveProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "Sachin Tendulkar", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/brave/tools/brave_search.py b/api/core/tools/provider/builtin/brave/tools/brave_search.py index 21cbf2c7da..94a4d92844 100644 --- a/api/core/tools/provider/builtin/brave/tools/brave_search.py +++ b/api/core/tools/provider/builtin/brave/tools/brave_search.py @@ -37,7 +37,7 @@ class BraveSearchWrapper(BaseModel): for item in web_search_results ] return json.dumps(final_results) - + def _search_request(self, query: str) -> list[dict]: headers = { "X-Subscription-Token": self.api_key, @@ -55,6 +55,7 @@ class BraveSearchWrapper(BaseModel): return response.json().get("web", {}).get("results", []) + class BraveSearch(BaseModel): """Tool that queries the BraveSearch.""" @@ -67,9 +68,7 @@ class BraveSearch(BaseModel): search_wrapper: BraveSearchWrapper @classmethod - def from_api_key( - cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any - ) -> "BraveSearch": + def from_api_key(cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any) -> "BraveSearch": """Create a tool from an api key. Args: @@ -90,6 +89,7 @@ class BraveSearch(BaseModel): """Use the tool.""" return self.search_wrapper.run(query) + class BraveSearchTool(BuiltinTool): """ Tool for performing a search using Brave search engine. @@ -106,12 +106,12 @@ class BraveSearchTool(BuiltinTool): Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - query = tool_parameters.get('query', '') - count = tool_parameters.get('count', 3) - api_key = self.runtime.credentials['brave_search_api_key'] + query = tool_parameters.get("query", "") + count = tool_parameters.get("count", 3) + api_key = self.runtime.credentials["brave_search_api_key"] if not query: - return self.create_text_message('Please input query') + return self.create_text_message("Please input query") tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count}) @@ -121,4 +121,3 @@ class BraveSearchTool(BuiltinTool): return self.create_text_message(f"No results found for '{query}' in Tavily") else: return self.create_text_message(text=results) - diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index 0865bc700a..8a24d33428 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -7,16 +7,34 @@ from core.tools.provider.builtin.chart.tools.line import LinearChartTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController # use a business theme -plt.style.use('seaborn-v0_8-darkgrid') -plt.rcParams['axes.unicode_minus'] = False +plt.style.use("seaborn-v0_8-darkgrid") +plt.rcParams["axes.unicode_minus"] = False + def init_fonts(): fonts = findSystemFonts() popular_unicode_fonts = [ - 'Arial Unicode MS', 'DejaVu Sans', 'DejaVu Sans Mono', 'DejaVu Serif', 'FreeMono', 'FreeSans', 'FreeSerif', - 'Liberation Mono', 'Liberation Sans', 'Liberation Serif', 'Noto Mono', 'Noto Sans', 'Noto Serif', 'Open Sans', - 'Roboto', 'Source Code Pro', 'Source Sans Pro', 'Source Serif Pro', 'Ubuntu', 'Ubuntu Mono' + "Arial Unicode MS", + "DejaVu Sans", + "DejaVu Sans Mono", + "DejaVu Serif", + "FreeMono", + "FreeSans", + "FreeSerif", + "Liberation Mono", + "Liberation Sans", + "Liberation Serif", + "Noto Mono", + "Noto Sans", + "Noto Serif", + "Open Sans", + "Roboto", + "Source Code Pro", + "Source Sans Pro", + "Source Serif Pro", + "Ubuntu", + "Ubuntu Mono", ] supported_fonts = [] @@ -25,21 +43,23 @@ def init_fonts(): try: font = TTFont(font_path) # get family name - family_name = font['name'].getName(1, 3, 1).toUnicode() + family_name = font["name"].getName(1, 3, 1).toUnicode() if family_name in popular_unicode_fonts: supported_fonts.append(family_name) except: pass - plt.rcParams['font.family'] = 'sans-serif' + plt.rcParams["font.family"] = "sans-serif" # sort by order of popular_unicode_fonts for font in popular_unicode_fonts: if font in supported_fonts: - plt.rcParams['font.sans-serif'] = font + plt.rcParams["font.sans-serif"] = font break - + + init_fonts() + class ChartProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: @@ -48,11 +68,10 @@ class ChartProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "data": "1,3,5,7,9,2,4,6,8,10", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/bar.py b/api/core/tools/provider/builtin/chart/tools/bar.py index 749ec761c6..3a47c0cfc0 100644 --- a/api/core/tools/provider/builtin/chart/tools/bar.py +++ b/api/core/tools/provider/builtin/chart/tools/bar.py @@ -8,12 +8,13 @@ from core.tools.tool.builtin_tool import BuiltinTool class BarChartTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') + return self.create_text_message("Please input data") + data = data.split(";") # if all data is int, convert to int if all(i.isdigit() for i in data): @@ -21,29 +22,27 @@ class BarChartTool(BuiltinTool): else: data = [float(i) for i in data] - axis = tool_parameters.get('x_axis') or None + axis = tool_parameters.get("x_axis") or None if axis: - axis = axis.split(';') + axis = axis.split(";") if len(axis) != len(data): axis = None flg, ax = plt.subplots(figsize=(10, 8)) if axis: - axis = [label[:10] + '...' if len(label) > 10 else label for label in axis] - ax.set_xticklabels(axis, rotation=45, ha='right') + axis = [label[:10] + "..." if len(label) > 10 else label for label in axis] + ax.set_xticklabels(axis, rotation=45, ha="right") ax.bar(axis, data) else: ax.bar(range(len(data)), data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the bar chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) + self.create_text_message("the bar chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/line.py b/api/core/tools/provider/builtin/chart/tools/line.py index 608bd6623c..39e8caac7e 100644 --- a/api/core/tools/provider/builtin/chart/tools/line.py +++ b/api/core/tools/provider/builtin/chart/tools/line.py @@ -8,18 +8,19 @@ from core.tools.tool.builtin_tool import BuiltinTool class LinearChartTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') + return self.create_text_message("Please input data") + data = data.split(";") - axis = tool_parameters.get('x_axis') or None + axis = tool_parameters.get("x_axis") or None if axis: - axis = axis.split(';') + axis = axis.split(";") if len(axis) != len(data): axis = None @@ -32,20 +33,18 @@ class LinearChartTool(BuiltinTool): flg, ax = plt.subplots(figsize=(10, 8)) if axis: - axis = [label[:10] + '...' if len(label) > 10 else label for label in axis] - ax.set_xticklabels(axis, rotation=45, ha='right') + axis = [label[:10] + "..." if len(label) > 10 else label for label in axis] + ax.set_xticklabels(axis, rotation=45, ha="right") ax.plot(axis, data) else: ax.plot(data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the linear chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) + self.create_text_message("the linear chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/pie.py b/api/core/tools/provider/builtin/chart/tools/pie.py index 4c551229e9..2c3b8a733e 100644 --- a/api/core/tools/provider/builtin/chart/tools/pie.py +++ b/api/core/tools/provider/builtin/chart/tools/pie.py @@ -8,15 +8,16 @@ from core.tools.tool.builtin_tool import BuiltinTool class PieChartTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') - categories = tool_parameters.get('categories') or None + return self.create_text_message("Please input data") + data = data.split(";") + categories = tool_parameters.get("categories") or None # if all data is int, convert to int if all(i.isdigit() for i in data): @@ -27,7 +28,7 @@ class PieChartTool(BuiltinTool): flg, ax = plt.subplots() if categories: - categories = categories.split(';') + categories = categories.split(";") if len(categories) != len(data): categories = None @@ -37,12 +38,11 @@ class PieChartTool(BuiltinTool): ax.pie(data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the pie chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) - ] \ No newline at end of file + self.create_text_message("the pie chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), + ] diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.py b/api/core/tools/provider/builtin/code/tools/simple_code.py index 37645bf0d0..017fe548f7 100644 --- a/api/core/tools/provider/builtin/code/tools/simple_code.py +++ b/api/core/tools/provider/builtin/code/tools/simple_code.py @@ -8,15 +8,15 @@ from core.tools.tool.builtin_tool import BuiltinTool class SimpleCode(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ - invoke simple code + invoke simple code """ - language = tool_parameters.get('language', CodeLanguage.PYTHON3) - code = tool_parameters.get('code', '') + language = tool_parameters.get("language", CodeLanguage.PYTHON3) + code = tool_parameters.get("code", "") if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]: - raise ValueError(f'Only python3 and javascript are supported, not {language}') - - result = CodeExecutor.execute_code(language, '', code) + raise ValueError(f"Only python3 and javascript are supported, not {language}") - return self.create_text_message(result) \ No newline at end of file + result = CodeExecutor.execute_code(language, "", code) + + return self.create_text_message(result) diff --git a/api/core/tools/provider/builtin/cogview/cogview.py b/api/core/tools/provider/builtin/cogview/cogview.py index 801817ec06..6941ce8649 100644 --- a/api/core/tools/provider/builtin/cogview/cogview.py +++ b/api/core/tools/provider/builtin/cogview/cogview.py @@ -1,4 +1,5 @@ -""" Provide the input parameters type for the cogview provider class """ +"""Provide the input parameters type for the cogview provider class""" + from typing import Any from core.tools.errors import ToolProviderCredentialValidationError @@ -7,7 +8,8 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class COGVIEWProvider(BuiltinToolProviderController): - """ cogview provider """ + """cogview provider""" + def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: CogView3Tool().fork_tool_runtime( @@ -15,13 +17,12 @@ class COGVIEWProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。", "size": "square", - "n": 1 + "n": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) from e - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 89ffcf3347..9776bd7dd1 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -7,43 +7,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class CogView3Tool(BuiltinTool): - """ CogView3 Tool """ + """CogView3 Tool""" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke CogView3 tool """ client = ZhipuAI( - base_url=self.runtime.credentials['zhipuai_base_url'], - api_key=self.runtime.credentials['zhipuai_api_key'], + base_url=self.runtime.credentials["zhipuai_base_url"], + api_key=self.runtime.credentials["zhipuai_api_key"], ) size_mapping = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = size_mapping[tool_parameters.get('size', 'square')] + size = size_mapping[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in ["standard", "hd"]: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in ["natural", "vivid"]: + return self.create_text_message("Invalid style") # set extra body - seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) - extra_body = {'seed': seed_id} + seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) + extra_body = {"seed": seed_id} response = client.images.generations( prompt=prompt, model="cogview-3", @@ -52,18 +51,22 @@ class CogView3Tool(BuiltinTool): extra_body=extra_body, style=style, quality=quality, - response_format='b64_json' + response_format="b64_json", ) result = [] for image in response.data: result.append(self.create_image_message(image=image.url)) - result.append(self.create_json_message({ - "url": image.url, - })) + result.append( + self.create_json_message( + { + "url": image.url, + } + ) + ) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/crossref/crossref.py b/api/core/tools/provider/builtin/crossref/crossref.py index 404e483e0d..8ba3c1b48a 100644 --- a/api/core/tools/provider/builtin/crossref/crossref.py +++ b/api/core/tools/provider/builtin/crossref/crossref.py @@ -11,9 +11,9 @@ class CrossRefProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - "doi": '10.1007/s00894-022-05373-8', + "doi": "10.1007/s00894-022-05373-8", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/crossref/tools/query_doi.py b/api/core/tools/provider/builtin/crossref/tools/query_doi.py index a43c0989e4..746139dd69 100644 --- a/api/core/tools/provider/builtin/crossref/tools/query_doi.py +++ b/api/core/tools/provider/builtin/crossref/tools/query_doi.py @@ -11,15 +11,18 @@ class CrossRefQueryDOITool(BuiltinTool): """ Tool for querying the metadata of a publication using its DOI. """ - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - doi = tool_parameters.get('doi') + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + doi = tool_parameters.get("doi") if not doi: - raise ToolParameterValidationError('doi is required.') + raise ToolParameterValidationError("doi is required.") # doc: https://github.com/CrossRef/rest-api-doc url = f"https://api.crossref.org/works/{doi}" response = requests.get(url) response.raise_for_status() response = response.json() - message = response.get('message', {}) + message = response.get("message", {}) return self.create_json_message(message) diff --git a/api/core/tools/provider/builtin/crossref/tools/query_title.py b/api/core/tools/provider/builtin/crossref/tools/query_title.py index 946aa6dc94..e245238183 100644 --- a/api/core/tools/provider/builtin/crossref/tools/query_title.py +++ b/api/core/tools/provider/builtin/crossref/tools/query_title.py @@ -12,16 +12,16 @@ def convert_time_str_to_seconds(time_str: str) -> int: Convert a time string to seconds. example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430 """ - time_str = time_str.lower().strip().replace(' ', '') + time_str = time_str.lower().strip().replace(" ", "") seconds = 0 - if 'h' in time_str: - hours, time_str = time_str.split('h') + if "h" in time_str: + hours, time_str = time_str.split("h") seconds += int(hours) * 3600 - if 'm' in time_str: - minutes, time_str = time_str.split('m') + if "m" in time_str: + minutes, time_str = time_str.split("m") seconds += int(minutes) * 60 - if 's' in time_str: - seconds += int(time_str.replace('s', '')) + if "s" in time_str: + seconds += int(time_str.replace("s", "")) return seconds @@ -30,6 +30,7 @@ class CrossRefQueryTitleAPI: Tool for querying the metadata of a publication using its title. Crossref API doc: https://github.com/CrossRef/rest-api-doc """ + query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}" rate_limit: int = 50 rate_interval: float = 1 @@ -38,7 +39,15 @@ class CrossRefQueryTitleAPI: def __init__(self, mailto: str): self.mailto = mailto - def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]: + def _query( + self, + query: str, + rows: int = 5, + offset: int = 0, + sort: str = "relevance", + order: str = "desc", + fuzzy_query: bool = False, + ) -> list[dict]: """ Query the metadata of a publication using its title. :param query: the title of the publication @@ -47,33 +56,37 @@ class CrossRefQueryTitleAPI: :param order: the sort order :param fuzzy_query: whether to return all items that match the query """ - url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto) + url = self.query_url_template.format( + query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto + ) response = requests.get(url) response.raise_for_status() - rate_limit = int(response.headers['x-ratelimit-limit']) + rate_limit = int(response.headers["x-ratelimit-limit"]) # convert time string to seconds - rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval']) + rate_interval = convert_time_str_to_seconds(response.headers["x-ratelimit-interval"]) self.rate_limit = rate_limit self.rate_interval = rate_interval response = response.json() - if response['status'] != 'ok': + if response["status"] != "ok": return [] - message = response['message'] + message = response["message"] if fuzzy_query: # fuzzy query return all items - return message['items'] + return message["items"] else: - for paper in message['items']: - title = paper['title'][0] + for paper in message["items"]: + title = paper["title"][0] if title.lower() != query.lower(): continue return [paper] return [] - def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]: + def query( + self, query: str, rows: int = 5, sort: str = "relevance", order: str = "desc", fuzzy_query: bool = False + ) -> list[dict]: """ Query the metadata of a publication using its title. :param query: the title of the publication @@ -89,7 +102,14 @@ class CrossRefQueryTitleAPI: results = [] for i in range(query_times): - result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query) + result = self._query( + query, + rows=self.rate_limit, + offset=i * self.rate_limit, + sort=sort, + order=order, + fuzzy_query=fuzzy_query, + ) if fuzzy_query: results.extend(result) else: @@ -107,13 +127,16 @@ class CrossRefQueryTitleTool(BuiltinTool): """ Tool for querying the metadata of a publication using its title. """ - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters.get('query') - fuzzy_query = tool_parameters.get('fuzzy_query', False) - rows = tool_parameters.get('rows', 3) - sort = tool_parameters.get('sort', 'relevance') - order = tool_parameters.get('order', 'desc') - mailto = self.runtime.credentials['mailto'] + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters.get("query") + fuzzy_query = tool_parameters.get("fuzzy_query", False) + rows = tool_parameters.get("rows", 3) + sort = tool_parameters.get("sort", "relevance") + order = tool_parameters.get("order", "desc") + mailto = self.runtime.credentials["mailto"] result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query) diff --git a/api/core/tools/provider/builtin/dalle/dalle.py b/api/core/tools/provider/builtin/dalle/dalle.py index 1c8019364d..5bd16e49e8 100644 --- a/api/core/tools/provider/builtin/dalle/dalle.py +++ b/api/core/tools/provider/builtin/dalle/dalle.py @@ -13,13 +13,8 @@ class DALLEProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "prompt": "cute girl, blue eyes, white hair, anime style", - "size": "small", - "n": 1 - }, + user_id="", + tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "small", "n": 1}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py index 9e9f32d429..ac7e394911 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py @@ -9,59 +9,58 @@ from core.tools.tool.builtin_tool import BuiltinTool class DallE2Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - openai_organization = self.runtime.credentials.get('openai_organization_id', None) + openai_organization = self.runtime.credentials.get("openai_organization_id", None) if not openai_organization: openai_organization = None - openai_base_url = self.runtime.credentials.get('openai_base_url', None) + openai_base_url = self.runtime.credentials.get("openai_base_url", None) if not openai_base_url: openai_base_url = None else: - openai_base_url = str(URL(openai_base_url) / 'v1') + openai_base_url = str(URL(openai_base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['openai_api_key'], + api_key=self.runtime.credentials["openai_api_key"], base_url=openai_base_url, - organization=openai_organization + organization=openai_organization, ) SIZE_MAPPING = { - 'small': '256x256', - 'medium': '512x512', - 'large': '1024x1024', + "small": "256x256", + "medium": "512x512", + "large": "1024x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') - + return self.create_text_message("Please input prompt") + # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'large')] + size = SIZE_MAPPING[tool_parameters.get("size", "large")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # call openapi dalle2 - response = client.images.generate( - prompt=prompt, - model='dall-e-2', - size=size, - n=n, - response_format='b64_json' - ) + response = client.images.generate(prompt=prompt, model="dall-e-2", size=size, n=n, response_format="b64_json") result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append( + self.create_blob_message( + blob=b64decode(image.b64_json), + meta={"mime_type": "image/png"}, + save_as=self.VARIABLE_KEY.IMAGE.value, + ) + ) return result diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index 4f5033dd7f..2d62cf608f 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -10,69 +10,64 @@ from core.tools.tool.builtin_tool import BuiltinTool class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - openai_organization = self.runtime.credentials.get('openai_organization_id', None) + openai_organization = self.runtime.credentials.get("openai_organization_id", None) if not openai_organization: openai_organization = None - openai_base_url = self.runtime.credentials.get('openai_base_url', None) + openai_base_url = self.runtime.credentials.get("openai_base_url", None) if not openai_base_url: openai_base_url = None else: - openai_base_url = str(URL(openai_base_url) / 'v1') + openai_base_url = str(URL(openai_base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['openai_api_key'], + api_key=self.runtime.credentials["openai_api_key"], base_url=openai_base_url, - organization=openai_organization + organization=openai_organization, ) SIZE_MAPPING = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'square')] + size = SIZE_MAPPING[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in ["standard", "hd"]: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in ["natural", "vivid"]: + return self.create_text_message("Invalid style") # call openapi dalle3 response = client.images.generate( - prompt=prompt, - model='dall-e-3', - size=size, - n=n, - style=style, - quality=quality, - response_format='b64_json' + prompt=prompt, model="dall-e-3", size=size, n=n, style=style, quality=quality, response_format="b64_json" ) result = [] for image in response.data: mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) - blob_message = self.create_blob_message(blob=blob_image, - meta={'mime_type': mime_type}, - save_as=self.VARIABLE_KEY.IMAGE.value) + blob_message = self.create_blob_message( + blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VARIABLE_KEY.IMAGE.value + ) result.append(blob_message) return result @@ -86,7 +81,7 @@ class DallE3Tool(BuiltinTool): :return: A tuple containing the MIME type and the decoded image bytes """ if DallE3Tool._is_plain_base64(base64_image): - return 'image/png', base64.b64decode(base64_image) + return "image/png", base64.b64decode(base64_image) else: return DallE3Tool._extract_mime_and_data(base64_image) @@ -98,7 +93,7 @@ class DallE3Tool(BuiltinTool): :param encoded_str: Base64 encoded image string :return: True if the string is plain base64, False otherwise """ - return not encoded_str.startswith('data:image') + return not encoded_str.startswith("data:image") @staticmethod def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]: @@ -108,13 +103,13 @@ class DallE3Tool(BuiltinTool): :param encoded_str: Base64 encoded image string with MIME type prefix :return: A tuple containing the MIME type and the decoded image bytes """ - mime_type = encoded_str.split(';')[0].split(':')[1] - image_data_base64 = encoded_str.split(',')[1] + mime_type = encoded_str.split(";")[0].split(":")[1] + image_data_base64 = encoded_str.split(",")[1] decoded_data = base64.b64decode(image_data_base64) return mime_type, decoded_data @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/devdocs/devdocs.py b/api/core/tools/provider/builtin/devdocs/devdocs.py index 95d7939d0d..446c1e5489 100644 --- a/api/core/tools/provider/builtin/devdocs/devdocs.py +++ b/api/core/tools/provider/builtin/devdocs/devdocs.py @@ -11,7 +11,7 @@ class DevDocsProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "doc": "python~3.12", "topic": "library/code", @@ -19,4 +19,3 @@ class DevDocsProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py b/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py index 1a244c5db3..e1effd066c 100644 --- a/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py +++ b/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py @@ -13,7 +13,9 @@ class SearchDevDocsInput(BaseModel): class SearchDevDocsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invokes the DevDocs search tool with the given user ID and tool parameters. @@ -24,13 +26,13 @@ class SearchDevDocsTool(BuiltinTool): Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. """ - doc = tool_parameters.get('doc', '') - topic = tool_parameters.get('topic', '') + doc = tool_parameters.get("doc", "") + topic = tool_parameters.get("topic", "") if not doc: - return self.create_text_message('Please provide the documentation name.') + return self.create_text_message("Please provide the documentation name.") if not topic: - return self.create_text_message('Please provide the topic path.') + return self.create_text_message("Please provide the topic path.") url = f"https://documents.devdocs.io/{doc}/{topic}.html" response = requests.get(url) @@ -39,4 +41,6 @@ class SearchDevDocsTool(BuiltinTool): content = response.text return self.create_text_message(self.summary(user_id=user_id, content=content)) else: - return self.create_text_message(f"Failed to retrieve the documentation. Status code: {response.status_code}") \ No newline at end of file + return self.create_text_message( + f"Failed to retrieve the documentation. Status code: {response.status_code}" + ) diff --git a/api/core/tools/provider/builtin/did/did.py b/api/core/tools/provider/builtin/did/did.py index b4bf172131..5af78794f6 100644 --- a/api/core/tools/provider/builtin/did/did.py +++ b/api/core/tools/provider/builtin/did/did.py @@ -7,15 +7,12 @@ class DIDProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the D-ID talks tool - TalksTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', + TalksTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={ "source_url": "https://www.d-id.com/wp-content/uploads/2023/11/Hero-image-1.png", "text_input": "Hello, welcome to use D-ID tool in Dify", - } + }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/did/did_appx.py b/api/core/tools/provider/builtin/did/did_appx.py index 964e82b729..4cad12e4ee 100644 --- a/api/core/tools/provider/builtin/did/did_appx.py +++ b/api/core/tools/provider/builtin/did/did_appx.py @@ -12,14 +12,14 @@ logger = logging.getLogger(__name__) class DIDApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.d-id.com' + self.base_url = base_url or "https://api.d-id.com" if not self.api_key: - raise ValueError('API key is required') + raise ValueError("API key is required") def _prepare_headers(self, idempotency_key: str | None = None): - headers = {'Content-Type': 'application/json', 'Authorization': f'Basic {self.api_key}'} + headers = {"Content-Type": "application/json", "Authorization": f"Basic {self.api_key}"} if idempotency_key: - headers['Idempotency-Key'] = idempotency_key + headers["Idempotency-Key"] = idempotency_key return headers def _request( @@ -44,44 +44,44 @@ class DIDApp: return None def talks(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs): - endpoint = f'{self.base_url}/talks' + endpoint = f"{self.base_url}/talks" headers = self._prepare_headers(idempotency_key) - data = kwargs['params'] - logger.debug(f'Send request to {endpoint=} body={data}') - response = self._request('POST', endpoint, data, headers) + data = kwargs["params"] + logger.debug(f"Send request to {endpoint=} body={data}") + response = self._request("POST", endpoint, data, headers) if response is None: - raise HTTPError('Failed to initiate D-ID talks after multiple retries') - id: str = response['id'] + raise HTTPError("Failed to initiate D-ID talks after multiple retries") + id: str = response["id"] if wait: - return self._monitor_job_status(id=id, target='talks', poll_interval=poll_interval) + return self._monitor_job_status(id=id, target="talks", poll_interval=poll_interval) return id def animations(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs): - endpoint = f'{self.base_url}/animations' + endpoint = f"{self.base_url}/animations" headers = self._prepare_headers(idempotency_key) - data = kwargs['params'] - logger.debug(f'Send request to {endpoint=} body={data}') - response = self._request('POST', endpoint, data, headers) + data = kwargs["params"] + logger.debug(f"Send request to {endpoint=} body={data}") + response = self._request("POST", endpoint, data, headers) if response is None: - raise HTTPError('Failed to initiate D-ID talks after multiple retries') - id: str = response['id'] + raise HTTPError("Failed to initiate D-ID talks after multiple retries") + id: str = response["id"] if wait: - return self._monitor_job_status(target='animations', id=id, poll_interval=poll_interval) + return self._monitor_job_status(target="animations", id=id, poll_interval=poll_interval) return id def check_did_status(self, target: str, id: str): - endpoint = f'{self.base_url}/{target}/{id}' + endpoint = f"{self.base_url}/{target}/{id}" headers = self._prepare_headers() - response = self._request('GET', endpoint, headers=headers) + response = self._request("GET", endpoint, headers=headers) if response is None: - raise HTTPError(f'Failed to check status for talks {id} after multiple retries') + raise HTTPError(f"Failed to check status for talks {id} after multiple retries") return response def _monitor_job_status(self, target: str, id: str, poll_interval: int): while True: status = self.check_did_status(target=target, id=id) - if status['status'] == 'done': + if status["status"] == "done": return status - elif status['status'] == 'error' or status['status'] == 'rejected': + elif status["status"] == "error" or status["status"] == "rejected": raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}') time.sleep(poll_interval) diff --git a/api/core/tools/provider/builtin/did/tools/animations.py b/api/core/tools/provider/builtin/did/tools/animations.py index e1d9de603f..bc9d17e40d 100644 --- a/api/core/tools/provider/builtin/did/tools/animations.py +++ b/api/core/tools/provider/builtin/did/tools/animations.py @@ -10,33 +10,33 @@ class AnimationsTool(BuiltinTool): def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url']) + app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"]) - driver_expressions_str = tool_parameters.get('driver_expressions') + driver_expressions_str = tool_parameters.get("driver_expressions") driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None config = { - 'stitch': tool_parameters.get('stitch', True), - 'mute': tool_parameters.get('mute'), - 'result_format': tool_parameters.get('result_format') or 'mp4', + "stitch": tool_parameters.get("stitch", True), + "mute": tool_parameters.get("mute"), + "result_format": tool_parameters.get("result_format") or "mp4", } - config = {k: v for k, v in config.items() if v is not None and v != ''} + config = {k: v for k, v in config.items() if v is not None and v != ""} options = { - 'source_url': tool_parameters['source_url'], - 'driver_url': tool_parameters.get('driver_url'), - 'config': config, + "source_url": tool_parameters["source_url"], + "driver_url": tool_parameters.get("driver_url"), + "config": config, } - options = {k: v for k, v in options.items() if v is not None and v != ''} + options = {k: v for k, v in options.items() if v is not None and v != ""} - if not options.get('source_url'): - raise ValueError('Source URL is required') + if not options.get("source_url"): + raise ValueError("Source URL is required") - if config.get('logo_url'): - if not config.get('logo_x'): - raise ValueError('Logo X position is required when logo URL is provided') - if not config.get('logo_y'): - raise ValueError('Logo Y position is required when logo URL is provided') + if config.get("logo_url"): + if not config.get("logo_x"): + raise ValueError("Logo X position is required when logo URL is provided") + if not config.get("logo_y"): + raise ValueError("Logo Y position is required when logo URL is provided") animations_result = app.animations(params=options, wait=True) @@ -44,6 +44,6 @@ class AnimationsTool(BuiltinTool): animations_result = json.dumps(animations_result, ensure_ascii=False, indent=4) if not animations_result: - return self.create_text_message('D-ID animations request failed.') + return self.create_text_message("D-ID animations request failed.") return self.create_text_message(animations_result) diff --git a/api/core/tools/provider/builtin/did/tools/talks.py b/api/core/tools/provider/builtin/did/tools/talks.py index 06b2c4cb2f..d6f0c7ff17 100644 --- a/api/core/tools/provider/builtin/did/tools/talks.py +++ b/api/core/tools/provider/builtin/did/tools/talks.py @@ -10,49 +10,49 @@ class TalksTool(BuiltinTool): def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url']) + app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"]) - driver_expressions_str = tool_parameters.get('driver_expressions') + driver_expressions_str = tool_parameters.get("driver_expressions") driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None script = { - 'type': tool_parameters.get('script_type') or 'text', - 'input': tool_parameters.get('text_input'), - 'audio_url': tool_parameters.get('audio_url'), - 'reduce_noise': tool_parameters.get('audio_reduce_noise', False), + "type": tool_parameters.get("script_type") or "text", + "input": tool_parameters.get("text_input"), + "audio_url": tool_parameters.get("audio_url"), + "reduce_noise": tool_parameters.get("audio_reduce_noise", False), } - script = {k: v for k, v in script.items() if v is not None and v != ''} + script = {k: v for k, v in script.items() if v is not None and v != ""} config = { - 'stitch': tool_parameters.get('stitch', True), - 'sharpen': tool_parameters.get('sharpen'), - 'fluent': tool_parameters.get('fluent'), - 'result_format': tool_parameters.get('result_format') or 'mp4', - 'pad_audio': tool_parameters.get('pad_audio'), - 'driver_expressions': driver_expressions, + "stitch": tool_parameters.get("stitch", True), + "sharpen": tool_parameters.get("sharpen"), + "fluent": tool_parameters.get("fluent"), + "result_format": tool_parameters.get("result_format") or "mp4", + "pad_audio": tool_parameters.get("pad_audio"), + "driver_expressions": driver_expressions, } - config = {k: v for k, v in config.items() if v is not None and v != ''} + config = {k: v for k, v in config.items() if v is not None and v != ""} options = { - 'source_url': tool_parameters['source_url'], - 'driver_url': tool_parameters.get('driver_url'), - 'script': script, - 'config': config, + "source_url": tool_parameters["source_url"], + "driver_url": tool_parameters.get("driver_url"), + "script": script, + "config": config, } - options = {k: v for k, v in options.items() if v is not None and v != ''} + options = {k: v for k, v in options.items() if v is not None and v != ""} - if not options.get('source_url'): - raise ValueError('Source URL is required') + if not options.get("source_url"): + raise ValueError("Source URL is required") - if script.get('type') == 'audio': - script.pop('input', None) - if not script.get('audio_url'): - raise ValueError('Audio URL is required for audio script type') + if script.get("type") == "audio": + script.pop("input", None) + if not script.get("audio_url"): + raise ValueError("Audio URL is required for audio script type") - if script.get('type') == 'text': - script.pop('audio_url', None) - script.pop('reduce_noise', None) - if not script.get('input'): - raise ValueError('Text input is required for text script type') + if script.get("type") == "text": + script.pop("audio_url", None) + script.pop("reduce_noise", None) + if not script.get("input"): + raise ValueError("Text input is required for text script type") talks_result = app.talks(params=options, wait=True) @@ -60,6 +60,6 @@ class TalksTool(BuiltinTool): talks_result = json.dumps(talks_result, ensure_ascii=False, indent=4) if not talks_result: - return self.create_text_message('D-ID talks request failed.') + return self.create_text_message("D-ID talks request failed.") return self.create_text_message(talks_result) diff --git a/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py index c247c3bd6b..f33ad5be59 100644 --- a/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py +++ b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py @@ -13,38 +13,43 @@ from core.tools.tool.builtin_tool import BuiltinTool class DingTalkGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools - Dingtalk custom group robot API docs: - https://open.dingtalk.com/document/orgapp/custom-robot-access + invoke tools + Dingtalk custom group robot API docs: + https://open.dingtalk.com/document/orgapp/custom-robot-access """ - content = tool_parameters.get('content') + content = tool_parameters.get("content") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - access_token = tool_parameters.get('access_token') + access_token = tool_parameters.get("access_token") if not access_token: - return self.create_text_message('Invalid parameter access_token. ' - 'Regarding information about security details,' - 'please refer to the DingTalk docs:' - 'https://open.dingtalk.com/document/robots/customize-robot-security-settings') + return self.create_text_message( + "Invalid parameter access_token. " + "Regarding information about security details," + "please refer to the DingTalk docs:" + "https://open.dingtalk.com/document/robots/customize-robot-security-settings" + ) - sign_secret = tool_parameters.get('sign_secret') + sign_secret = tool_parameters.get("sign_secret") if not sign_secret: - return self.create_text_message('Invalid parameter sign_secret. ' - 'Regarding information about security details,' - 'please refer to the DingTalk docs:' - 'https://open.dingtalk.com/document/robots/customize-robot-security-settings') + return self.create_text_message( + "Invalid parameter sign_secret. " + "Regarding information about security details," + "please refer to the DingTalk docs:" + "https://open.dingtalk.com/document/robots/customize-robot-security-settings" + ) - msgtype = 'text' - api_url = 'https://oapi.dingtalk.com/robot/send' + msgtype = "text" + api_url = "https://oapi.dingtalk.com/robot/send" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = { - 'access_token': access_token, + "access_token": access_token, } self._apply_security_mechanism(params, sign_secret) @@ -53,7 +58,7 @@ class DingTalkGroupBotTool(BuiltinTool): "msgtype": msgtype, "text": { "content": content, - } + }, } try: @@ -62,7 +67,8 @@ class DingTalkGroupBotTool(BuiltinTool): return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) @@ -70,14 +76,14 @@ class DingTalkGroupBotTool(BuiltinTool): def _apply_security_mechanism(params: dict[str, Any], sign_secret: str): try: timestamp = str(round(time.time() * 1000)) - secret_enc = sign_secret.encode('utf-8') - string_to_sign = f'{timestamp}\n{sign_secret}' - string_to_sign_enc = string_to_sign.encode('utf-8') + secret_enc = sign_secret.encode("utf-8") + string_to_sign = f"{timestamp}\n{sign_secret}" + string_to_sign_enc = string_to_sign.encode("utf-8") hmac_code = hmac.new(secret_enc, string_to_sign_enc, digestmod=hashlib.sha256).digest() sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) - params['timestamp'] = timestamp - params['sign'] = sign + params["timestamp"] = timestamp + params["sign"] = sign except Exception: msg = "Failed to apply security mechanism to the request." logging.exception(msg) diff --git a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py index 2292e89fa6..8269167127 100644 --- a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py +++ b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py @@ -11,11 +11,10 @@ class DuckDuckGoProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py index 878b0d8645..8bdd638f4a 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py @@ -13,8 +13,8 @@ class DuckDuckGoAITool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: query_dict = { - "keywords": tool_parameters.get('query'), - "model": tool_parameters.get('model'), + "keywords": tool_parameters.get("query"), + "model": tool_parameters.get("model"), } response = DDGS().chat(**query_dict) return self.create_text_message(text=response) diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py index bca53f6b4b..396570248a 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py @@ -14,18 +14,17 @@ class DuckDuckGoImageSearchTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: query_dict = { - "keywords": tool_parameters.get('query'), - "timelimit": tool_parameters.get('timelimit'), - "size": tool_parameters.get('size'), - "max_results": tool_parameters.get('max_results'), + "keywords": tool_parameters.get("query"), + "timelimit": tool_parameters.get("timelimit"), + "size": tool_parameters.get("size"), + "max_results": tool_parameters.get("max_results"), } response = DDGS().images(**query_dict) result = [] for res in response: - res['transfer_method'] = FileTransferMethod.REMOTE_URL - msg = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=res.get('image'), - save_as='', - meta=res) + res["transfer_method"] = FileTransferMethod.REMOTE_URL + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=res.get("image"), save_as="", meta=res + ) result.append(msg) return result diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py index dfaeb734d8..cbd65d2e77 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py @@ -21,10 +21,11 @@ class DuckDuckGoSearchTool(BuiltinTool): """ Tool for performing a search using DuckDuckGo search engine. """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: - query = tool_parameters.get('query') - max_results = tool_parameters.get('max_results', 5) - require_summary = tool_parameters.get('require_summary', False) + query = tool_parameters.get("query") + max_results = tool_parameters.get("max_results", 5) + require_summary = tool_parameters.get("require_summary", False) response = DDGS().text(query, max_results=max_results) if require_summary: results = "\n".join([res.get("body") for res in response]) @@ -34,7 +35,11 @@ class DuckDuckGoSearchTool(BuiltinTool): def summary_results(self, user_id: str, content: str, query: str) -> str: prompt = SUMMARY_PROMPT.format(query=query, content=content) - summary = self.invoke_model(user_id=user_id, prompt_messages=[ - SystemPromptMessage(content=prompt), - ], stop=[]) + summary = self.invoke_model( + user_id=user_id, + prompt_messages=[ + SystemPromptMessage(content=prompt), + ], + stop=[], + ) return summary.message.content diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py index 9822b37cf0..396ce21b18 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py @@ -13,8 +13,8 @@ class DuckDuckGoTranslateTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: query_dict = { - "keywords": tool_parameters.get('query'), - "to": tool_parameters.get('translate_to'), + "keywords": tool_parameters.get("query"), + "to": tool_parameters.get("translate_to"), } - response = DDGS().translate(**query_dict)[0].get('translated', 'Unable to translate!') + response = DDGS().translate(**query_dict)[0].get("translated", "Unable to translate!") return self.create_text_message(text=response) diff --git a/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py b/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py index e8ab02f55e..e82da8ca53 100644 --- a/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py +++ b/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py @@ -8,35 +8,35 @@ from core.tools.utils.uuid_utils import is_valid_uuid class FeishuGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools - API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot + invoke tools + API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot """ url = "https://open.feishu.cn/open-apis/bot/v2/hook" - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - hook_key = tool_parameters.get('hook_key', '') + hook_key = tool_parameters.get("hook_key", "") if not is_valid_uuid(hook_key): - return self.create_text_message( - f'Invalid parameter hook_key ${hook_key}, not a valid UUID') + return self.create_text_message(f"Invalid parameter hook_key ${hook_key}, not a valid UUID") - msg_type = 'text' - api_url = f'{url}/{hook_key}' + msg_type = "text" + api_url = f"{url}/{hook_key}" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} payload = { "msg_type": msg_type, "content": { "text": content, - } + }, } try: @@ -45,6 +45,7 @@ class FeishuGroupBotTool(BuiltinTool): return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) \ No newline at end of file + return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/feishu_base.py b/api/core/tools/provider/builtin/feishu_base/feishu_base.py index febb769ff8..04056af53b 100644 --- a/api/core/tools/provider/builtin/feishu_base/feishu_base.py +++ b/api/core/tools/provider/builtin/feishu_base/feishu_base.py @@ -5,4 +5,4 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class FeishuBaseProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: GetTenantAccessTokenTool() - pass \ No newline at end of file + pass diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py index be43b43ce4..4a605fbffe 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py @@ -8,45 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool class AddBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - fields = tool_parameters.get('fields', '') + fields = tool_parameters.get("fields", "") if not fields: - return self.create_text_message('Invalid parameter fields') + return self.create_text_message("Invalid parameter fields") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "fields": json.loads(fields) - } + payload = {"fields": json.loads(fields)} try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to add base record, status code: {res.status_code}, response: {res.text}") + f"Failed to add base record, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to add base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base.py index 639644e7f0..6b755e2007 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_base.py @@ -8,28 +8,25 @@ from core.tools.tool.builtin_tool import BuiltinTool class CreateBaseTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - name = tool_parameters.get('name', '') - folder_token = tool_parameters.get('folder_token', '') + name = tool_parameters.get("name", "") + folder_token = tool_parameters.get("folder_token", "") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "name": name, - "folder_token": folder_token - } + payload = {"name": name, "folder_token": folder_token} try: res = httpx.post(url, headers=headers, params=params, json=payload, timeout=30) @@ -38,6 +35,7 @@ class CreateBaseTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to create base, status code: {res.status_code}, response: {res.text}") + f"Failed to create base, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to create base. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py index e9062e8730..b05d700113 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py @@ -8,37 +8,32 @@ from core.tools.tool.builtin_tool import BuiltinTool class CreateBaseTableTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - name = tool_parameters.get('name', '') + name = tool_parameters.get("name", "") - fields = tool_parameters.get('fields', '') + fields = tool_parameters.get("fields", "") if not fields: - return self.create_text_message('Invalid parameter fields') + return self.create_text_message("Invalid parameter fields") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "table": { - "name": name, - "fields": json.loads(fields) - } - } + payload = {"table": {"name": name, "fields": json.loads(fields)}} try: res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) @@ -47,6 +42,7 @@ class CreateBaseTableTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to create base table, status code: {res.status_code}, response: {res.text}") + f"Failed to create base table, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to create base table. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py index aa13aad6fa..862eb2171b 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py @@ -8,45 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool class DeleteBaseRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/batch_delete" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - record_ids = tool_parameters.get('record_ids', '') + record_ids = tool_parameters.get("record_ids", "") if not record_ids: - return self.create_text_message('Invalid parameter record_ids') + return self.create_text_message("Invalid parameter record_ids") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "records": json.loads(record_ids) - } + payload = {"records": json.loads(record_ids)} try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to delete base records, status code: {res.status_code}, response: {res.text}") + f"Failed to delete base records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to delete base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py index c4280ebc21..f512186303 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py @@ -8,32 +8,30 @@ from core.tools.tool.builtin_tool import BuiltinTool class DeleteBaseTablesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/batch_delete" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_ids = tool_parameters.get('table_ids', '') + table_ids = tool_parameters.get("table_ids", "") if not table_ids: - return self.create_text_message('Invalid parameter table_ids') + return self.create_text_message("Invalid parameter table_ids") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "table_ids": json.loads(table_ids) - } + payload = {"table_ids": json.loads(table_ids)} try: res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) @@ -42,6 +40,7 @@ class DeleteBaseTablesTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}") + f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to delete base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py index de70f2ed93..f664bbeed0 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py @@ -8,22 +8,22 @@ from core.tools.tool.builtin_tool import BuiltinTool class GetBaseInfoTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } try: @@ -33,6 +33,7 @@ class GetBaseInfoTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to get base info, status code: {res.status_code}, response: {res.text}") + f"Failed to get base info, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get base info. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py index 88507bda60..2ea61d0068 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py @@ -8,27 +8,24 @@ from core.tools.tool.builtin_tool import BuiltinTool class GetTenantAccessTokenTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" - app_id = tool_parameters.get('app_id', '') + app_id = tool_parameters.get("app_id", "") if not app_id: - return self.create_text_message('Invalid parameter app_id') + return self.create_text_message("Invalid parameter app_id") - app_secret = tool_parameters.get('app_secret', '') + app_secret = tool_parameters.get("app_secret", "") if not app_secret: - return self.create_text_message('Invalid parameter app_secret') + return self.create_text_message("Invalid parameter app_secret") headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} - payload = { - "app_id": app_id, - "app_secret": app_secret - } + payload = {"app_id": app_id, "app_secret": app_secret} """ { @@ -45,6 +42,7 @@ class GetTenantAccessTokenTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}") + f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get tenant access token. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py index 2a4229f137..e579d02f69 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py @@ -8,31 +8,31 @@ from core.tools.tool.builtin_tool import BuiltinTool class ListBaseRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/search" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - page_token = tool_parameters.get('page_token', '') - page_size = tool_parameters.get('page_size', '') - sort_condition = tool_parameters.get('sort_condition', '') - filter_condition = tool_parameters.get('filter_condition', '') + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", "") + sort_condition = tool_parameters.get("sort_condition", "") + filter_condition = tool_parameters.get("filter_condition", "") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = { @@ -40,22 +40,26 @@ class ListBaseRecordsTool(BuiltinTool): "page_size": page_size, } - payload = { - "automatic_fields": True - } + payload = {"automatic_fields": True} if sort_condition: payload["sort"] = json.loads(sort_condition) if filter_condition: payload["filter"] = json.loads(filter_condition) try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) + res = httpx.post( + url.format(app_token=app_token, table_id=table_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to list base records, status code: {res.status_code}, response: {res.text}") + f"Failed to list base records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py index 6d82490eb3..4ec9a476bc 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py @@ -8,25 +8,25 @@ from core.tools.tool.builtin_tool import BuiltinTool class ListBaseTablesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - page_token = tool_parameters.get('page_token', '') - page_size = tool_parameters.get('page_size', '') + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", "") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = { @@ -41,6 +41,7 @@ class ListBaseTablesTool(BuiltinTool): return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to list base tables, status code: {res.status_code}, response: {res.text}") + f"Failed to list base tables, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py index bb4bd6c3a6..fb818f8380 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py @@ -8,40 +8,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class ReadBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - record_id = tool_parameters.get('record_id', '') + record_id = tool_parameters.get("record_id", "") if not record_id: - return self.create_text_message('Invalid parameter record_id') + return self.create_text_message("Invalid parameter record_id") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } try: - res = httpx.get(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, - timeout=30) + res = httpx.get( + url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, timeout=30 + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to read base record, status code: {res.status_code}, response: {res.text}") + f"Failed to read base record, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to read base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py index 6551053ce2..6d7e33f3ff 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py @@ -8,49 +8,53 @@ from core.tools.tool.builtin_tool import BuiltinTool class UpdateBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" - access_token = tool_parameters.get('Authorization', '') + access_token = tool_parameters.get("Authorization", "") if not access_token: - return self.create_text_message('Invalid parameter access_token') + return self.create_text_message("Invalid parameter access_token") - app_token = tool_parameters.get('app_token', '') + app_token = tool_parameters.get("app_token", "") if not app_token: - return self.create_text_message('Invalid parameter app_token') + return self.create_text_message("Invalid parameter app_token") - table_id = tool_parameters.get('table_id', '') + table_id = tool_parameters.get("table_id", "") if not table_id: - return self.create_text_message('Invalid parameter table_id') + return self.create_text_message("Invalid parameter table_id") - record_id = tool_parameters.get('record_id', '') + record_id = tool_parameters.get("record_id", "") if not record_id: - return self.create_text_message('Invalid parameter record_id') + return self.create_text_message("Invalid parameter record_id") - fields = tool_parameters.get('fields', '') + fields = tool_parameters.get("fields", "") if not fields: - return self.create_text_message('Invalid parameter fields') + return self.create_text_message("Invalid parameter fields") headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", } params = {} - payload = { - "fields": json.loads(fields) - } + payload = {"fields": json.loads(fields)} try: - res = httpx.put(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, - params=params, json=payload, timeout=30) + res = httpx.put( + url.format(app_token=app_token, table_id=table_id, record_id=record_id), + headers=headers, + params=params, + json=payload, + timeout=30, + ) res_json = res.json() if res.is_success: return self.create_text_message(text=json.dumps(res_json)) else: return self.create_text_message( - f"Failed to update base record, status code: {res.status_code}, response: {res.text}") + f"Failed to update base record, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to update base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_document/feishu_document.py b/api/core/tools/provider/builtin/feishu_document/feishu_document.py index c4f8f26e2c..b0a1e393eb 100644 --- a/api/core/tools/provider/builtin/feishu_document/feishu_document.py +++ b/api/core/tools/provider/builtin/feishu_document/feishu_document.py @@ -5,11 +5,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class FeishuDocumentProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - app_id = credentials.get('app_id') - app_secret = credentials.get('app_secret') + app_id = credentials.get("app_id") + app_secret = credentials.get("app_secret") if not app_id or not app_secret: raise ToolProviderCredentialValidationError("app_id and app_secret is required") try: assert FeishuRequest(app_id, app_secret).tenant_access_token is not None except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py index 0ff82e621b..090a0828e8 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py @@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class CreateDocumentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - title = tool_parameters.get('title') - content = tool_parameters.get('content') - folder_token = tool_parameters.get('folder_token') + title = tool_parameters.get("title") + content = tool_parameters.get("content") + folder_token = tool_parameters.get("folder_token") res = client.create_document(title, content, folder_token) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py index 16ef90908b..83073e0822 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_raw_content.py @@ -7,11 +7,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class GetDocumentRawContentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get('document_id') + document_id = tool_parameters.get("document_id") res = client.get_document_raw_content(document_id) - return self.create_json_message(res) \ No newline at end of file + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py index 97d17bdb04..8c0c4a3c97 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_block.py @@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class ListDocumentBlockTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get('document_id') - page_size = tool_parameters.get('page_size', 500) - page_token = tool_parameters.get('page_token', '') + document_id = tool_parameters.get("document_id") + page_size = tool_parameters.get("page_size", 500) + page_token = tool_parameters.get("page_token", "") res = client.list_document_block(document_id, page_token, page_size) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/write_document.py b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py index 914a44dce6..6061250e48 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/write_document.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py @@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class CreateDocumentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get('document_id') - content = tool_parameters.get('content') - position = tool_parameters.get('position') + document_id = tool_parameters.get("document_id") + content = tool_parameters.get("content") + position = tool_parameters.get("position") res = client.write_document(document_id, content, position) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/feishu_message.py b/api/core/tools/provider/builtin/feishu_message/feishu_message.py index 6d7fed330c..7b3adb9293 100644 --- a/api/core/tools/provider/builtin/feishu_message/feishu_message.py +++ b/api/core/tools/provider/builtin/feishu_message/feishu_message.py @@ -5,11 +5,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class FeishuMessageProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - app_id = credentials.get('app_id') - app_secret = credentials.get('app_secret') + app_id = credentials.get("app_id") + app_secret = credentials.get("app_secret") if not app_id or not app_secret: raise ToolProviderCredentialValidationError("app_id and app_secret is required") try: assert FeishuRequest(app_id, app_secret).tenant_access_token is not None except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py index 74f6866ba3..1dd315d0e2 100644 --- a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py @@ -7,14 +7,14 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class SendBotMessageTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - receive_id_type = tool_parameters.get('receive_id_type') - receive_id = tool_parameters.get('receive_id') - msg_type = tool_parameters.get('msg_type') - content = tool_parameters.get('content') + receive_id_type = tool_parameters.get("receive_id_type") + receive_id = tool_parameters.get("receive_id") + msg_type = tool_parameters.get("msg_type") + content = tool_parameters.get("content") res = client.send_bot_message(receive_id_type, receive_id, msg_type, content) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py index 7159f59ffa..44e70e0a15 100644 --- a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py @@ -6,14 +6,14 @@ from core.tools.utils.feishu_api_utils import FeishuRequest class SendWebhookMessageTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) ->ToolInvokeMessage: - app_id = self.runtime.credentials.get('app_id') - app_secret = self.runtime.credentials.get('app_secret') + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") client = FeishuRequest(app_id, app_secret) - webhook = tool_parameters.get('webhook') - msg_type = tool_parameters.get('msg_type') - content = tool_parameters.get('content') + webhook = tool_parameters.get("webhook") + msg_type = tool_parameters.get("msg_type") + content = tool_parameters.get("content") res = client.send_webhook_message(webhook, msg_type, content) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl.py b/api/core/tools/provider/builtin/firecrawl/firecrawl.py index 24dc35759d..01455d7206 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl.py @@ -7,15 +7,8 @@ class FirecrawlProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the ScrapeTool, only scraping title for minimize content - ScrapeTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', - tool_parameters={ - "url": "https://google.com", - "onlyIncludeTags": 'title' - } + ScrapeTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={"url": "https://google.com", "onlyIncludeTags": "title"} ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py index 3b3f78731b..a0e4cdf933 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py @@ -13,27 +13,24 @@ logger = logging.getLogger(__name__) class FirecrawlApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.firecrawl.dev' + self.base_url = base_url or "https://api.firecrawl.dev" if not self.api_key: raise ValueError("API key is required") def _prepare_headers(self, idempotency_key: str | None = None): - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} if idempotency_key: - headers['Idempotency-Key'] = idempotency_key + headers["Idempotency-Key"] = idempotency_key return headers def _request( - self, - method: str, - url: str, - data: Mapping[str, Any] | None = None, - headers: Mapping[str, str] | None = None, - retries: int = 3, - backoff_factor: float = 0.3, + self, + method: str, + url: str, + data: Mapping[str, Any] | None = None, + headers: Mapping[str, str] | None = None, + retries: int = 3, + backoff_factor: float = 0.3, ) -> Mapping[str, Any] | None: if not headers: headers = self._prepare_headers() @@ -44,54 +41,54 @@ class FirecrawlApp: return response.json() except requests.exceptions.RequestException as e: if i < retries - 1: - time.sleep(backoff_factor * (2 ** i)) + time.sleep(backoff_factor * (2**i)) else: raise return None def scrape_url(self, url: str, **kwargs): - endpoint = f'{self.base_url}/v0/scrape' - data = {'url': url, **kwargs} + endpoint = f"{self.base_url}/v0/scrape" + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data) + response = self._request("POST", endpoint, data) if response is None: raise HTTPError("Failed to scrape URL after multiple retries") return response def search(self, query: str, **kwargs): - endpoint = f'{self.base_url}/v0/search' - data = {'query': query, **kwargs} + endpoint = f"{self.base_url}/v0/search" + data = {"query": query, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data) + response = self._request("POST", endpoint, data) if response is None: raise HTTPError("Failed to perform search after multiple retries") return response def crawl_url( - self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs + self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs ): - endpoint = f'{self.base_url}/v0/crawl' + endpoint = f"{self.base_url}/v0/crawl" headers = self._prepare_headers(idempotency_key) - data = {'url': url, **kwargs} + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data, headers) + response = self._request("POST", endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate crawl after multiple retries") - job_id: str = response['jobId'] + job_id: str = response["jobId"] if wait: return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval) return response def check_crawl_status(self, job_id: str): - endpoint = f'{self.base_url}/v0/crawl/status/{job_id}' - response = self._request('GET', endpoint) + endpoint = f"{self.base_url}/v0/crawl/status/{job_id}" + response = self._request("GET", endpoint) if response is None: raise HTTPError(f"Failed to check status for job {job_id} after multiple retries") return response def cancel_crawl_job(self, job_id: str): - endpoint = f'{self.base_url}/v0/crawl/cancel/{job_id}' - response = self._request('DELETE', endpoint) + endpoint = f"{self.base_url}/v0/crawl/cancel/{job_id}" + response = self._request("DELETE", endpoint) if response is None: raise HTTPError(f"Failed to cancel job {job_id} after multiple retries") return response @@ -99,9 +96,9 @@ class FirecrawlApp: def _monitor_job_status(self, job_id: str, poll_interval: int): while True: status = self.check_crawl_status(job_id) - if status['status'] == 'completed': + if status["status"] == "completed": return status - elif status['status'] == 'failed': + elif status["status"] == "failed": raise HTTPError(f'Job {job_id} failed: {status["error"]}') time.sleep(poll_interval) @@ -109,7 +106,7 @@ class FirecrawlApp: def get_array_params(tool_parameters: dict[str, Any], key): param = tool_parameters.get(key) if param: - return param.split(',') + return param.split(",") def get_json_params(tool_parameters: dict[str, Any], key): diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py index 08c40a4064..94717cbbfb 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py @@ -11,38 +11,36 @@ class CrawlTool(BuiltinTool): the crawlerOptions and pageOptions comes from doc here: https://docs.firecrawl.dev/api-reference/endpoint/crawl """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) crawlerOptions = {} pageOptions = {} - wait_for_results = tool_parameters.get('wait_for_results', True) + wait_for_results = tool_parameters.get("wait_for_results", True) - crawlerOptions['excludes'] = get_array_params(tool_parameters, 'excludes') - crawlerOptions['includes'] = get_array_params(tool_parameters, 'includes') - crawlerOptions['returnOnlyUrls'] = tool_parameters.get('returnOnlyUrls', False) - crawlerOptions['maxDepth'] = tool_parameters.get('maxDepth') - crawlerOptions['mode'] = tool_parameters.get('mode') - crawlerOptions['ignoreSitemap'] = tool_parameters.get('ignoreSitemap', False) - crawlerOptions['limit'] = tool_parameters.get('limit', 5) - crawlerOptions['allowBackwardCrawling'] = tool_parameters.get('allowBackwardCrawling', False) - crawlerOptions['allowExternalContentLinks'] = tool_parameters.get('allowExternalContentLinks', False) + crawlerOptions["excludes"] = get_array_params(tool_parameters, "excludes") + crawlerOptions["includes"] = get_array_params(tool_parameters, "includes") + crawlerOptions["returnOnlyUrls"] = tool_parameters.get("returnOnlyUrls", False) + crawlerOptions["maxDepth"] = tool_parameters.get("maxDepth") + crawlerOptions["mode"] = tool_parameters.get("mode") + crawlerOptions["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", False) + crawlerOptions["limit"] = tool_parameters.get("limit", 5) + crawlerOptions["allowBackwardCrawling"] = tool_parameters.get("allowBackwardCrawling", False) + crawlerOptions["allowExternalContentLinks"] = tool_parameters.get("allowExternalContentLinks", False) - pageOptions['headers'] = get_json_params(tool_parameters, 'headers') - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags') - pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags') - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False) - pageOptions['screenshot'] = tool_parameters.get('screenshot', False) - pageOptions['waitFor'] = tool_parameters.get('waitFor', 0) + pageOptions["headers"] = get_json_params(tool_parameters, "headers") + pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) + pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) + pageOptions["onlyIncludeTags"] = get_array_params(tool_parameters, "onlyIncludeTags") + pageOptions["removeTags"] = get_array_params(tool_parameters, "removeTags") + pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + pageOptions["replaceAllPathsWithAbsolutePaths"] = tool_parameters.get("replaceAllPathsWithAbsolutePaths", False) + pageOptions["screenshot"] = tool_parameters.get("screenshot", False) + pageOptions["waitFor"] = tool_parameters.get("waitFor", 0) crawl_result = app.crawl_url( - url=tool_parameters['url'], - wait=wait_for_results, - crawlerOptions=crawlerOptions, - pageOptions=pageOptions + url=tool_parameters["url"], wait=wait_for_results, crawlerOptions=crawlerOptions, pageOptions=pageOptions ) return self.create_json_message(crawl_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py index fa6c1f87ee..0d2486c7ca 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py @@ -7,14 +7,15 @@ from core.tools.tool.builtin_tool import BuiltinTool class CrawlJobTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) - operation = tool_parameters.get('operation', 'get') - if operation == 'get': - result = app.check_crawl_status(job_id=tool_parameters['job_id']) - elif operation == 'cancel': - result = app.cancel_crawl_job(job_id=tool_parameters['job_id']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) + operation = tool_parameters.get("operation", "get") + if operation == "get": + result = app.check_crawl_status(job_id=tool_parameters["job_id"]) + elif operation == "cancel": + result = app.cancel_crawl_job(job_id=tool_parameters["job_id"]) else: - raise ValueError(f'Invalid operation: {operation}') + raise ValueError(f"Invalid operation: {operation}") return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py index 91412da548..962570bf73 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py @@ -6,34 +6,34 @@ from core.tools.tool.builtin_tool import BuiltinTool class ScrapeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: """ the pageOptions and extractorOptions comes from doc here: https://docs.firecrawl.dev/api-reference/endpoint/scrape """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) pageOptions = {} extractorOptions = {} - pageOptions['headers'] = get_json_params(tool_parameters, 'headers') - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags') - pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags') - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False) - pageOptions['screenshot'] = tool_parameters.get('screenshot', False) - pageOptions['waitFor'] = tool_parameters.get('waitFor', 0) + pageOptions["headers"] = get_json_params(tool_parameters, "headers") + pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) + pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) + pageOptions["onlyIncludeTags"] = get_array_params(tool_parameters, "onlyIncludeTags") + pageOptions["removeTags"] = get_array_params(tool_parameters, "removeTags") + pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + pageOptions["replaceAllPathsWithAbsolutePaths"] = tool_parameters.get("replaceAllPathsWithAbsolutePaths", False) + pageOptions["screenshot"] = tool_parameters.get("screenshot", False) + pageOptions["waitFor"] = tool_parameters.get("waitFor", 0) - extractorOptions['mode'] = tool_parameters.get('mode', '') - extractorOptions['extractionPrompt'] = tool_parameters.get('extractionPrompt', '') - extractorOptions['extractionSchema'] = get_json_params(tool_parameters, 'extractionSchema') + extractorOptions["mode"] = tool_parameters.get("mode", "") + extractorOptions["extractionPrompt"] = tool_parameters.get("extractionPrompt", "") + extractorOptions["extractionSchema"] = get_json_params(tool_parameters, "extractionSchema") - crawl_result = app.scrape_url(url=tool_parameters['url'], - pageOptions=pageOptions, - extractorOptions=extractorOptions) + crawl_result = app.scrape_url( + url=tool_parameters["url"], pageOptions=pageOptions, extractorOptions=extractorOptions + ) return self.create_json_message(crawl_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.py b/api/core/tools/provider/builtin/firecrawl/tools/search.py index e2b2ac6b4d..f077e7d8ea 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/search.py @@ -11,18 +11,17 @@ class SearchTool(BuiltinTool): the pageOptions and searchOptions comes from doc here: https://docs.firecrawl.dev/api-reference/endpoint/search """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) pageOptions = {} - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['fetchPageContent'] = tool_parameters.get('fetchPageContent', True) - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - searchOptions = {'limit': tool_parameters.get('limit')} + pageOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + pageOptions["fetchPageContent"] = tool_parameters.get("fetchPageContent", True) + pageOptions["includeHtml"] = tool_parameters.get("includeHtml", False) + pageOptions["includeRawHtml"] = tool_parameters.get("includeRawHtml", False) + searchOptions = {"limit": tool_parameters.get("limit")} search_result = app.search( - query=tool_parameters['keyword'], - pageOptions=pageOptions, - searchOptions=searchOptions + query=tool_parameters["keyword"], pageOptions=pageOptions, searchOptions=searchOptions ) return self.create_json_message(search_result) diff --git a/api/core/tools/provider/builtin/gaode/gaode.py b/api/core/tools/provider/builtin/gaode/gaode.py index b55d93e07b..a3e50da001 100644 --- a/api/core/tools/provider/builtin/gaode/gaode.py +++ b/api/core/tools/provider/builtin/gaode/gaode.py @@ -9,17 +9,19 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class GaodeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - if 'api_key' not in credentials or not credentials.get('api_key'): + if "api_key" not in credentials or not credentials.get("api_key"): raise ToolProviderCredentialValidationError("Gaode API key is required.") try: - response = requests.get(url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}" - "".format(address=urllib.parse.quote('广东省广州市天河区广州塔'), - apikey=credentials.get('api_key'))) - if response.status_code == 200 and (response.json()).get('info') == 'OK': + response = requests.get( + url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}" "".format( + address=urllib.parse.quote("广东省广州市天河区广州塔"), apikey=credentials.get("api_key") + ) + ) + if response.status_code == 200 and (response.json()).get("info") == "OK": pass else: - raise ToolProviderCredentialValidationError((response.json()).get('info')) + raise ToolProviderCredentialValidationError((response.json()).get("info")) except Exception as e: raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e)) except Exception as e: diff --git a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py index efd11cedce..843504eefd 100644 --- a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py +++ b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py @@ -8,50 +8,57 @@ from core.tools.tool.builtin_tool import BuiltinTool class GaodeRepositoriesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - city = tool_parameters.get('city', '') + city = tool_parameters.get("city", "") if not city: - return self.create_text_message('Please tell me your city') + return self.create_text_message("Please tell me your city") - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("Gaode API key is required.") try: s = requests.session() - api_domain = 'https://restapi.amap.com/v3' - city_response = s.request(method='GET', headers={"Content-Type": "application/json; charset=utf-8"}, - url="{url}/config/district?keywords={keywords}" - "&subdistrict=0&extensions=base&key={apikey}" - "".format(url=api_domain, keywords=city, - apikey=self.runtime.credentials.get('api_key'))) + api_domain = "https://restapi.amap.com/v3" + city_response = s.request( + method="GET", + headers={"Content-Type": "application/json; charset=utf-8"}, + url="{url}/config/district?keywords={keywords}" "&subdistrict=0&extensions=base&key={apikey}" "".format( + url=api_domain, keywords=city, apikey=self.runtime.credentials.get("api_key") + ), + ) City_data = city_response.json() - if city_response.status_code == 200 and City_data.get('info') == 'OK': - if len(City_data.get('districts')) > 0: - CityCode = City_data['districts'][0]['adcode'] - weatherInfo_response = s.request(method='GET', - url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json" - "".format(url=api_domain, citycode=CityCode, - apikey=self.runtime.credentials.get('api_key'))) + if city_response.status_code == 200 and City_data.get("info") == "OK": + if len(City_data.get("districts")) > 0: + CityCode = City_data["districts"][0]["adcode"] + weatherInfo_response = s.request( + method="GET", + url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json" + "".format(url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")), + ) weatherInfo_data = weatherInfo_response.json() - if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK': + if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK": contents = [] - if len(weatherInfo_data.get('forecasts')) > 0: - for item in weatherInfo_data['forecasts'][0]['casts']: + if len(weatherInfo_data.get("forecasts")) > 0: + for item in weatherInfo_data["forecasts"][0]["casts"]: content = {} - content['date'] = item.get('date') - content['week'] = item.get('week') - content['dayweather'] = item.get('dayweather') - content['daytemp_float'] = item.get('daytemp_float') - content['daywind'] = item.get('daywind') - content['nightweather'] = item.get('nightweather') - content['nighttemp_float'] = item.get('nighttemp_float') + content["date"] = item.get("date") + content["week"] = item.get("week") + content["dayweather"] = item.get("dayweather") + content["daytemp_float"] = item.get("daytemp_float") + content["daywind"] = item.get("daywind") + content["nightweather"] = item.get("nightweather") + content["nighttemp_float"] = item.get("nighttemp_float") contents.append(content) s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)) + ) s.close() - return self.create_text_message(f'No weather information for {city} was found.') + return self.create_text_message(f"No weather information for {city} was found.") except Exception as e: return self.create_text_message("Gaode API Key and Api Version is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/getimgai/getimgai.py b/api/core/tools/provider/builtin/getimgai/getimgai.py index c81d5fa333..bbd07d120f 100644 --- a/api/core/tools/provider/builtin/getimgai/getimgai.py +++ b/api/core/tools/provider/builtin/getimgai/getimgai.py @@ -7,16 +7,13 @@ class GetImgAIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the text2image tool - Text2ImageTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', + Text2ImageTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={ "prompt": "A fire egg", "response_format": "url", "style": "photorealism", - } + }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/getimgai/getimgai_appx.py b/api/core/tools/provider/builtin/getimgai/getimgai_appx.py index e28c57649c..0e95a5f654 100644 --- a/api/core/tools/provider/builtin/getimgai/getimgai_appx.py +++ b/api/core/tools/provider/builtin/getimgai/getimgai_appx.py @@ -8,18 +8,16 @@ from requests.exceptions import HTTPError logger = logging.getLogger(__name__) + class GetImgAIApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.getimg.ai/v1' + self.base_url = base_url or "https://api.getimg.ai/v1" if not self.api_key: raise ValueError("API key is required") def _prepare_headers(self): - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return headers def _request( @@ -38,22 +36,20 @@ class GetImgAIApp: return response.json() except requests.exceptions.RequestException as e: if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500: - time.sleep(backoff_factor * (2 ** i)) + time.sleep(backoff_factor * (2**i)) else: raise return None - def text2image( - self, mode: str, **kwargs - ): - data = kwargs['params'] - if not data.get('prompt'): + def text2image(self, mode: str, **kwargs): + data = kwargs["params"] + if not data.get("prompt"): raise ValueError("Prompt is required") - endpoint = f'{self.base_url}/{mode}/text-to-image' + endpoint = f"{self.base_url}/{mode}/text-to-image" headers = self._prepare_headers() logger.debug(f"Send request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data, headers) + response = self._request("POST", endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate getimg.ai after multiple retries") return response diff --git a/api/core/tools/provider/builtin/getimgai/tools/text2image.py b/api/core/tools/provider/builtin/getimgai/tools/text2image.py index dad7314479..c556749552 100644 --- a/api/core/tools/provider/builtin/getimgai/tools/text2image.py +++ b/api/core/tools/provider/builtin/getimgai/tools/text2image.py @@ -7,28 +7,28 @@ from core.tools.tool.builtin_tool import BuiltinTool class Text2ImageTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = GetImgAIApp(api_key=self.runtime.credentials['getimg_api_key'], base_url=self.runtime.credentials['base_url']) + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + app = GetImgAIApp( + api_key=self.runtime.credentials["getimg_api_key"], base_url=self.runtime.credentials["base_url"] + ) options = { - 'style': tool_parameters.get('style'), - 'prompt': tool_parameters.get('prompt'), - 'aspect_ratio': tool_parameters.get('aspect_ratio'), - 'output_format': tool_parameters.get('output_format', 'jpeg'), - 'response_format': tool_parameters.get('response_format', 'url'), - 'width': tool_parameters.get('width'), - 'height': tool_parameters.get('height'), - 'steps': tool_parameters.get('steps'), - 'negative_prompt': tool_parameters.get('negative_prompt'), - 'prompt_2': tool_parameters.get('prompt_2'), + "style": tool_parameters.get("style"), + "prompt": tool_parameters.get("prompt"), + "aspect_ratio": tool_parameters.get("aspect_ratio"), + "output_format": tool_parameters.get("output_format", "jpeg"), + "response_format": tool_parameters.get("response_format", "url"), + "width": tool_parameters.get("width"), + "height": tool_parameters.get("height"), + "steps": tool_parameters.get("steps"), + "negative_prompt": tool_parameters.get("negative_prompt"), + "prompt_2": tool_parameters.get("prompt_2"), } options = {k: v for k, v in options.items() if v} - text2image_result = app.text2image( - mode=tool_parameters.get('mode', 'essential-v2'), - params=options, - wait=True - ) + text2image_result = app.text2image(mode=tool_parameters.get("mode", "essential-v2"), params=options, wait=True) if not isinstance(text2image_result, str): text2image_result = json.dumps(text2image_result, ensure_ascii=False, indent=4) diff --git a/api/core/tools/provider/builtin/github/github.py b/api/core/tools/provider/builtin/github/github.py index b19f0896f8..87a34ac3e8 100644 --- a/api/core/tools/provider/builtin/github/github.py +++ b/api/core/tools/provider/builtin/github/github.py @@ -7,25 +7,25 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class GithubProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + if "access_tokens" not in credentials or not credentials.get("access_tokens"): raise ToolProviderCredentialValidationError("Github API Access Tokens is required.") - if 'api_version' not in credentials or not credentials.get('api_version'): - api_version = '2022-11-28' + if "api_version" not in credentials or not credentials.get("api_version"): + api_version = "2022-11-28" else: - api_version = credentials.get('api_version') + api_version = credentials.get("api_version") try: headers = { "Content-Type": "application/vnd.github+json", "Authorization": f"Bearer {credentials.get('access_tokens')}", - "X-GitHub-Api-Version": api_version + "X-GitHub-Api-Version": api_version, } response = requests.get( - url="https://api.github.com/search/users?q={account}".format(account='charli117'), - headers=headers) + url="https://api.github.com/search/users?q={account}".format(account="charli117"), headers=headers + ) if response.status_code != 200: - raise ToolProviderCredentialValidationError((response.json()).get('message')) + raise ToolProviderCredentialValidationError((response.json()).get("message")) except Exception as e: raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e)) except Exception as e: diff --git a/api/core/tools/provider/builtin/github/tools/github_repositories.py b/api/core/tools/provider/builtin/github/tools/github_repositories.py index 305bf08ce8..3eab8bf8dc 100644 --- a/api/core/tools/provider/builtin/github/tools/github_repositories.py +++ b/api/core/tools/provider/builtin/github/tools/github_repositories.py @@ -10,53 +10,61 @@ from core.tools.tool.builtin_tool import BuiltinTool class GithubRepositoriesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - top_n = tool_parameters.get('top_n', 5) - query = tool_parameters.get('query', '') + top_n = tool_parameters.get("top_n", 5) + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input symbol') + return self.create_text_message("Please input symbol") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Github API Access Tokens is required.") - if 'api_version' not in self.runtime.credentials or not self.runtime.credentials.get('api_version'): - api_version = '2022-11-28' + if "api_version" not in self.runtime.credentials or not self.runtime.credentials.get("api_version"): + api_version = "2022-11-28" else: - api_version = self.runtime.credentials.get('api_version') + api_version = self.runtime.credentials.get("api_version") try: headers = { "Content-Type": "application/vnd.github+json", "Authorization": f"Bearer {self.runtime.credentials.get('access_tokens')}", - "X-GitHub-Api-Version": api_version + "X-GitHub-Api-Version": api_version, } s = requests.session() - api_domain = 'https://api.github.com' - response = s.request(method='GET', headers=headers, - url=f"{api_domain}/search/repositories?" - f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc") + api_domain = "https://api.github.com" + response = s.request( + method="GET", + headers=headers, + url=f"{api_domain}/search/repositories?" f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc", + ) response_data = response.json() - if response.status_code == 200 and isinstance(response_data.get('items'), list): + if response.status_code == 200 and isinstance(response_data.get("items"), list): contents = [] - if len(response_data.get('items')) > 0: - for item in response_data.get('items'): + if len(response_data.get("items")) > 0: + for item in response_data.get("items"): content = {} - updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ") - content['owner'] = item['owner']['login'] - content['name'] = item['name'] - content['description'] = item['description'][:100] + '...' if len(item['description']) > 100 else item['description'] - content['url'] = item['html_url'] - content['star'] = item['watchers'] - content['forks'] = item['forks'] - content['updated'] = updated_at_object.strftime("%Y-%m-%d") + updated_at_object = datetime.strptime(item["updated_at"], "%Y-%m-%dT%H:%M:%SZ") + content["owner"] = item["owner"]["login"] + content["name"] = item["name"] + content["description"] = ( + item["description"][:100] + "..." if len(item["description"]) > 100 else item["description"] + ) + content["url"] = item["html_url"] + content["star"] = item["watchers"] + content["forks"] = item["forks"] + content["updated"] = updated_at_object.strftime("%Y-%m-%d") contents.append(content) s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)) + ) else: - return self.create_text_message(f'No items related to {query} were found.') + return self.create_text_message(f"No items related to {query} were found.") else: - return self.create_text_message((response.json()).get('message')) + return self.create_text_message((response.json()).get("message")) except Exception as e: return self.create_text_message("Github API Key and Api Version is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.py b/api/core/tools/provider/builtin/gitlab/gitlab.py index 0c13ec662a..9bd4a0bd52 100644 --- a/api/core/tools/provider/builtin/gitlab/gitlab.py +++ b/api/core/tools/provider/builtin/gitlab/gitlab.py @@ -9,13 +9,13 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class GitlabProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + if "access_tokens" not in credentials or not credentials.get("access_tokens"): raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.") - - if 'site_url' not in credentials or not credentials.get('site_url'): - site_url = 'https://gitlab.com' + + if "site_url" not in credentials or not credentials.get("site_url"): + site_url = "https://gitlab.com" else: - site_url = credentials.get('site_url') + site_url = credentials.get("site_url") try: headers = { @@ -23,12 +23,10 @@ class GitlabProvider(BuiltinToolProviderController): "Authorization": f"Bearer {credentials.get('access_tokens')}", } - response = requests.get( - url= f"{site_url}/api/v4/user", - headers=headers) + response = requests.get(url=f"{site_url}/api/v4/user", headers=headers) if response.status_code != 200: - raise ToolProviderCredentialValidationError((response.json()).get('message')) + raise ToolProviderCredentialValidationError((response.json()).get("message")) except Exception as e: raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e)) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py index 0824eb3a26..dceb37db49 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py @@ -9,39 +9,47 @@ from core.tools.tool.builtin_tool import BuiltinTool class GitlabCommitsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - project = tool_parameters.get('project', '') - employee = tool_parameters.get('employee', '') - start_time = tool_parameters.get('start_time', '') - end_time = tool_parameters.get('end_time', '') - change_type = tool_parameters.get('change_type', 'all') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + project = tool_parameters.get("project", "") + employee = tool_parameters.get("employee", "") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") + change_type = tool_parameters.get("change_type", "all") if not project: - return self.create_text_message('Project is required') + return self.create_text_message("Project is required") if not start_time: start_time = (datetime.utcnow() - timedelta(days=1)).isoformat() if not end_time: end_time = datetime.utcnow().isoformat() - access_token = self.runtime.credentials.get('access_tokens') - site_url = self.runtime.credentials.get('site_url') + access_token = self.runtime.credentials.get("access_tokens") + site_url = self.runtime.credentials.get("site_url") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Gitlab API Access Tokens is required.") - if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): - site_url = 'https://gitlab.com' - + if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): + site_url = "https://gitlab.com" + # Get commit content result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type) return [self.create_json_message(item) for item in result] - - def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]: + + def fetch( + self, + user_id: str, + site_url: str, + access_token: str, + project: str, + employee: str = None, + start_time: str = "", + end_time: str = "", + change_type: str = "", + ) -> list[dict[str, Any]]: domain = site_url headers = {"PRIVATE-TOKEN": access_token} results = [] @@ -53,59 +61,66 @@ class GitlabCommitsTool(BuiltinTool): response.raise_for_status() projects = response.json() - filtered_projects = [p for p in projects if project == "*" or p['name'] == project] + filtered_projects = [p for p in projects if project == "*" or p["name"] == project] for project in filtered_projects: - project_id = project['id'] - project_name = project['name'] + project_id = project["id"] + project_name = project["name"] print(f"Project: {project_name}") # Get all of project commits commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" - params = { - 'since': start_time, - 'until': end_time - } + params = {"since": start_time, "until": end_time} if employee: - params['author'] = employee + params["author"] = employee commits_response = requests.get(commits_url, headers=headers, params=params) commits_response.raise_for_status() commits = commits_response.json() for commit in commits: - commit_sha = commit['id'] - author_name = commit['author_name'] + commit_sha = commit["id"] + author_name = commit["author_name"] diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff" diff_response = requests.get(diff_url, headers=headers) diff_response.raise_for_status() diffs = diff_response.json() - + for diff in diffs: # Calculate code lines of changed - added_lines = diff['diff'].count('\n+') - removed_lines = diff['diff'].count('\n-') + added_lines = diff["diff"].count("\n+") + removed_lines = diff["diff"].count("\n-") total_changes = added_lines + removed_lines if change_type == "new": if added_lines > 1: - final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')]) - results.append({ - "commit_sha": commit_sha, - "author_name": author_name, - "diff": final_code - }) + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if line.startswith("+") and not line.startswith("+++") + ] + ) + results.append( + {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code} + ) else: if total_changes > 1: - final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')]) + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if (line.startswith("+") or line.startswith("-")) + and not line.startswith("+++") + and not line.startswith("---") + ] + ) final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code - results.append({ - "commit_sha": commit_sha, - "author_name": author_name, - "diff": final_code_escaped - }) + results.append( + {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped} + ) except requests.RequestException as e: print(f"Error fetching data from GitLab: {e}") - - return results \ No newline at end of file + + return results diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py index 7fa1d0d112..4a42b0fd73 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py @@ -7,32 +7,29 @@ from core.tools.tool.builtin_tool import BuiltinTool class GitlabFilesTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - project = tool_parameters.get('project', '') - branch = tool_parameters.get('branch', '') - path = tool_parameters.get('path', '') - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + project = tool_parameters.get("project", "") + branch = tool_parameters.get("branch", "") + path = tool_parameters.get("path", "") if not project: - return self.create_text_message('Project is required') + return self.create_text_message("Project is required") if not branch: - return self.create_text_message('Branch is required') + return self.create_text_message("Branch is required") if not path: - return self.create_text_message('Path is required') + return self.create_text_message("Path is required") - access_token = self.runtime.credentials.get('access_tokens') - site_url = self.runtime.credentials.get('site_url') + access_token = self.runtime.credentials.get("access_tokens") + site_url = self.runtime.credentials.get("site_url") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Gitlab API Access Tokens is required.") - if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): - site_url = 'https://gitlab.com' - + if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): + site_url = "https://gitlab.com" + # Get project ID from project name project_id = self.get_project_id(site_url, access_token, project) if not project_id: @@ -42,9 +39,9 @@ class GitlabFilesTool(BuiltinTool): result = self.fetch(user_id, project_id, site_url, access_token, branch, path) return [self.create_json_message(item) for item in result] - + def extract_project_name_and_path(self, path: str) -> tuple[str, str]: - parts = path.split('/', 1) + parts = path.split("/", 1) if len(parts) < 2: return None, None return parts[0], parts[1] @@ -57,13 +54,15 @@ class GitlabFilesTool(BuiltinTool): response.raise_for_status() projects = response.json() for project in projects: - if project['name'] == project_name: - return project['id'] + if project["name"] == project_name: + return project["id"] except requests.RequestException as e: print(f"Error fetching project ID from GitLab: {e}") return None - - def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]: + + def fetch( + self, user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None + ) -> list[dict[str, Any]]: domain = site_url headers = {"PRIVATE-TOKEN": access_token} results = [] @@ -76,20 +75,16 @@ class GitlabFilesTool(BuiltinTool): items = response.json() for item in items: - item_path = item['path'] - if item['type'] == 'tree': # It's a directory + item_path = item["path"] + if item["type"] == "tree": # It's a directory results.extend(self.fetch(project_id, site_url, access_token, branch, item_path)) else: # It's a file file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" file_response = requests.get(file_url, headers=headers) file_response.raise_for_status() file_content = file_response.text - results.append({ - "path": item_path, - "branch": branch, - "content": file_content - }) + results.append({"path": item_path, "branch": branch, "content": file_content}) except requests.RequestException as e: print(f"Error fetching data from GitLab: {e}") - - return results \ No newline at end of file + + return results diff --git a/api/core/tools/provider/builtin/google/google.py b/api/core/tools/provider/builtin/google/google.py index 8f4b9a4a4e..6b5395f9d3 100644 --- a/api/core/tools/provider/builtin/google/google.py +++ b/api/core/tools/provider/builtin/google/google.py @@ -13,12 +13,8 @@ class GoogleProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "test", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "test", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/google/tools/google_search.py b/api/core/tools/provider/builtin/google/tools/google_search.py index 09d0326fb4..a9f65925d8 100644 --- a/api/core/tools/provider/builtin/google/tools/google_search.py +++ b/api/core/tools/provider/builtin/google/tools/google_search.py @@ -9,7 +9,6 @@ SERP_API_URL = "https://serpapi.com/search" class GoogleSearchTool(BuiltinTool): - def _parse_response(self, response: dict) -> dict: result = {} if "knowledge_graph" in response: @@ -17,25 +16,23 @@ class GoogleSearchTool(BuiltinTool): result["description"] = response["knowledge_graph"].get("description", "") if "organic_results" in response: result["organic_results"] = [ - { - "title": item.get("title", ""), - "link": item.get("link", ""), - "snippet": item.get("snippet", "") - } + {"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")} for item in response["organic_results"] ] return result - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: params = { - "api_key": self.runtime.credentials['serpapi_api_key'], - "q": tool_parameters['query'], + "api_key": self.runtime.credentials["serpapi_api_key"], + "q": tool_parameters["query"], "engine": "google", "google_domain": "google.com", "gl": "us", - "hl": "en" + "hl": "en", } response = requests.get(url=SERP_API_URL, params=params) response.raise_for_status() diff --git a/api/core/tools/provider/builtin/google_translate/google_translate.py b/api/core/tools/provider/builtin/google_translate/google_translate.py index f6e1d65834..ea53aa4eeb 100644 --- a/api/core/tools/provider/builtin/google_translate/google_translate.py +++ b/api/core/tools/provider/builtin/google_translate/google_translate.py @@ -8,10 +8,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class JsonExtractProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - GoogleTranslate().invoke(user_id='', - tool_parameters={ - "content": "这是一段测试文本", - "dest": "en" - }) + GoogleTranslate().invoke(user_id="", tool_parameters={"content": "这是一段测试文本", "dest": "en"}) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/google_translate/tools/translate.py b/api/core/tools/provider/builtin/google_translate/tools/translate.py index 4314182b06..5d57b5fabf 100644 --- a/api/core/tools/provider/builtin/google_translate/tools/translate.py +++ b/api/core/tools/provider/builtin/google_translate/tools/translate.py @@ -7,46 +7,40 @@ from core.tools.tool.builtin_tool import BuiltinTool class GoogleTranslate(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - dest = tool_parameters.get('dest', '') + dest = tool_parameters.get("dest", "") if not dest: - return self.create_text_message('Invalid parameter destination language') + return self.create_text_message("Invalid parameter destination language") try: result = self._translate(content, dest) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Translation service error, please check the network') + return self.create_text_message("Translation service error, please check the network") def _translate(self, content: str, dest: str) -> str: try: url = "https://translate.googleapis.com/translate_a/single" - params = { - "client": "gtx", - "sl": "auto", - "tl": dest, - "dt": "t", - "q": content - } + params = {"client": "gtx", "sl": "auto", "tl": dest, "dt": "t", "q": content} headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" } - response_json = requests.get( - url, params=params, headers=headers).json() + response_json = requests.get(url, params=params, headers=headers).json() result = response_json[0] - translated_text = ''.join([item[0] for item in result if item[0]]) + translated_text = "".join([item[0] for item in result if item[0]]) return str(translated_text) except Exception as e: return str(e) diff --git a/api/core/tools/provider/builtin/hap/hap.py b/api/core/tools/provider/builtin/hap/hap.py index e0a48e05a5..cbdf950465 100644 --- a/api/core/tools/provider/builtin/hap/hap.py +++ b/api/core/tools/provider/builtin/hap/hap.py @@ -5,4 +5,4 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class HapProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - pass \ No newline at end of file + pass diff --git a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py index 0e101dc67d..f2288ed81c 100644 --- a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py @@ -8,41 +8,40 @@ from core.tools.tool.builtin_tool import BuiltinTool class AddWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - record_data = tool_parameters.get('record_data', '') + return self.create_text_message("Invalid parameter Worksheet ID") + record_data = tool_parameters.get("record_data", "") if not record_data: - return self.create_text_message('Invalid parameter Record Row Data') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row Data") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v2/open/worksheet/addRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} try: - payload['controls'] = json.loads(record_data) + payload["controls"] = json.loads(record_data) res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to add the new record. {res_json['error_msg']}") return self.create_text_message(f"New record added successfully. The record ID is {res_json['data']}.") except httpx.RequestError as e: diff --git a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py index ba25952c9f..1df5f6d5cf 100644 --- a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py @@ -7,43 +7,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class DeleteWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - row_id = tool_parameters.get('row_id', '') + return self.create_text_message("Invalid parameter Worksheet ID") + row_id = tool_parameters.get("row_id", "") if not row_id: - return self.create_text_message('Invalid parameter Record Row ID') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row ID") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v2/open/worksheet/deleteRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "rowId": row_id} try: res = httpx.post(url, headers=headers, json=payload, timeout=30) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to delete the record. {res_json['error_msg']}") return self.create_text_message("Successfully deleted the record.") except httpx.RequestError as e: return self.create_text_message(f"Failed to delete the record, request error: {e}") except Exception as e: - return self.create_text_message(f"Failed to delete the record, unexpected error: {e}") \ No newline at end of file + return self.create_text_message(f"Failed to delete the record, unexpected error: {e}") diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py index 2c46d9dd4e..69cf8aa740 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py @@ -8,43 +8,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class GetWorksheetFieldsTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Worksheet ID") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v2/open/worksheet/getWorksheetInfo" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} try: res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to get the worksheet information. {res_json['error_msg']}") - - fields_json, fields_table = self.get_controls(res_json['data']['controls']) - result_type = tool_parameters.get('result_type', 'table') + + fields_json, fields_table = self.get_controls(res_json["data"]["controls"]) + result_type = tool_parameters.get("result_type", "table") return self.create_text_message( - text=json.dumps(fields_json, ensure_ascii=False) if result_type == 'json' else fields_table + text=json.dumps(fields_json, ensure_ascii=False) if result_type == "json" else fields_table ) except httpx.RequestError as e: return self.create_text_message(f"Failed to get the worksheet information, request error: {e}") @@ -88,61 +87,65 @@ class GetWorksheetFieldsTool(BuiltinTool): 50: "Text", 51: "Query Record", } - return field_type_map.get(field_type_id, '') + return field_type_map.get(field_type_id, "") def get_controls(self, controls: list) -> dict: fields = [] - fields_list = ['|fieldId|fieldName|fieldType|fieldTypeId|description|options|','|'+'---|'*6] + fields_list = ["|fieldId|fieldName|fieldType|fieldTypeId|description|options|", "|" + "---|" * 6] for control in controls: - if control['type'] in self._get_ignore_types(): + if control["type"] in self._get_ignore_types(): continue - field_type_id = control['type'] - field_type = self.get_field_type_by_id(control['type']) + field_type_id = control["type"] + field_type = self.get_field_type_by_id(control["type"]) if field_type_id == 30: - source_type = control['sourceControl']['type'] + source_type = control["sourceControl"]["type"] if source_type in self._get_ignore_types(): continue else: field_type_id = source_type field_type = self.get_field_type_by_id(source_type) field = { - 'id': control['controlId'], - 'name': control['controlName'], - 'type': field_type, - 'typeId': field_type_id, - 'description': control['remark'].replace('\n', ' ').replace('\t', ' '), - 'options': self._extract_options(control), + "id": control["controlId"], + "name": control["controlName"], + "type": field_type, + "typeId": field_type_id, + "description": control["remark"].replace("\n", " ").replace("\t", " "), + "options": self._extract_options(control), } fields.append(field) - fields_list.append(f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}|{field['options'] if field['options'] else ''}|") + fields_list.append( + f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}|{field['options'] if field['options'] else ''}|" + ) - fields.append({ - 'id': 'ctime', - 'name': 'Created Time', - 'type': self.get_field_type_by_id(16), - 'typeId': 16, - 'description': '', - 'options': [] - }) + fields.append( + { + "id": "ctime", + "name": "Created Time", + "type": self.get_field_type_by_id(16), + "typeId": 16, + "description": "", + "options": [], + } + ) fields_list.append("|ctime|Created Time|Date|16|||") - return fields, '\n'.join(fields_list) + return fields, "\n".join(fields_list) def _extract_options(self, control: dict) -> list: options = [] - if control['type'] in [9, 10, 11]: - options.extend([{"key": opt['key'], "value": opt['value']} for opt in control.get('options', [])]) - elif control['type'] in [28, 36]: - itemnames = control['advancedSetting'].get('itemnames') - if itemnames and itemnames.startswith('[{'): + if control["type"] in [9, 10, 11]: + options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) + elif control["type"] in [28, 36]: + itemnames = control["advancedSetting"].get("itemnames") + if itemnames and itemnames.startswith("[{"): try: options = json.loads(itemnames) except json.JSONDecodeError: pass - elif control['type'] == 30: - source_type = control['sourceControl']['type'] + elif control["type"] == 30: + source_type = control["sourceControl"]["type"] if source_type not in self._get_ignore_types(): - options.extend([{"key": opt['key'], "value": opt['value']} for opt in control.get('options', [])]) + options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) return options - + def _get_ignore_types(self): - return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} \ No newline at end of file + return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py index 6bf1caa65e..6b831f3145 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py @@ -8,64 +8,66 @@ from core.tools.tool.builtin_tool import BuiltinTool class GetWorksheetPivotDataTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - x_column_fields = tool_parameters.get('x_column_fields', '') - if not x_column_fields or not x_column_fields.startswith('['): - return self.create_text_message('Invalid parameter Column Fields') - y_row_fields = tool_parameters.get('y_row_fields', '') - if y_row_fields and not y_row_fields.strip().startswith('['): - return self.create_text_message('Invalid parameter Row Fields') + return self.create_text_message("Invalid parameter Worksheet ID") + x_column_fields = tool_parameters.get("x_column_fields", "") + if not x_column_fields or not x_column_fields.startswith("["): + return self.create_text_message("Invalid parameter Column Fields") + y_row_fields = tool_parameters.get("y_row_fields", "") + if y_row_fields and not y_row_fields.strip().startswith("["): + return self.create_text_message("Invalid parameter Row Fields") elif not y_row_fields: - y_row_fields = '[]' - value_fields = tool_parameters.get('value_fields', '') - if not value_fields or not value_fields.strip().startswith('['): - return self.create_text_message('Invalid parameter Value Fields') - - host = tool_parameters.get('host', '') + y_row_fields = "[]" + value_fields = tool_parameters.get("value_fields", "") + if not value_fields or not value_fields.strip().startswith("["): + return self.create_text_message("Invalid parameter Value Fields") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/report/getPivotData" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "options": {"showTotal": True}} try: x_column_fields = json.loads(x_column_fields) - payload['columns'] = x_column_fields + payload["columns"] = x_column_fields y_row_fields = json.loads(y_row_fields) - if y_row_fields: payload['rows'] = y_row_fields + if y_row_fields: + payload["rows"] = y_row_fields value_fields = json.loads(value_fields) - payload['values'] = value_fields - sort_fields = tool_parameters.get('sort_fields', '') - if not sort_fields: sort_fields = '[]' + payload["values"] = value_fields + sort_fields = tool_parameters.get("sort_fields", "") + if not sort_fields: + sort_fields = "[]" sort_fields = json.loads(sort_fields) - if sort_fields: payload['options']['sort'] = sort_fields + if sort_fields: + payload["options"]["sort"] = sort_fields res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('status') != 1: + if res_json.get("status") != 1: return self.create_text_message(f"Failed to get the worksheet pivot data. {res_json['msg']}") - - pivot_json = self.generate_pivot_json(res_json['data']) - pivot_table = self.generate_pivot_table(res_json['data']) - result_type = tool_parameters.get('result_type', '') - text = pivot_table if result_type == 'table' else json.dumps(pivot_json, ensure_ascii=False) + + pivot_json = self.generate_pivot_json(res_json["data"]) + pivot_table = self.generate_pivot_table(res_json["data"]) + result_type = tool_parameters.get("result_type", "") + text = pivot_table if result_type == "table" else json.dumps(pivot_json, ensure_ascii=False) return self.create_text_message(text) except httpx.RequestError as e: return self.create_text_message(f"Failed to get the worksheet pivot data, request error: {e}") @@ -75,27 +77,31 @@ class GetWorksheetPivotDataTool(BuiltinTool): return self.create_text_message(f"Failed to get the worksheet pivot data, unexpected error: {e}") def generate_pivot_table(self, data: dict[str, Any]) -> str: - columns = data['metadata']['columns'] - rows = data['metadata']['rows'] - values = data['metadata']['values'] + columns = data["metadata"]["columns"] + rows = data["metadata"]["rows"] + values = data["metadata"]["values"] - rows_data = data['data'] + rows_data = data["data"] - header = ([row['displayName'] for row in rows] if rows else []) + [column['displayName'] for column in columns] + [value['displayName'] for value in values] - line = (['---'] * len(rows) if rows else []) + ['---'] * len(columns) + ['--:'] * len(values) + header = ( + ([row["displayName"] for row in rows] if rows else []) + + [column["displayName"] for column in columns] + + [value["displayName"] for value in values] + ) + line = (["---"] * len(rows) if rows else []) + ["---"] * len(columns) + ["--:"] * len(values) table = [header, line] for row in rows_data: - row_data = [self.replace_pipe(row['rows'][r['controlId']]) for r in rows] if rows else [] - row_data.extend([self.replace_pipe(row['columns'][column['controlId']]) for column in columns]) - row_data.extend([self.replace_pipe(str(row['values'][value['controlId']])) for value in values]) + row_data = [self.replace_pipe(row["rows"][r["controlId"]]) for r in rows] if rows else [] + row_data.extend([self.replace_pipe(row["columns"][column["controlId"]]) for column in columns]) + row_data.extend([self.replace_pipe(str(row["values"][value["controlId"]])) for value in values]) table.append(row_data) - return '\n'.join([('|'+'|'.join(row) +'|') for row in table]) - + return "\n".join([("|" + "|".join(row) + "|") for row in table]) + def replace_pipe(self, text: str) -> str: - return text.replace('|', '▏').replace('\n', ' ') - + return text.replace("|", "▏").replace("\n", " ") + def generate_pivot_json(self, data: dict[str, Any]) -> dict: fields = { "x-axis": [ @@ -103,13 +109,14 @@ class GetWorksheetPivotDataTool(BuiltinTool): for column in data["metadata"]["columns"] ], "y-axis": [ - {"fieldId": row["controlId"], "fieldName": row["displayName"]} - for row in data["metadata"]["rows"] - ] if data["metadata"]["rows"] else [], + {"fieldId": row["controlId"], "fieldName": row["displayName"]} for row in data["metadata"]["rows"] + ] + if data["metadata"]["rows"] + else [], "values": [ {"fieldId": value["controlId"], "fieldName": value["displayName"]} for value in data["metadata"]["values"] - ] + ], } # fields = ([ # {"fieldId": row["controlId"], "fieldName": row["displayName"]} @@ -127,4 +134,4 @@ class GetWorksheetPivotDataTool(BuiltinTool): row_data.update(row["columns"]) row_data.update(row["values"]) rows.append(row_data) - return {"fields": fields, "rows": rows, "summary": data["metadata"]["totalRow"]} \ No newline at end of file + return {"fields": fields, "rows": rows, "summary": data["metadata"]["totalRow"]} diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py index dddc041cc1..7e9f70f8e5 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py @@ -9,152 +9,173 @@ from core.tools.tool.builtin_tool import BuiltinTool class ListWorksheetRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') + return self.create_text_message("Invalid parameter App Key") - sign = tool_parameters.get('sign', '') + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') + return self.create_text_message("Invalid parameter Sign") - worksheet_id = tool_parameters.get('worksheet_id', '') + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') + return self.create_text_message("Invalid parameter Worksheet ID") - host = tool_parameters.get('host', '') + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not (host.startswith("http://") or host.startswith("https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" - + url_fields = f"{host}/v2/open/worksheet/getWorksheetInfo" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} - field_ids = tool_parameters.get('field_ids', '') + field_ids = tool_parameters.get("field_ids", "") try: res = httpx.post(url_fields, headers=headers, json=payload, timeout=30) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to get the worksheet information. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message( + "Failed to get the worksheet information. {}".format(res_json["error_msg"]) + ) else: - worksheet_name = res_json['data']['name'] - fields, schema, table_header = self.get_schema(res_json['data']['controls'], field_ids) + worksheet_name = res_json["data"]["name"] + fields, schema, table_header = self.get_schema(res_json["data"]["controls"], field_ids) else: return self.create_text_message( - f"Failed to get the worksheet information, status code: {res.status_code}, response: {res.text}") + f"Failed to get the worksheet information, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to get the worksheet information, something went wrong: {}".format(e)) + return self.create_text_message( + "Failed to get the worksheet information, something went wrong: {}".format(e) + ) if field_ids: - payload['controls'] = [v.strip() for v in field_ids.split(',')] if field_ids else [] - filters = tool_parameters.get('filters', '') + payload["controls"] = [v.strip() for v in field_ids.split(",")] if field_ids else [] + filters = tool_parameters.get("filters", "") if filters: - payload['filters'] = json.loads(filters) - sort_id = tool_parameters.get('sort_id', '') - sort_is_asc = tool_parameters.get('sort_is_asc', False) + payload["filters"] = json.loads(filters) + sort_id = tool_parameters.get("sort_id", "") + sort_is_asc = tool_parameters.get("sort_is_asc", False) if sort_id: - payload['sortId'] = sort_id - payload['isAsc'] = sort_is_asc - limit = tool_parameters.get('limit', 50) - payload['pageSize'] = limit - page_index = tool_parameters.get('page_index', 1) - payload['pageIndex'] = page_index - payload['useControlId'] = True - payload['listType'] = 1 + payload["sortId"] = sort_id + payload["isAsc"] = sort_is_asc + limit = tool_parameters.get("limit", 50) + payload["pageSize"] = limit + page_index = tool_parameters.get("page_index", 1) + payload["pageIndex"] = page_index + payload["useControlId"] = True + payload["listType"] = 1 url = f"{host}/v2/open/worksheet/getFilterRows" try: res = httpx.post(url, headers=headers, json=payload, timeout=90) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to get the records. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message("Failed to get the records. {}".format(res_json["error_msg"])) else: result = { "fields": fields, "rows": [], "total": res_json.get("data", {}).get("total"), - "payload": {key: payload[key] for key in ['worksheetId', 'controls', 'filters', 'sortId', 'isAsc', 'pageSize', 'pageIndex'] if key in payload} + "payload": { + key: payload[key] + for key in [ + "worksheetId", + "controls", + "filters", + "sortId", + "isAsc", + "pageSize", + "pageIndex", + ] + if key in payload + }, } rows = res_json.get("data", {}).get("rows", []) - result_type = tool_parameters.get('result_type', '') - if not result_type: result_type = 'table' - if result_type == 'json': + result_type = tool_parameters.get("result_type", "") + if not result_type: + result_type = "table" + if result_type == "json": for row in rows: - result['rows'].append(self.get_row_field_value(row, schema)) + result["rows"].append(self.get_row_field_value(row, schema)) return self.create_text_message(json.dumps(result, ensure_ascii=False)) else: result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"." - if result['total'] > 0: + if result["total"] > 0: result_text += f" The following are {result['total'] if result['total'] < limit else limit} pieces of data presented in a table format:\n\n{table_header}" for row in rows: result_values = [] for f in fields: - result_values.append(self.handle_value_type(row[f['fieldId']], schema[f['fieldId']])) - result_text += '\n|'+'|'.join(result_values)+'|' + result_values.append( + self.handle_value_type(row[f["fieldId"]], schema[f["fieldId"]]) + ) + result_text += "\n|" + "|".join(result_values) + "|" return self.create_text_message(result_text) else: return self.create_text_message( - f"Failed to get the records, status code: {res.status_code}, response: {res.text}") + f"Failed to get the records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get the records, something went wrong: {}".format(e)) - def get_row_field_value(self, row: dict, schema: dict): row_value = {"rowid": row["rowid"]} for field in schema: row_value[field] = self.handle_value_type(row[field], schema[field]) return row_value - - def get_schema(self, controls: list, fieldids: str): - allow_fields = {v.strip() for v in fieldids.split(',')} if fieldids else set() + def get_schema(self, controls: list, fieldids: str): + allow_fields = {v.strip() for v in fieldids.split(",")} if fieldids else set() fields = [] schema = {} field_names = [] for control in controls: control_type_id = self.get_real_type_id(control) - if (control_type_id in self._get_ignore_types()) or (allow_fields and not control['controlId'] in allow_fields): + if (control_type_id in self._get_ignore_types()) or ( + allow_fields and not control["controlId"] in allow_fields + ): continue else: - fields.append({'fieldId': control['controlId'], 'fieldName': control['controlName']}) - schema[control['controlId']] = {'typeId': control_type_id, 'options': self.set_option(control)} - field_names.append(control['controlName']) - if (not allow_fields or ('ctime' in allow_fields)): - fields.append({'fieldId': 'ctime', 'fieldName': 'Created Time'}) - schema['ctime'] = {'typeId': 16, 'options': {}} + fields.append({"fieldId": control["controlId"], "fieldName": control["controlName"]}) + schema[control["controlId"]] = {"typeId": control_type_id, "options": self.set_option(control)} + field_names.append(control["controlName"]) + if not allow_fields or ("ctime" in allow_fields): + fields.append({"fieldId": "ctime", "fieldName": "Created Time"}) + schema["ctime"] = {"typeId": 16, "options": {}} field_names.append("Created Time") - fields.append({'fieldId':'rowid', 'fieldName': 'Record Row ID'}) - schema['rowid'] = {'typeId': 2, 'options': {}} + fields.append({"fieldId": "rowid", "fieldName": "Record Row ID"}) + schema["rowid"] = {"typeId": 2, "options": {}} field_names.append("Record Row ID") - return fields, schema, '|'+'|'.join(field_names)+'|\n|'+'---|'*len(field_names) - + return fields, schema, "|" + "|".join(field_names) + "|\n|" + "---|" * len(field_names) + def get_real_type_id(self, control: dict) -> int: - return control['sourceControlType'] if control['type'] == 30 else control['type'] - + return control["sourceControlType"] if control["type"] == 30 else control["type"] + def set_option(self, control: dict) -> dict: options = {} - if control.get('options'): - options = {option['key']: option['value'] for option in control['options']} - elif control.get('advancedSetting', {}).get('itemnames'): + if control.get("options"): + options = {option["key"]: option["value"] for option in control["options"]} + elif control.get("advancedSetting", {}).get("itemnames"): try: - itemnames = json.loads(control['advancedSetting']['itemnames']) - options = {item['key']: item['value'] for item in itemnames} + itemnames = json.loads(control["advancedSetting"]["itemnames"]) + options = {item["key"]: item["value"] for item in itemnames} except json.JSONDecodeError: pass return options def _get_ignore_types(self): return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} - + def handle_value_type(self, value, field): type_id = field.get("typeId") if type_id == 10: @@ -167,33 +188,33 @@ class ListWorksheetRecordsTool(BuiltinTool): value = self.parse_cascade_or_associated(field, value) elif type_id == 40: value = self.parse_location(value) - return self.rich_text_to_plain_text(value) if value else '' + return self.rich_text_to_plain_text(value) if value else "" def process_value(self, value): if isinstance(value, str): - if value.startswith("[{\"accountId\""): + if value.startswith('[{"accountId"'): value = json.loads(value) - value = ', '.join([item['fullname'] for item in value]) - elif value.startswith("[{\"departmentId\""): + value = ", ".join([item["fullname"] for item in value]) + elif value.startswith('[{"departmentId"'): value = json.loads(value) - value = '、'.join([item['departmentName'] for item in value]) - elif value.startswith("[{\"organizeId\""): + value = "、".join([item["departmentName"] for item in value]) + elif value.startswith('[{"organizeId"'): value = json.loads(value) - value = '、'.join([item['organizeName'] for item in value]) - elif value.startswith("[{\"file_id\""): - value = '' - elif value == '[]': - value = '' - elif hasattr(value, 'accountId'): - value = value['fullname'] + value = "、".join([item["organizeName"] for item in value]) + elif value.startswith('[{"file_id"'): + value = "" + elif value == "[]": + value = "" + elif hasattr(value, "accountId"): + value = value["fullname"] return value def parse_cascade_or_associated(self, field, value): - if (field['typeId'] == 35 and value.startswith('[')) or (field['typeId'] == 29 and value.startswith('[{')): + if (field["typeId"] == 35 and value.startswith("[")) or (field["typeId"] == 29 and value.startswith("[{")): value = json.loads(value) - value = value[0]['name'] if len(value) > 0 else '' + value = value[0]["name"] if len(value) > 0 else "" else: - value = '' + value = "" return value def parse_location(self, value): @@ -205,5 +226,5 @@ class ListWorksheetRecordsTool(BuiltinTool): return value def rich_text_to_plain_text(self, rich_text): - text = re.sub(r'<[^>]+>', '', rich_text) if '<' in rich_text else rich_text - return text.replace("|", "▏").replace("\n", " ") \ No newline at end of file + text = re.sub(r"<[^>]+>", "", rich_text) if "<" in rich_text else rich_text + return text.replace("|", "▏").replace("\n", " ") diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py index 960cbd10ac..b4193f00bf 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py @@ -8,75 +8,76 @@ from core.tools.tool.builtin_tool import BuiltinTool class ListWorksheetsTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Sign") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not (host.startswith("http://") or host.startswith("https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v1/open/app/get" - result_type = tool_parameters.get('result_type', '') + result_type = tool_parameters.get("result_type", "") if not result_type: - result_type = 'table' + result_type = "table" - headers = { 'Content-Type': 'application/json' } - params = { "appKey": appkey, "sign": sign, } + headers = {"Content-Type": "application/json"} + params = { + "appKey": appkey, + "sign": sign, + } try: res = httpx.get(url, headers=headers, params=params, timeout=30) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to access the application. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message( + "Failed to access the application. {}".format(res_json["error_msg"]) + ) else: - if result_type == 'json': + if result_type == "json": worksheets = [] - for section in res_json['data']['sections']: + for section in res_json["data"]["sections"]: worksheets.extend(self._extract_worksheets(section, result_type)) return self.create_text_message(text=json.dumps(worksheets, ensure_ascii=False)) else: - worksheets = '|worksheetId|worksheetName|description|\n|---|---|---|' - for section in res_json['data']['sections']: + worksheets = "|worksheetId|worksheetName|description|\n|---|---|---|" + for section in res_json["data"]["sections"]: worksheets += self._extract_worksheets(section, result_type) return self.create_text_message(worksheets) else: return self.create_text_message( - f"Failed to list worksheets, status code: {res.status_code}, response: {res.text}") + f"Failed to list worksheets, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list worksheets, something went wrong: {}".format(e)) def _extract_worksheets(self, section, type): items = [] - tables = '' - for item in section.get('items', []): - if item.get('type') == 0 and (not 'notes' in item or item.get('notes') != 'NO'): - if type == 'json': - filtered_item = { - 'id': item['id'], - 'name': item['name'], - 'notes': item.get('notes', '') - } + tables = "" + for item in section.get("items", []): + if item.get("type") == 0 and (not "notes" in item or item.get("notes") != "NO"): + if type == "json": + filtered_item = {"id": item["id"], "name": item["name"], "notes": item.get("notes", "")} items.append(filtered_item) else: tables += f"\n|{item['id']}|{item['name']}|{item.get('notes', '')}|" - for child_section in section.get('childSections', []): - if type == 'json': - items.extend(self._extract_worksheets(child_section, 'json')) + for child_section in section.get("childSections", []): + if type == "json": + items.extend(self._extract_worksheets(child_section, "json")) else: - tables += self._extract_worksheets(child_section, 'table') - - return items if type == 'json' else tables \ No newline at end of file + tables += self._extract_worksheets(child_section, "table") + + return items if type == "json" else tables diff --git a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py index 6ca1b98d90..32abb18f9a 100644 --- a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py @@ -8,44 +8,43 @@ from core.tools.tool.builtin_tool import BuiltinTool class UpdateWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - row_id = tool_parameters.get('row_id', '') + return self.create_text_message("Invalid parameter Worksheet ID") + row_id = tool_parameters.get("row_id", "") if not row_id: - return self.create_text_message('Invalid parameter Record Row ID') - record_data = tool_parameters.get('record_data', '') + return self.create_text_message("Invalid parameter Record Row ID") + record_data = tool_parameters.get("record_data", "") if not record_data: - return self.create_text_message('Invalid parameter Record Row Data') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row Data") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: host = f"{host[:-1] if host.endswith('/') else host}/api" url = f"{host}/v2/open/worksheet/editRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "rowId": row_id} try: - payload['controls'] = json.loads(record_data) + payload["controls"] = json.loads(record_data) res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to update the record. {res_json['error_msg']}") return self.create_text_message("Record updated successfully.") except httpx.RequestError as e: diff --git a/api/core/tools/provider/builtin/jina/jina.py b/api/core/tools/provider/builtin/jina/jina.py index 12e5058cdc..154e15db01 100644 --- a/api/core/tools/provider/builtin/jina/jina.py +++ b/api/core/tools/provider/builtin/jina/jina.py @@ -10,27 +10,29 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class GoogleProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - if credentials['api_key'] is None: - credentials['api_key'] = '' + if credentials["api_key"] is None: + credentials["api_key"] = "" else: - result = JinaReaderTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - "url": "https://example.com", - }, - )[0] + result = ( + JinaReaderTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "url": "https://example.com", + }, + )[0] + ) message = json.loads(result.message) - if message['code'] != 200: - raise ToolProviderCredentialValidationError(message['message']) + if message["code"] != 200: + raise ToolProviderCredentialValidationError(message["message"]) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - + def _get_tool_labels(self) -> list[ToolLabelEnum]: - return [ - ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY - ] \ No newline at end of file + return [ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY] diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index cee46cee23..0dd55c6529 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -9,26 +9,25 @@ from core.tools.tool.builtin_tool import BuiltinTool class JinaReaderTool(BuiltinTool): - _jina_reader_endpoint = 'https://r.jina.ai/' + _jina_reader_endpoint = "https://r.jina.ai/" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - url = tool_parameters['url'] + url = tool_parameters["url"] - headers = { - 'Accept': 'application/json' - } + headers = {"Accept": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - request_params = tool_parameters.get('request_params') - if request_params is not None and request_params != '': + request_params = tool_parameters.get("request_params") + if request_params is not None and request_params != "": try: request_params = json.loads(request_params) if not isinstance(request_params, dict): @@ -36,40 +35,40 @@ class JinaReaderTool(BuiltinTool): except (json.JSONDecodeError, ValueError) as e: raise ValueError(f"Invalid request_params: {e}") - target_selector = tool_parameters.get('target_selector') - if target_selector is not None and target_selector != '': - headers['X-Target-Selector'] = target_selector + target_selector = tool_parameters.get("target_selector") + if target_selector is not None and target_selector != "": + headers["X-Target-Selector"] = target_selector - wait_for_selector = tool_parameters.get('wait_for_selector') - if wait_for_selector is not None and wait_for_selector != '': - headers['X-Wait-For-Selector'] = wait_for_selector + wait_for_selector = tool_parameters.get("wait_for_selector") + if wait_for_selector is not None and wait_for_selector != "": + headers["X-Wait-For-Selector"] = wait_for_selector - if tool_parameters.get('image_caption', False): - headers['X-With-Generated-Alt'] = 'true' + if tool_parameters.get("image_caption", False): + headers["X-With-Generated-Alt"] = "true" - if tool_parameters.get('gather_all_links_at_the_end', False): - headers['X-With-Links-Summary'] = 'true' + if tool_parameters.get("gather_all_links_at_the_end", False): + headers["X-With-Links-Summary"] = "true" - if tool_parameters.get('gather_all_images_at_the_end', False): - headers['X-With-Images-Summary'] = 'true' + if tool_parameters.get("gather_all_images_at_the_end", False): + headers["X-With-Images-Summary"] = "true" - proxy_server = tool_parameters.get('proxy_server') - if proxy_server is not None and proxy_server != '': - headers['X-Proxy-Url'] = proxy_server + proxy_server = tool_parameters.get("proxy_server") + if proxy_server is not None and proxy_server != "": + headers["X-Proxy-Url"] = proxy_server - if tool_parameters.get('no_cache', False): - headers['X-No-Cache'] = 'true' + if tool_parameters.get("no_cache", False): + headers["X-No-Cache"] = "true" - max_retries = tool_parameters.get('max_retries', 3) + max_retries = tool_parameters.get("max_retries", 3) response = ssrf_proxy.get( str(URL(self._jina_reader_endpoint + url)), headers=headers, params=request_params, timeout=(10, 60), - max_retries=max_retries + max_retries=max_retries, ) - if tool_parameters.get('summary', False): + if tool_parameters.get("summary", False): return self.create_text_message(self.summary(user_id, response.text)) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.py b/api/core/tools/provider/builtin/jina/tools/jina_search.py index d4a81cd096..30af6de783 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.py @@ -8,44 +8,39 @@ from core.tools.tool.builtin_tool import BuiltinTool class JinaSearchTool(BuiltinTool): - _jina_search_endpoint = 'https://s.jina.ai/' + _jina_search_endpoint = "https://s.jina.ai/" def _invoke( self, user_id: str, tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters['query'] + query = tool_parameters["query"] - headers = { - 'Accept': 'application/json' - } + headers = {"Accept": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - if tool_parameters.get('image_caption', False): - headers['X-With-Generated-Alt'] = 'true' + if tool_parameters.get("image_caption", False): + headers["X-With-Generated-Alt"] = "true" - if tool_parameters.get('gather_all_links_at_the_end', False): - headers['X-With-Links-Summary'] = 'true' + if tool_parameters.get("gather_all_links_at_the_end", False): + headers["X-With-Links-Summary"] = "true" - if tool_parameters.get('gather_all_images_at_the_end', False): - headers['X-With-Images-Summary'] = 'true' + if tool_parameters.get("gather_all_images_at_the_end", False): + headers["X-With-Images-Summary"] = "true" - proxy_server = tool_parameters.get('proxy_server') - if proxy_server is not None and proxy_server != '': - headers['X-Proxy-Url'] = proxy_server + proxy_server = tool_parameters.get("proxy_server") + if proxy_server is not None and proxy_server != "": + headers["X-Proxy-Url"] = proxy_server - if tool_parameters.get('no_cache', False): - headers['X-No-Cache'] = 'true' + if tool_parameters.get("no_cache", False): + headers["X-No-Cache"] = "true" - max_retries = tool_parameters.get('max_retries', 3) + max_retries = tool_parameters.get("max_retries", 3) response = ssrf_proxy.get( - str(URL(self._jina_search_endpoint + query)), - headers=headers, - timeout=(10, 60), - max_retries=max_retries + str(URL(self._jina_search_endpoint + query)), headers=headers, timeout=(10, 60), max_retries=max_retries ) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py index 0d018e3ca2..06dabcc9c2 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py @@ -6,33 +6,29 @@ from core.tools.tool.builtin_tool import BuiltinTool class JinaTokenizerTool(BuiltinTool): - _jina_tokenizer_endpoint = 'https://tokenize.jina.ai/' + _jina_tokenizer_endpoint = "https://tokenize.jina.ai/" def _invoke( self, user_id: str, tool_parameters: dict[str, Any], ) -> ToolInvokeMessage: - content = tool_parameters['content'] - body = { - "content": content - } + content = tool_parameters["content"] + body = {"content": content} - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - if tool_parameters.get('return_chunks', False): - body['return_chunks'] = True - - if tool_parameters.get('return_tokens', False): - body['return_tokens'] = True - - if tokenizer := tool_parameters.get('tokenizer'): - body['tokenizer'] = tokenizer + if tool_parameters.get("return_chunks", False): + body["return_chunks"] = True + + if tool_parameters.get("return_tokens", False): + body["return_tokens"] = True + + if tokenizer := tool_parameters.get("tokenizer"): + body["tokenizer"] = tokenizer response = ssrf_proxy.post( self._jina_tokenizer_endpoint, diff --git a/api/core/tools/provider/builtin/json_process/json_process.py b/api/core/tools/provider/builtin/json_process/json_process.py index f6eed3c628..10746210b5 100644 --- a/api/core/tools/provider/builtin/json_process/json_process.py +++ b/api/core/tools/provider/builtin/json_process/json_process.py @@ -8,10 +8,9 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class JsonExtractProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - JSONParseTool().invoke(user_id='', - tool_parameters={ - 'content': '{"name": "John", "age": 30, "city": "New York"}', - 'json_filter': '$.name' - }) + JSONParseTool().invoke( + user_id="", + tool_parameters={"content": '{"name": "John", "age": 30, "city": "New York"}', "json_filter": "$.name"}, + ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/json_process/tools/delete.py b/api/core/tools/provider/builtin/json_process/tools/delete.py index 1b49cfe2f3..fcab3d71a9 100644 --- a/api/core/tools/provider/builtin/json_process/tools/delete.py +++ b/api/core/tools/provider/builtin/json_process/tools/delete.py @@ -8,34 +8,35 @@ from core.tools.tool.builtin_tool import BuiltinTool class JSONDeleteTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the JSON delete tool """ # Get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # Get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._delete(content, query, ensure_ascii) return self.create_text_message(str(result)) except Exception as e: - return self.create_text_message(f'Failed to delete JSON content: {str(e)}') + return self.create_text_message(f"Failed to delete JSON content: {str(e)}") def _delete(self, origin_json: str, query: str, ensure_ascii: bool) -> str: try: input_data = json.loads(origin_json) - expr = parse('$.' + query.lstrip('$.')) # Ensure query path starts with $ + expr = parse("$." + query.lstrip("$.")) # Ensure query path starts with $ matches = expr.find(input_data) diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.py b/api/core/tools/provider/builtin/json_process/tools/insert.py index 48d1bdcab4..793c74e5f9 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.py +++ b/api/core/tools/provider/builtin/json_process/tools/insert.py @@ -8,46 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool class JSONParseTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") # get new value - new_value = tool_parameters.get('new_value', '') + new_value = tool_parameters.get("new_value", "") if not new_value: - return self.create_text_message('Invalid parameter new_value') + return self.create_text_message("Invalid parameter new_value") # get insert position - index = tool_parameters.get('index') + index = tool_parameters.get("index") # get create path - create_path = tool_parameters.get('create_path', False) + create_path = tool_parameters.get("create_path", False) # get value decode. # if true, it will be decoded to an dict - value_decode = tool_parameters.get('value_decode', False) + value_decode = tool_parameters.get("value_decode", False) - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._insert(content, query, new_value, ensure_ascii, value_decode, index, create_path) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to insert JSON content') + return self.create_text_message("Failed to insert JSON content") - def _insert(self, origin_json, query, new_value, ensure_ascii: bool, value_decode: bool, index=None, create_path=False): + def _insert( + self, origin_json, query, new_value, ensure_ascii: bool, value_decode: bool, index=None, create_path=False + ): try: input_data = json.loads(origin_json) expr = parse(query) @@ -61,13 +64,13 @@ class JSONParseTool(BuiltinTool): if not matches and create_path: # create new path - path_parts = query.strip('$').strip('.').split('.') + path_parts = query.strip("$").strip(".").split(".") current = input_data for i, part in enumerate(path_parts): - if '[' in part and ']' in part: + if "[" in part and "]" in part: # process array index - array_name, index = part.split('[') - index = int(index.rstrip(']')) + array_name, index = part.split("[") + index = int(index.rstrip("]")) if array_name not in current: current[array_name] = [] while len(current[array_name]) <= index: diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.py b/api/core/tools/provider/builtin/json_process/tools/parse.py index ecd39113ae..37cae40153 100644 --- a/api/core/tools/provider/builtin/json_process/tools/parse.py +++ b/api/core/tools/provider/builtin/json_process/tools/parse.py @@ -8,29 +8,30 @@ from core.tools.tool.builtin_tool import BuiltinTool class JSONParseTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get json filter - json_filter = tool_parameters.get('json_filter', '') + json_filter = tool_parameters.get("json_filter", "") if not json_filter: - return self.create_text_message('Invalid parameter json_filter') + return self.create_text_message("Invalid parameter json_filter") - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._extract(content, json_filter, ensure_ascii) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to extract JSON content') + return self.create_text_message("Failed to extract JSON content") # Extract data from JSON content def _extract(self, content: str, json_filter: str, ensure_ascii: bool) -> str: diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.py b/api/core/tools/provider/builtin/json_process/tools/replace.py index b19198aa93..383825c2d0 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.py +++ b/api/core/tools/provider/builtin/json_process/tools/replace.py @@ -8,55 +8,60 @@ from core.tools.tool.builtin_tool import BuiltinTool class JSONReplaceTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") # get replace value - replace_value = tool_parameters.get('replace_value', '') + replace_value = tool_parameters.get("replace_value", "") if not replace_value: - return self.create_text_message('Invalid parameter replace_value') + return self.create_text_message("Invalid parameter replace_value") # get replace model - replace_model = tool_parameters.get('replace_model', '') + replace_model = tool_parameters.get("replace_model", "") if not replace_model: - return self.create_text_message('Invalid parameter replace_model') + return self.create_text_message("Invalid parameter replace_model") # get value decode. # if true, it will be decoded to an dict - value_decode = tool_parameters.get('value_decode', False) + value_decode = tool_parameters.get("value_decode", False) - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: - if replace_model == 'pattern': + if replace_model == "pattern": # get replace pattern - replace_pattern = tool_parameters.get('replace_pattern', '') + replace_pattern = tool_parameters.get("replace_pattern", "") if not replace_pattern: - return self.create_text_message('Invalid parameter replace_pattern') - result = self._replace_pattern(content, query, replace_pattern, replace_value, ensure_ascii, value_decode) - elif replace_model == 'key': + return self.create_text_message("Invalid parameter replace_pattern") + result = self._replace_pattern( + content, query, replace_pattern, replace_value, ensure_ascii, value_decode + ) + elif replace_model == "key": result = self._replace_key(content, query, replace_value, ensure_ascii) - elif replace_model == 'value': + elif replace_model == "value": result = self._replace_value(content, query, replace_value, ensure_ascii, value_decode) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to replace JSON content') + return self.create_text_message("Failed to replace JSON content") # Replace pattern - def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool, value_decode: bool) -> str: + def _replace_pattern( + self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool, value_decode: bool + ) -> str: try: input_data = json.loads(content) expr = parse(query) @@ -102,7 +107,9 @@ class JSONReplaceTool(BuiltinTool): return str(e) # Replace value - def _replace_value(self, content: str, query: str, replace_value: str, ensure_ascii: bool, value_decode: bool) -> str: + def _replace_value( + self, content: str, query: str, replace_value: str, ensure_ascii: bool, value_decode: bool + ) -> str: try: input_data = json.loads(content) expr = parse(query) diff --git a/api/core/tools/provider/builtin/judge0ce/judge0ce.py b/api/core/tools/provider/builtin/judge0ce/judge0ce.py index bac6576797..50db74dd9e 100644 --- a/api/core/tools/provider/builtin/judge0ce/judge0ce.py +++ b/api/core/tools/provider/builtin/judge0ce/judge0ce.py @@ -13,7 +13,7 @@ class Judge0CEProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "source_code": "print('hello world')", "language_id": 71, @@ -21,4 +21,3 @@ class Judge0CEProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py b/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py index 6031687c03..b8d654ff63 100644 --- a/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py +++ b/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py @@ -9,11 +9,13 @@ from core.tools.tool.builtin_tool import BuiltinTool class ExecuteCodeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools """ - api_key = self.runtime.credentials['X-RapidAPI-Key'] + api_key = self.runtime.credentials["X-RapidAPI-Key"] url = "https://judge0-ce.p.rapidapi.com/submissions" @@ -22,15 +24,15 @@ class ExecuteCodeTool(BuiltinTool): headers = { "Content-Type": "application/json", "X-RapidAPI-Key": api_key, - "X-RapidAPI-Host": "judge0-ce.p.rapidapi.com" + "X-RapidAPI-Host": "judge0-ce.p.rapidapi.com", } payload = { - "language_id": tool_parameters['language_id'], - "source_code": tool_parameters['source_code'], - "stdin": tool_parameters.get('stdin', ''), - "expected_output": tool_parameters.get('expected_output', ''), - "additional_files": tool_parameters.get('additional_files', ''), + "language_id": tool_parameters["language_id"], + "source_code": tool_parameters["source_code"], + "stdin": tool_parameters.get("stdin", ""), + "expected_output": tool_parameters.get("expected_output", ""), + "additional_files": tool_parameters.get("additional_files", ""), } response = post(url, data=json.dumps(payload), headers=headers, params=querystring) @@ -38,22 +40,22 @@ class ExecuteCodeTool(BuiltinTool): if response.status_code != 201: raise Exception(response.text) - token = response.json()['token'] + token = response.json()["token"] url = f"https://judge0-ce.p.rapidapi.com/submissions/{token}" - headers = { - "X-RapidAPI-Key": api_key - } - + headers = {"X-RapidAPI-Key": api_key} + response = requests.get(url, headers=headers) if response.status_code == 200: result = response.json() - return self.create_text_message(text=f"stdout: {result.get('stdout', '')}\n" - f"stderr: {result.get('stderr', '')}\n" - f"compile_output: {result.get('compile_output', '')}\n" - f"message: {result.get('message', '')}\n" - f"status: {result['status']['description']}\n" - f"time: {result.get('time', '')} seconds\n" - f"memory: {result.get('memory', '')} bytes") + return self.create_text_message( + text=f"stdout: {result.get('stdout', '')}\n" + f"stderr: {result.get('stderr', '')}\n" + f"compile_output: {result.get('compile_output', '')}\n" + f"message: {result.get('message', '')}\n" + f"status: {result['status']['description']}\n" + f"time: {result.get('time', '')} seconds\n" + f"memory: {result.get('memory', '')} bytes" + ) else: - return self.create_text_message(text=f"Error retrieving submission details: {response.text}") \ No newline at end of file + return self.create_text_message(text=f"Error retrieving submission details: {response.text}") diff --git a/api/core/tools/provider/builtin/maths/maths.py b/api/core/tools/provider/builtin/maths/maths.py index 7226a5c168..d4b449ec87 100644 --- a/api/core/tools/provider/builtin/maths/maths.py +++ b/api/core/tools/provider/builtin/maths/maths.py @@ -9,9 +9,9 @@ class MathsProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: EvaluateExpressionTool().invoke( - user_id='', + user_id="", tool_parameters={ - 'expression': '1+(2+3)*4', + "expression": "1+(2+3)*4", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/maths/tools/eval_expression.py b/api/core/tools/provider/builtin/maths/tools/eval_expression.py index bf73ed6918..0c5b5e41cb 100644 --- a/api/core/tools/provider/builtin/maths/tools/eval_expression.py +++ b/api/core/tools/provider/builtin/maths/tools/eval_expression.py @@ -8,22 +8,23 @@ from core.tools.tool.builtin_tool import BuiltinTool class EvaluateExpressionTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get expression - expression = tool_parameters.get('expression', '').strip() + expression = tool_parameters.get("expression", "").strip() if not expression: - return self.create_text_message('Invalid expression') + return self.create_text_message("Invalid expression") try: result = ne.evaluate(expression) result_str = str(result) except Exception as e: - logging.exception(f'Error evaluating expression: {expression}') - return self.create_text_message(f'Invalid expression: {expression}, error: {str(e)}') - return self.create_text_message(f'The result of the expression "{expression}" is {result_str}') \ No newline at end of file + logging.exception(f"Error evaluating expression: {expression}") + return self.create_text_message(f"Invalid expression: {expression}, error: {str(e)}") + return self.create_text_message(f'The result of the expression "{expression}" is {result_str}') diff --git a/api/core/tools/provider/builtin/nominatim/nominatim.py b/api/core/tools/provider/builtin/nominatim/nominatim.py index b6f29b5feb..5a24bed750 100644 --- a/api/core/tools/provider/builtin/nominatim/nominatim.py +++ b/api/core/tools/provider/builtin/nominatim/nominatim.py @@ -8,16 +8,20 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class NominatimProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - result = NominatimSearchTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - 'query': 'London', - 'limit': 1, - }, + result = ( + NominatimSearchTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "query": "London", + "limit": 1, + }, + ) ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py index e21ce14f54..ffa8ad0fcc 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py @@ -8,40 +8,33 @@ from core.tools.tool.builtin_tool import BuiltinTool class NominatimLookupTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - osm_ids = tool_parameters.get('osm_ids', '') - - if not osm_ids: - return self.create_text_message('Please provide OSM IDs') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + osm_ids = tool_parameters.get("osm_ids", "") - params = { - 'osm_ids': osm_ids, - 'format': 'json', - 'addressdetails': 1 - } - - return self._make_request(user_id, 'lookup', params) + if not osm_ids: + return self.create_text_message("Please provide OSM IDs") + + params = {"osm_ids": osm_ids, "format": "json", "addressdetails": 1} + + return self._make_request(user_id, "lookup", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py index 438d5219e9..f46691e1a3 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py @@ -8,42 +8,34 @@ from core.tools.tool.builtin_tool import BuiltinTool class NominatimReverseTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - lat = tool_parameters.get('lat') - lon = tool_parameters.get('lon') - - if lat is None or lon is None: - return self.create_text_message('Please provide both latitude and longitude') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + lat = tool_parameters.get("lat") + lon = tool_parameters.get("lon") - params = { - 'lat': lat, - 'lon': lon, - 'format': 'json', - 'addressdetails': 1 - } - - return self._make_request(user_id, 'reverse', params) + if lat is None or lon is None: + return self.create_text_message("Please provide both latitude and longitude") + + params = {"lat": lat, "lon": lon, "format": "json", "addressdetails": 1} + + return self._make_request(user_id, "reverse", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py index 983cbc0e34..34851d86dc 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py @@ -8,42 +8,34 @@ from core.tools.tool.builtin_tool import BuiltinTool class NominatimSearchTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters.get('query', '') - limit = tool_parameters.get('limit', 10) - - if not query: - return self.create_text_message('Please input a search query') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters.get("query", "") + limit = tool_parameters.get("limit", 10) - params = { - 'q': query, - 'format': 'json', - 'limit': limit, - 'addressdetails': 1 - } - - return self._make_request(user_id, 'search', params) + if not query: + return self.create_text_message("Please input a search query") + + params = {"q": query, "format": "json", "limit": limit, "addressdetails": 1} + + return self._make_request(user_id, "search", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py index b753be4791..762e158459 100644 --- a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py +++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py @@ -12,10 +12,10 @@ class NovitaAiToolBase: if not loras_str: return [] - loras_ori_list = lora_str.strip().split(';') + loras_ori_list = lora_str.strip().split(";") result_list = [] for lora_str in loras_ori_list: - lora_info = lora_str.strip().split(',') + lora_info = lora_str.strip().split(",") lora = Txt2ImgV3LoRA( model_name=lora_info[0].strip(), strength=float(lora_info[1]), @@ -28,43 +28,39 @@ class NovitaAiToolBase: if not embeddings_str: return [] - embeddings_ori_list = embeddings_str.strip().split(';') + embeddings_ori_list = embeddings_str.strip().split(";") result_list = [] for embedding_str in embeddings_ori_list: - embedding = Txt2ImgV3Embedding( - model_name=embedding_str.strip() - ) + embedding = Txt2ImgV3Embedding(model_name=embedding_str.strip()) result_list.append(embedding) return result_list def _extract_hires_fix(self, hires_fix_str: str): - hires_fix_info = hires_fix_str.strip().split(',') - if 'upscaler' in hires_fix_info: + hires_fix_info = hires_fix_str.strip().split(",") + if "upscaler" in hires_fix_info: hires_fix = Txt2ImgV3HiresFix( target_width=int(hires_fix_info[0]), target_height=int(hires_fix_info[1]), strength=float(hires_fix_info[2]), - upscaler=hires_fix_info[3].strip() + upscaler=hires_fix_info[3].strip(), ) else: hires_fix = Txt2ImgV3HiresFix( target_width=int(hires_fix_info[0]), target_height=int(hires_fix_info[1]), - strength=float(hires_fix_info[2]) + strength=float(hires_fix_info[2]), ) return hires_fix def _extract_refiner(self, switch_at: str): - refiner = Txt2ImgV3Refiner( - switch_at=float(switch_at) - ) + refiner = Txt2ImgV3Refiner(switch_at=float(switch_at)) return refiner def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool: """ - is hit nsfw + is hit nsfw """ if image.nsfw_detection_result is None: return False diff --git a/api/core/tools/provider/builtin/novitaai/novitaai.py b/api/core/tools/provider/builtin/novitaai/novitaai.py index 1e7d9757c3..d5e32eff29 100644 --- a/api/core/tools/provider/builtin/novitaai/novitaai.py +++ b/api/core/tools/provider/builtin/novitaai/novitaai.py @@ -8,23 +8,27 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class NovitaAIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - result = NovitaAiTxt2ImgTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - 'model_name': 'cinenautXLATRUE_cinenautV10_392434.safetensors', - 'prompt': 'a futuristic city with flying cars', - 'negative_prompt': '', - 'width': 128, - 'height': 128, - 'image_num': 1, - 'guidance_scale': 7.5, - 'seed': -1, - 'steps': 1, - }, + result = ( + NovitaAiTxt2ImgTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "model_name": "cinenautXLATRUE_cinenautV10_392434.safetensors", + "prompt": "a futuristic city with flying cars", + "negative_prompt": "", + "width": 128, + "height": 128, + "image_num": 1, + "guidance_scale": 7.5, + "seed": -1, + "steps": 1, + }, + ) ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py index e63c891957..f76587bea1 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py @@ -12,17 +12,18 @@ from core.tools.tool.builtin_tool import BuiltinTool class NovitaAiCreateTileTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') + api_key = self.runtime.credentials.get("api_key") client = NovitaClient(api_key=api_key) param = self._process_parameters(tool_parameters) @@ -30,21 +31,23 @@ class NovitaAiCreateTileTool(BuiltinTool): results = [] results.append( - self.create_blob_message(blob=b64decode(client_result.image_file), - meta={'mime_type': f'image/{client_result.image_type}'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + self.create_blob_message( + blob=b64decode(client_result.image_file), + meta={"mime_type": f"image/{client_result.image_type}"}, + save_as=self.VARIABLE_KEY.IMAGE.value, + ) ) return results def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ res_parameters = deepcopy(parameters) # delete none and empty - keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""] for k in keys_to_delete: del res_parameters[k] diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py index ec2927675e..fe105f70a7 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py @@ -12,127 +12,137 @@ from core.tools.tool.builtin_tool import BuiltinTool class NovitaAiModelQueryTool(BuiltinTool): - _model_query_endpoint = 'https://api.novita.ai/v3/model' + _model_query_endpoint = "https://api.novita.ai/v3/model" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') - headers = { - 'Content-Type': 'application/json', - 'Authorization': "Bearer " + api_key - } + api_key = self.runtime.credentials.get("api_key") + headers = {"Content-Type": "application/json", "Authorization": "Bearer " + api_key} params = self._process_parameters(tool_parameters) - result_type = params.get('result_type') - del params['result_type'] + result_type = params.get("result_type") + del params["result_type"] models_data = self._query_models( models_data=[], headers=headers, params=params, - recursive=False if result_type == 'first sd_name' or result_type == 'first name sd_name pair' else True + recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True, ) - result_str = '' - if result_type == 'first sd_name': - result_str = models_data[0]['sd_name_in_api'] if len(models_data) > 0 else '' - elif result_type == 'first name sd_name pair': - result_str = json.dumps({'name': models_data[0]['name'], 'sd_name': models_data[0]['sd_name_in_api']}) if len(models_data) > 0 else '' - elif result_type == 'sd_name array': - sd_name_array = [model['sd_name_in_api'] for model in models_data] if len(models_data) > 0 else [] + result_str = "" + if result_type == "first sd_name": + result_str = models_data[0]["sd_name_in_api"] if len(models_data) > 0 else "" + elif result_type == "first name sd_name pair": + result_str = ( + json.dumps({"name": models_data[0]["name"], "sd_name": models_data[0]["sd_name_in_api"]}) + if len(models_data) > 0 + else "" + ) + elif result_type == "sd_name array": + sd_name_array = [model["sd_name_in_api"] for model in models_data] if len(models_data) > 0 else [] result_str = json.dumps(sd_name_array) - elif result_type == 'name array': - name_array = [model['name'] for model in models_data] if len(models_data) > 0 else [] + elif result_type == "name array": + name_array = [model["name"] for model in models_data] if len(models_data) > 0 else [] result_str = json.dumps(name_array) - elif result_type == 'name sd_name pair array': - name_sd_name_pair_array = [{'name': model['name'], 'sd_name': model['sd_name_in_api']} - for model in models_data] if len(models_data) > 0 else [] + elif result_type == "name sd_name pair array": + name_sd_name_pair_array = ( + [{"name": model["name"], "sd_name": model["sd_name_in_api"]} for model in models_data] + if len(models_data) > 0 + else [] + ) result_str = json.dumps(name_sd_name_pair_array) - elif result_type == 'whole info array': + elif result_type == "whole info array": result_str = json.dumps(models_data) else: raise NotImplementedError return self.create_text_message(result_str) - def _query_models(self, models_data: list, headers: dict[str, Any], - params: dict[str, Any], pagination_cursor: str = '', recursive: bool = True) -> list: + def _query_models( + self, + models_data: list, + headers: dict[str, Any], + params: dict[str, Any], + pagination_cursor: str = "", + recursive: bool = True, + ) -> list: """ - query models + query models """ inside_params = deepcopy(params) - if pagination_cursor != '': - inside_params['pagination.cursor'] = pagination_cursor + if pagination_cursor != "": + inside_params["pagination.cursor"] = pagination_cursor response = ssrf_proxy.get( - url=str(URL(self._model_query_endpoint)), - headers=headers, - params=params, - timeout=(10, 60) + url=str(URL(self._model_query_endpoint)), headers=headers, params=params, timeout=(10, 60) ) res_data = response.json() - models_data.extend(res_data['models']) + models_data.extend(res_data["models"]) - res_data_len = len(res_data['models']) - if res_data_len == 0 or res_data_len < int(params['pagination.limit']) or recursive is False: + res_data_len = len(res_data["models"]) + if res_data_len == 0 or res_data_len < int(params["pagination.limit"]) or recursive is False: # deduplicate df = DataFrame.from_dict(models_data) - df_unique = df.drop_duplicates(subset=['id']) - models_data = df_unique.to_dict('records') + df_unique = df.drop_duplicates(subset=["id"]) + models_data = df_unique.to_dict("records") return models_data return self._query_models( models_data=models_data, headers=headers, params=inside_params, - pagination_cursor=res_data['pagination']['next_cursor'] + pagination_cursor=res_data["pagination"]["next_cursor"], ) def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ process_parameters = deepcopy(parameters) res_parameters = {} # delete none or empty - keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ""] for k in keys_to_delete: del process_parameters[k] - if 'query' in process_parameters and process_parameters.get('query') != 'unspecified': - res_parameters['filter.query'] = process_parameters['query'] + if "query" in process_parameters and process_parameters.get("query") != "unspecified": + res_parameters["filter.query"] = process_parameters["query"] - if 'visibility' in process_parameters and process_parameters.get('visibility') != 'unspecified': - res_parameters['filter.visibility'] = process_parameters['visibility'] + if "visibility" in process_parameters and process_parameters.get("visibility") != "unspecified": + res_parameters["filter.visibility"] = process_parameters["visibility"] - if 'source' in process_parameters and process_parameters.get('source') != 'unspecified': - res_parameters['filter.source'] = process_parameters['source'] + if "source" in process_parameters and process_parameters.get("source") != "unspecified": + res_parameters["filter.source"] = process_parameters["source"] - if 'type' in process_parameters and process_parameters.get('type') != 'unspecified': - res_parameters['filter.types'] = process_parameters['type'] + if "type" in process_parameters and process_parameters.get("type") != "unspecified": + res_parameters["filter.types"] = process_parameters["type"] - if 'is_sdxl' in process_parameters: - if process_parameters['is_sdxl'] == 'true': - res_parameters['filter.is_sdxl'] = True - elif process_parameters['is_sdxl'] == 'false': - res_parameters['filter.is_sdxl'] = False + if "is_sdxl" in process_parameters: + if process_parameters["is_sdxl"] == "true": + res_parameters["filter.is_sdxl"] = True + elif process_parameters["is_sdxl"] == "false": + res_parameters["filter.is_sdxl"] = False - res_parameters['result_type'] = process_parameters.get('result_type', 'first sd_name') + res_parameters["result_type"] = process_parameters.get("result_type", "first sd_name") - res_parameters['pagination.limit'] = 1 \ - if res_parameters.get('result_type') == 'first sd_name' \ - or res_parameters.get('result_type') == 'first name sd_name pair'\ + res_parameters["pagination.limit"] = ( + 1 + if res_parameters.get("result_type") == "first sd_name" + or res_parameters.get("result_type") == "first name sd_name pair" else 100 + ) return res_parameters diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index 5fef3d2da7..9632c163cf 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -13,17 +13,18 @@ from core.tools.tool.builtin_tool import BuiltinTool class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') + api_key = self.runtime.credentials.get("api_key") client = NovitaClient(api_key=api_key) param = self._process_parameters(tool_parameters) @@ -32,56 +33,58 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): results = [] for image_encoded, image in zip(client_result.images_encoded, client_result.images): if self._is_hit_nsfw_detection(image, 0.8): - results = self.create_text_message(text='NSFW detected!') + results = self.create_text_message(text="NSFW detected!") break results.append( - self.create_blob_message(blob=b64decode(image_encoded), - meta={'mime_type': f'image/{image.image_type}'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + self.create_blob_message( + blob=b64decode(image_encoded), + meta={"mime_type": f"image/{image.image_type}"}, + save_as=self.VARIABLE_KEY.IMAGE.value, + ) ) return results def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ res_parameters = deepcopy(parameters) # delete none and empty - keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""] for k in keys_to_delete: del res_parameters[k] - if 'clip_skip' in res_parameters and res_parameters.get('clip_skip') == 0: - del res_parameters['clip_skip'] + if "clip_skip" in res_parameters and res_parameters.get("clip_skip") == 0: + del res_parameters["clip_skip"] - if 'refiner_switch_at' in res_parameters and res_parameters.get('refiner_switch_at') == 0: - del res_parameters['refiner_switch_at'] + if "refiner_switch_at" in res_parameters and res_parameters.get("refiner_switch_at") == 0: + del res_parameters["refiner_switch_at"] - if 'enabled_enterprise_plan' in res_parameters: - res_parameters['enterprise_plan'] = {'enabled': res_parameters['enabled_enterprise_plan']} - del res_parameters['enabled_enterprise_plan'] + if "enabled_enterprise_plan" in res_parameters: + res_parameters["enterprise_plan"] = {"enabled": res_parameters["enabled_enterprise_plan"]} + del res_parameters["enabled_enterprise_plan"] - if 'nsfw_detection_level' in res_parameters: - res_parameters['nsfw_detection_level'] = int(res_parameters['nsfw_detection_level']) + if "nsfw_detection_level" in res_parameters: + res_parameters["nsfw_detection_level"] = int(res_parameters["nsfw_detection_level"]) # process loras - if 'loras' in res_parameters: - res_parameters['loras'] = self._extract_loras(res_parameters.get('loras')) + if "loras" in res_parameters: + res_parameters["loras"] = self._extract_loras(res_parameters.get("loras")) # process embeddings - if 'embeddings' in res_parameters: - res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings')) + if "embeddings" in res_parameters: + res_parameters["embeddings"] = self._extract_embeddings(res_parameters.get("embeddings")) # process hires_fix - if 'hires_fix' in res_parameters: - res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix')) + if "hires_fix" in res_parameters: + res_parameters["hires_fix"] = self._extract_hires_fix(res_parameters.get("hires_fix")) # process refiner - if 'refiner_switch_at' in res_parameters: - res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at')) - del res_parameters['refiner_switch_at'] + if "refiner_switch_at" in res_parameters: + res_parameters["refiner"] = self._extract_refiner(res_parameters.get("refiner_switch_at")) + del res_parameters["refiner_switch_at"] return res_parameters diff --git a/api/core/tools/provider/builtin/onebot/onebot.py b/api/core/tools/provider/builtin/onebot/onebot.py index 42f321e919..b8e5ed24d6 100644 --- a/api/core/tools/provider/builtin/onebot/onebot.py +++ b/api/core/tools/provider/builtin/onebot/onebot.py @@ -5,8 +5,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class OneBotProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict[str, Any]) -> None: - if not credentials.get("ob11_http_url"): - raise ToolProviderCredentialValidationError('OneBot HTTP URL is required.') + raise ToolProviderCredentialValidationError("OneBot HTTP URL is required.") diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py index 2a1a9f86de..9c95bbc2ae 100644 --- a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py @@ -11,54 +11,29 @@ class SendGroupMsg(BuiltinTool): """OneBot v11 Tool: Send Group Message""" def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # Get parameters - send_group_id = tool_parameters.get('group_id', '') - - message = tool_parameters.get('message', '') + send_group_id = tool_parameters.get("group_id", "") + + message = tool_parameters.get("message", "") if not message: - return self.create_json_message( - { - 'error': 'Message is empty.' - } - ) - - auto_escape = tool_parameters.get('auto_escape', False) + return self.create_json_message({"error": "Message is empty."}) + + auto_escape = tool_parameters.get("auto_escape", False) try: - url = URL(self.runtime.credentials['ob11_http_url']) / 'send_group_msg' + url = URL(self.runtime.credentials["ob11_http_url"]) / "send_group_msg" resp = requests.post( url, - json={ - 'group_id': send_group_id, - 'message': message, - 'auto_escape': auto_escape - }, - headers={ - 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] - } + json={"group_id": send_group_id, "message": message, "auto_escape": auto_escape}, + headers={"Authorization": "Bearer " + self.runtime.credentials["access_token"]}, ) if resp.status_code != 200: - return self.create_json_message( - { - 'error': f'Failed to send group message: {resp.text}' - } - ) + return self.create_json_message({"error": f"Failed to send group message: {resp.text}"}) - return self.create_json_message( - { - 'response': resp.json() - } - ) + return self.create_json_message({"response": resp.json()}) except Exception as e: - return self.create_json_message( - { - 'error': f'Failed to send group message: {e}' - } - ) + return self.create_json_message({"error": f"Failed to send group message: {e}"}) diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py index 8ef4d72ab6..1174c7f07d 100644 --- a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py @@ -11,54 +11,29 @@ class SendPrivateMsg(BuiltinTool): """OneBot v11 Tool: Send Private Message""" def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # Get parameters - send_user_id = tool_parameters.get('user_id', '') - - message = tool_parameters.get('message', '') + send_user_id = tool_parameters.get("user_id", "") + + message = tool_parameters.get("message", "") if not message: - return self.create_json_message( - { - 'error': 'Message is empty.' - } - ) - - auto_escape = tool_parameters.get('auto_escape', False) + return self.create_json_message({"error": "Message is empty."}) + + auto_escape = tool_parameters.get("auto_escape", False) try: - url = URL(self.runtime.credentials['ob11_http_url']) / 'send_private_msg' + url = URL(self.runtime.credentials["ob11_http_url"]) / "send_private_msg" resp = requests.post( url, - json={ - 'user_id': send_user_id, - 'message': message, - 'auto_escape': auto_escape - }, - headers={ - 'Authorization': 'Bearer ' + self.runtime.credentials['access_token'] - } + json={"user_id": send_user_id, "message": message, "auto_escape": auto_escape}, + headers={"Authorization": "Bearer " + self.runtime.credentials["access_token"]}, ) if resp.status_code != 200: - return self.create_json_message( - { - 'error': f'Failed to send private message: {resp.text}' - } - ) - - return self.create_json_message( - { - 'response': resp.json() - } - ) + return self.create_json_message({"error": f"Failed to send private message: {resp.text}"}) + + return self.create_json_message({"response": resp.json()}) except Exception as e: - return self.create_json_message( - { - 'error': f'Failed to send private message: {e}' - } - ) \ No newline at end of file + return self.create_json_message({"error": f"Failed to send private message: {e}"}) diff --git a/api/core/tools/provider/builtin/openweather/openweather.py b/api/core/tools/provider/builtin/openweather/openweather.py index a2827177a3..9e40249aba 100644 --- a/api/core/tools/provider/builtin/openweather/openweather.py +++ b/api/core/tools/provider/builtin/openweather/openweather.py @@ -5,7 +5,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl def query_weather(city="Beijing", units="metric", language="zh_cn", api_key=None): - url = "https://api.openweathermap.org/data/2.5/weather" params = {"q": city, "appid": api_key, "units": units, "lang": language} @@ -16,21 +15,15 @@ class OpenweatherProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: if "api_key" not in credentials or not credentials.get("api_key"): - raise ToolProviderCredentialValidationError( - "Open weather API key is required." - ) + raise ToolProviderCredentialValidationError("Open weather API key is required.") apikey = credentials.get("api_key") try: response = query_weather(api_key=apikey) if response.status_code == 200: pass else: - raise ToolProviderCredentialValidationError( - (response.json()).get("info") - ) + raise ToolProviderCredentialValidationError((response.json()).get("info")) except Exception as e: - raise ToolProviderCredentialValidationError( - "Open weather API Key is invalid. {}".format(e) - ) + raise ToolProviderCredentialValidationError("Open weather API Key is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.py b/api/core/tools/provider/builtin/openweather/tools/weather.py index d6c49a230f..ed4ec487fa 100644 --- a/api/core/tools/provider/builtin/openweather/tools/weather.py +++ b/api/core/tools/provider/builtin/openweather/tools/weather.py @@ -17,10 +17,7 @@ class OpenweatherTool(BuiltinTool): city = tool_parameters.get("city", "") if not city: return self.create_text_message("Please tell me your city") - if ( - "api_key" not in self.runtime.credentials - or not self.runtime.credentials.get("api_key") - ): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("OpenWeather API key is required.") units = tool_parameters.get("units", "metric") @@ -39,12 +36,9 @@ class OpenweatherTool(BuiltinTool): response = requests.get(url, params=params) if response.status_code == 200: - data = response.json() return self.create_text_message( - self.summary( - user_id=user_id, content=json.dumps(data, ensure_ascii=False) - ) + self.summary(user_id=user_id, content=json.dumps(data, ensure_ascii=False)) ) else: error_message = { @@ -55,6 +49,4 @@ class OpenweatherTool(BuiltinTool): return json.dumps(error_message) except Exception as e: - return self.create_text_message( - "Openweather API Key is invalid. {}".format(e) - ) + return self.create_text_message("Openweather API Key is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/perplexity/perplexity.py b/api/core/tools/provider/builtin/perplexity/perplexity.py index ff91edf18d..80518853fb 100644 --- a/api/core/tools/provider/builtin/perplexity/perplexity.py +++ b/api/core/tools/provider/builtin/perplexity/perplexity.py @@ -11,34 +11,26 @@ class PerplexityProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: headers = { "Authorization": f"Bearer {credentials.get('perplexity_api_key')}", - "Content-Type": "application/json" + "Content-Type": "application/json", } - + payload = { "model": "llama-3.1-sonar-small-128k-online", "messages": [ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "Hello" - } + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, ], "max_tokens": 5, "temperature": 0.1, "top_p": 0.9, - "stream": False + "stream": False, } try: response = requests.post(PERPLEXITY_API_URL, json=payload, headers=headers) response.raise_for_status() except requests.RequestException as e: - raise ToolProviderCredentialValidationError( - f"Failed to validate Perplexity API key: {str(e)}" - ) + raise ToolProviderCredentialValidationError(f"Failed to validate Perplexity API key: {str(e)}") if response.status_code != 200: raise ToolProviderCredentialValidationError( diff --git a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py index 5b1a263f9b..5ed4b9ca99 100644 --- a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py +++ b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py @@ -8,65 +8,60 @@ from core.tools.tool.builtin_tool import BuiltinTool PERPLEXITY_API_URL = "https://api.perplexity.ai/chat/completions" + class PerplexityAITool(BuiltinTool): def _parse_response(self, response: dict) -> dict: """Parse the response from Perplexity AI API""" - if 'choices' in response and len(response['choices']) > 0: - message = response['choices'][0]['message'] + if "choices" in response and len(response["choices"]) > 0: + message = response["choices"][0]["message"] return { - 'content': message.get('content', ''), - 'role': message.get('role', ''), - 'citations': response.get('citations', []) + "content": message.get("content", ""), + "role": message.get("role", ""), + "citations": response.get("citations", []), } else: - return {'content': 'Unable to get a valid response', 'role': 'assistant', 'citations': []} + return {"content": "Unable to get a valid response", "role": "assistant", "citations": []} - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: headers = { "Authorization": f"Bearer {self.runtime.credentials['perplexity_api_key']}", - "Content-Type": "application/json" + "Content-Type": "application/json", } - + payload = { - "model": tool_parameters.get('model', 'llama-3.1-sonar-small-128k-online'), + "model": tool_parameters.get("model", "llama-3.1-sonar-small-128k-online"), "messages": [ - { - "role": "system", - "content": "Be precise and concise." - }, - { - "role": "user", - "content": tool_parameters['query'] - } + {"role": "system", "content": "Be precise and concise."}, + {"role": "user", "content": tool_parameters["query"]}, ], - "max_tokens": tool_parameters.get('max_tokens', 4096), - "temperature": tool_parameters.get('temperature', 0.7), - "top_p": tool_parameters.get('top_p', 1), - "top_k": tool_parameters.get('top_k', 5), - "presence_penalty": tool_parameters.get('presence_penalty', 0), - "frequency_penalty": tool_parameters.get('frequency_penalty', 1), - "stream": False + "max_tokens": tool_parameters.get("max_tokens", 4096), + "temperature": tool_parameters.get("temperature", 0.7), + "top_p": tool_parameters.get("top_p", 1), + "top_k": tool_parameters.get("top_k", 5), + "presence_penalty": tool_parameters.get("presence_penalty", 0), + "frequency_penalty": tool_parameters.get("frequency_penalty", 1), + "stream": False, } - - if 'search_recency_filter' in tool_parameters: - payload['search_recency_filter'] = tool_parameters['search_recency_filter'] - if 'return_citations' in tool_parameters: - payload['return_citations'] = tool_parameters['return_citations'] - if 'search_domain_filter' in tool_parameters: - if isinstance(tool_parameters['search_domain_filter'], str): - payload['search_domain_filter'] = [tool_parameters['search_domain_filter']] - elif isinstance(tool_parameters['search_domain_filter'], list): - payload['search_domain_filter'] = tool_parameters['search_domain_filter'] - + + if "search_recency_filter" in tool_parameters: + payload["search_recency_filter"] = tool_parameters["search_recency_filter"] + if "return_citations" in tool_parameters: + payload["return_citations"] = tool_parameters["return_citations"] + if "search_domain_filter" in tool_parameters: + if isinstance(tool_parameters["search_domain_filter"], str): + payload["search_domain_filter"] = [tool_parameters["search_domain_filter"]] + elif isinstance(tool_parameters["search_domain_filter"], list): + payload["search_domain_filter"] = tool_parameters["search_domain_filter"] response = requests.post(url=PERPLEXITY_API_URL, json=payload, headers=headers) response.raise_for_status() valuable_res = self._parse_response(response.json()) - + return [ self.create_json_message(valuable_res), - self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2)) + self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2)), ] diff --git a/api/core/tools/provider/builtin/pubmed/pubmed.py b/api/core/tools/provider/builtin/pubmed/pubmed.py index 05cd171b87..ea3a477c30 100644 --- a/api/core/tools/provider/builtin/pubmed/pubmed.py +++ b/api/core/tools/provider/builtin/pubmed/pubmed.py @@ -11,11 +11,10 @@ class PubMedProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py index 58811d65e6..fedfdbd859 100644 --- a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py +++ b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py @@ -51,17 +51,12 @@ class PubMedAPIWrapper(BaseModel): try: # Retrieve the top-k results for the query docs = [ - f"Published: {result['pub_date']}\nTitle: {result['title']}\n" - f"Summary: {result['summary']}" + f"Published: {result['pub_date']}\nTitle: {result['title']}\n" f"Summary: {result['summary']}" for result in self.load(query[: self.ARXIV_MAX_QUERY_LENGTH]) ] # Join the results and limit the character count - return ( - "\n\n".join(docs)[:self.doc_content_chars_max] - if docs - else "No good PubMed Result was found" - ) + return "\n\n".join(docs)[: self.doc_content_chars_max] if docs else "No good PubMed Result was found" except Exception as ex: return f"PubMed exception: {ex}" @@ -91,13 +86,7 @@ class PubMedAPIWrapper(BaseModel): return articles def retrieve_article(self, uid: str, webenv: str) -> dict: - url = ( - self.base_url_efetch - + "db=pubmed&retmode=xml&id=" - + uid - + "&webenv=" - + webenv - ) + url = self.base_url_efetch + "db=pubmed&retmode=xml&id=" + uid + "&webenv=" + webenv retry = 0 while True: @@ -108,10 +97,7 @@ class PubMedAPIWrapper(BaseModel): if e.code == 429 and retry < self.max_retry: # Too Many Requests error # wait for an exponentially increasing amount of time - print( - f"Too Many Requests, " - f"waiting for {self.sleep_time:.2f} seconds..." - ) + print(f"Too Many Requests, " f"waiting for {self.sleep_time:.2f} seconds...") time.sleep(self.sleep_time) self.sleep_time *= 2 retry += 1 @@ -125,27 +111,21 @@ class PubMedAPIWrapper(BaseModel): if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - title = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + title = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Get abstract abstract = "" if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - abstract = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + abstract = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Get publication date pub_date = "" if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - pub_date = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + pub_date = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Return article as dictionary article = { @@ -182,6 +162,7 @@ class PubmedQueryRun(BaseModel): class PubMedInput(BaseModel): query: str = Field(..., description="Search query.") + class PubMedSearchTool(BuiltinTool): """ Tool for performing a search using PubMed search engine. @@ -198,14 +179,13 @@ class PubMedSearchTool(BuiltinTool): Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') + return self.create_text_message("Please input query") tool = PubmedQueryRun(args_schema=PubMedInput) result = tool._run(query) return self.create_text_message(self.summary(user_id=user_id, content=result)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/qrcode/qrcode.py b/api/core/tools/provider/builtin/qrcode/qrcode.py index 9fa7d01265..8466b9a26b 100644 --- a/api/core/tools/provider/builtin/qrcode/qrcode.py +++ b/api/core/tools/provider/builtin/qrcode/qrcode.py @@ -8,9 +8,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class QRCodeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - QRCodeGeneratorTool().invoke(user_id='', - tool_parameters={ - 'content': 'Dify 123 😊' - }) + QRCodeGeneratorTool().invoke(user_id="", tool_parameters={"content": "Dify 123 😊"}) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index 5eede98f5e..8aefc65131 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -13,43 +13,44 @@ from core.tools.tool.builtin_tool import BuiltinTool class QRCodeGeneratorTool(BuiltinTool): error_correction_levels: dict[str, int] = { - 'L': ERROR_CORRECT_L, # <=7% - 'M': ERROR_CORRECT_M, # <=15% - 'Q': ERROR_CORRECT_Q, # <=25% - 'H': ERROR_CORRECT_H, # <=30% + "L": ERROR_CORRECT_L, # <=7% + "M": ERROR_CORRECT_M, # <=15% + "Q": ERROR_CORRECT_Q, # <=25% + "H": ERROR_CORRECT_H, # <=30% } - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get text content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get border size - border = tool_parameters.get('border', 0) + border = tool_parameters.get("border", 0) if border < 0 or border > 100: - return self.create_text_message('Invalid parameter border') + return self.create_text_message("Invalid parameter border") # get error_correction - error_correction = tool_parameters.get('error_correction', '') + error_correction = tool_parameters.get("error_correction", "") if error_correction not in self.error_correction_levels.keys(): - return self.create_text_message('Invalid parameter error_correction') + return self.create_text_message("Invalid parameter error_correction") try: image = self._generate_qrcode(content, border, error_correction) image_bytes = self._image_to_byte_array(image) - return self.create_blob_message(blob=image_bytes, - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + return self.create_blob_message( + blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + ) except Exception: - logging.exception(f'Failed to generate QR code for content: {content}') - return self.create_text_message('Failed to generate QR code') + logging.exception(f"Failed to generate QR code for content: {content}") + return self.create_text_message("Failed to generate QR code") def _generate_qrcode(self, content: str, border: int, error_correction: str) -> BaseImage: qr = QRCode( diff --git a/api/core/tools/provider/builtin/regex/regex.py b/api/core/tools/provider/builtin/regex/regex.py index d38ae1b292..c498105979 100644 --- a/api/core/tools/provider/builtin/regex/regex.py +++ b/api/core/tools/provider/builtin/regex/regex.py @@ -9,10 +9,10 @@ class RegexProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: RegexExpressionTool().invoke( - user_id='', + user_id="", tool_parameters={ - 'content': '1+(2+3)*4', - 'expression': r'(\d+)', + "content": "1+(2+3)*4", + "expression": r"(\d+)", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/regex/tools/regex_extract.py b/api/core/tools/provider/builtin/regex/tools/regex_extract.py index 5d8f013d0d..786b469404 100644 --- a/api/core/tools/provider/builtin/regex/tools/regex_extract.py +++ b/api/core/tools/provider/builtin/regex/tools/regex_extract.py @@ -6,22 +6,23 @@ from core.tools.tool.builtin_tool import BuiltinTool class RegexExpressionTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get expression - content = tool_parameters.get('content', '').strip() + content = tool_parameters.get("content", "").strip() if not content: - return self.create_text_message('Invalid content') - expression = tool_parameters.get('expression', '').strip() + return self.create_text_message("Invalid content") + expression = tool_parameters.get("expression", "").strip() if not expression: - return self.create_text_message('Invalid expression') + return self.create_text_message("Invalid expression") try: result = re.findall(expression, content) return self.create_text_message(str(result)) except Exception as e: - return self.create_text_message(f'Failed to extract result, error: {str(e)}') \ No newline at end of file + return self.create_text_message(f"Failed to extract result, error: {str(e)}") diff --git a/api/core/tools/provider/builtin/searchapi/searchapi.py b/api/core/tools/provider/builtin/searchapi/searchapi.py index 6fa4f05acd..109bba8b2d 100644 --- a/api/core/tools/provider/builtin/searchapi/searchapi.py +++ b/api/core/tools/provider/builtin/searchapi/searchapi.py @@ -13,11 +13,8 @@ class SearchAPIProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "SearchApi dify", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "SearchApi dify", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index dd780aeadc..d632304a46 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -80,25 +81,29 @@ class SearchAPI: toret = "No good search result found" return toret + class GoogleTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] num = tool_parameters.get("num", 10) google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") location = tool_parameters.get("location") - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index 81c67c51a9..1544061c08 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -50,7 +51,16 @@ class SearchAPI: if type == "text": if "jobs" in res.keys() and "title" in res["jobs"][0].keys(): for item in res["jobs"]: - toret += "title: " + item["title"] + "\n" + "company_name: " + item["company_name"] + "content: " + item["description"] + "\n" + toret += ( + "title: " + + item["title"] + + "\n" + + "company_name: " + + item["company_name"] + + "content: " + + item["description"] + + "\n" + ) if toret == "": toret = "No good search result found" @@ -62,16 +72,18 @@ class SearchAPI: toret = "No good search result found" return toret + class GoogleJobsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] is_remote = tool_parameters.get("is_remote") google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") @@ -80,9 +92,11 @@ class GoogleJobsTool(BuiltinTool): ltype = 1 if is_remote else None - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, google_domain=google_domain, gl=gl, hl=hl, location=location, ltype=ltype) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, google_domain=google_domain, gl=gl, hl=hl, location=location, ltype=ltype + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index 5d2657dddd..95a7aad736 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -68,25 +69,29 @@ class SearchAPI: toret = "No good search result found" return toret + class GoogleNewsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] num = tool_parameters.get("num", 10) google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") location = tool_parameters.get("location") - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index 6345b33801..88def504fc 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -55,18 +56,20 @@ class SearchAPI: return toret + class YoutubeTranscriptsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - video_id = tool_parameters['video_id'] - language = tool_parameters.get('language', "en") + video_id = tool_parameters["video_id"] + language = tool_parameters.get("language", "en") - api_key = self.runtime.credentials['searchapi_api_key'] + api_key = self.runtime.credentials["searchapi_api_key"] result = SearchAPI(api_key).run(video_id, language=language) return self.create_text_message(text=result) diff --git a/api/core/tools/provider/builtin/searxng/searxng.py b/api/core/tools/provider/builtin/searxng/searxng.py index ab354003e6..b7bbcc60b1 100644 --- a/api/core/tools/provider/builtin/searxng/searxng.py +++ b/api/core/tools/provider/builtin/searxng/searxng.py @@ -13,12 +13,8 @@ class SearXNGProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "SearXNG", - "limit": 1, - "search_type": "general" - }, + user_id="", + tool_parameters={"query": "SearXNG", "limit": 1, "search_type": "general"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py index dc835a8e8c..c5e339a108 100644 --- a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py @@ -23,18 +23,21 @@ class SearXNGSearchTool(BuiltinTool): ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - host = self.runtime.credentials.get('searxng_base_url') + host = self.runtime.credentials.get("searxng_base_url") if not host: - raise Exception('SearXNG api is required') + raise Exception("SearXNG api is required") - response = requests.get(host, params={ - "q": tool_parameters.get('query'), - "format": "json", - "categories": tool_parameters.get('search_type', 'general') - }) + response = requests.get( + host, + params={ + "q": tool_parameters.get("query"), + "format": "json", + "categories": tool_parameters.get("search_type", "general"), + }, + ) if response.status_code != 200: - raise Exception(f'Error {response.status_code}: {response.text}') + raise Exception(f"Error {response.status_code}: {response.text}") res = response.json().get("results", []) if not res: diff --git a/api/core/tools/provider/builtin/serper/serper.py b/api/core/tools/provider/builtin/serper/serper.py index 2a42109373..cb1d090a9d 100644 --- a/api/core/tools/provider/builtin/serper/serper.py +++ b/api/core/tools/provider/builtin/serper/serper.py @@ -13,11 +13,8 @@ class SerperProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "test", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "test", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/serper/tools/serper_search.py b/api/core/tools/provider/builtin/serper/tools/serper_search.py index 24facaf4ec..7baebbf958 100644 --- a/api/core/tools/provider/builtin/serper/tools/serper_search.py +++ b/api/core/tools/provider/builtin/serper/tools/serper_search.py @@ -9,7 +9,6 @@ SERPER_API_URL = "https://google.serper.dev/search" class SerperSearchTool(BuiltinTool): - def _parse_response(self, response: dict) -> dict: result = {} if "knowledgeGraph" in response: @@ -17,28 +16,19 @@ class SerperSearchTool(BuiltinTool): result["description"] = response["knowledgeGraph"].get("description", "") if "organic" in response: result["organic"] = [ - { - "title": item.get("title", ""), - "link": item.get("link", ""), - "snippet": item.get("snippet", "") - } + {"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")} for item in response["organic"] ] return result - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - params = { - "q": tool_parameters['query'], - "gl": "us", - "hl": "en" - } - headers = { - 'X-API-KEY': self.runtime.credentials['serperapi_api_key'], - 'Content-Type': 'application/json' - } - response = requests.get(url=SERPER_API_URL, params=params,headers=headers) + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + params = {"q": tool_parameters["query"], "gl": "us", "hl": "en"} + headers = {"X-API-KEY": self.runtime.credentials["serperapi_api_key"], "Content-Type": "application/json"} + response = requests.get(url=SERPER_API_URL, params=params, headers=headers) response.raise_for_status() valuable_res = self._parse_response(response.json()) return self.create_json_message(valuable_res) diff --git a/api/core/tools/provider/builtin/siliconflow/siliconflow.py b/api/core/tools/provider/builtin/siliconflow/siliconflow.py index 0df78280df..37a0b0755b 100644 --- a/api/core/tools/provider/builtin/siliconflow/siliconflow.py +++ b/api/core/tools/provider/builtin/siliconflow/siliconflow.py @@ -14,6 +14,4 @@ class SiliconflowProvider(BuiltinToolProviderController): response = requests.get(url, headers=headers) if response.status_code != 200: - raise ToolProviderCredentialValidationError( - "SiliconFlow API key is invalid" - ) + raise ToolProviderCredentialValidationError("SiliconFlow API key is invalid") diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.py b/api/core/tools/provider/builtin/siliconflow/tools/flux.py index ed9f4be574..5fa9926484 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/flux.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.py @@ -5,17 +5,13 @@ import requests from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -FLUX_URL = ( - "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" -) +FLUX_URL = "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" class FluxTool(BuiltinTool): - def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - headers = { "accept": "application/json", "content-type": "application/json", @@ -36,9 +32,5 @@ class FluxTool(BuiltinTool): res = response.json() result = [self.create_json_message(res)] for image in res.get("images", []): - result.append( - self.create_image_message( - image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value - ) - ) + result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py index e8134a6565..e7c3c28d7b 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py @@ -12,11 +12,9 @@ SDURL = { class StableDiffusionTool(BuiltinTool): - def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - headers = { "accept": "application/json", "content-type": "application/json", @@ -43,9 +41,5 @@ class StableDiffusionTool(BuiltinTool): res = response.json() result = [self.create_json_message(res)] for image in res.get("images", []): - result.append( - self.create_image_message( - image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value - ) - ) + result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/slack/tools/slack_webhook.py b/api/core/tools/provider/builtin/slack/tools/slack_webhook.py index f47557f2ef..85e0de7675 100644 --- a/api/core/tools/provider/builtin/slack/tools/slack_webhook.py +++ b/api/core/tools/provider/builtin/slack/tools/slack_webhook.py @@ -7,25 +7,27 @@ from core.tools.tool.builtin_tool import BuiltinTool class SlackWebhookTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - Incoming Webhooks - API Document: https://api.slack.com/messaging/webhooks + Incoming Webhooks + API Document: https://api.slack.com/messaging/webhooks """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - webhook_url = tool_parameters.get('webhook_url', '') + webhook_url = tool_parameters.get("webhook_url", "") - if not webhook_url.startswith('https://hooks.slack.com/'): + if not webhook_url.startswith("https://hooks.slack.com/"): return self.create_text_message( - f'Invalid parameter webhook_url ${webhook_url}, not a valid Slack webhook URL') + f"Invalid parameter webhook_url ${webhook_url}, not a valid Slack webhook URL" + ) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} payload = { @@ -38,6 +40,7 @@ class SlackWebhookTool(BuiltinTool): return self.create_text_message("Text message was sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to send message through webhook. {}".format(e)) \ No newline at end of file + return self.create_text_message("Failed to send message through webhook. {}".format(e)) diff --git a/api/core/tools/provider/builtin/spark/spark.py b/api/core/tools/provider/builtin/spark/spark.py index cb8e69a59f..e0b1a58a3f 100644 --- a/api/core/tools/provider/builtin/spark/spark.py +++ b/api/core/tools/provider/builtin/spark/spark.py @@ -29,12 +29,8 @@ class SparkProvider(BuiltinToolProviderController): # 0 success, pass else: - raise ToolProviderCredentialValidationError( - "image generate error, code:{}".format(code) - ) + raise ToolProviderCredentialValidationError("image generate error, code:{}".format(code)) except Exception as e: - raise ToolProviderCredentialValidationError( - "APPID APISecret APIKey is invalid. {}".format(e) - ) + raise ToolProviderCredentialValidationError("APPID APISecret APIKey is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py index c7b0de014f..a6f5570af2 100644 --- a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py +++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py @@ -47,26 +47,25 @@ def parse_url(request_url): u = Url(host, path, schema) return u + def assemble_ws_auth_url(request_url, method="GET", api_key="", api_secret=""): u = parse_url(request_url) host = u.host path = u.path now = datetime.now() date = format_date_time(mktime(now.timetuple())) - signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format( - host, date, method, path - ) + signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path) signature_sha = hmac.new( api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256, ).digest() signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8") - authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' - - authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( - encoding="utf-8" + authorization_origin = ( + f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' ) + + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") values = {"host": host, "date": date, "authorization": authorization} return request_url + "?" + urlencode(values) @@ -75,9 +74,7 @@ def assemble_ws_auth_url(request_url, method="GET", api_key="", api_secret=""): def get_body(appid, text): body = { "header": {"app_id": appid, "uid": "123456789"}, - "parameter": { - "chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096} - }, + "parameter": {"chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096}}, "payload": {"message": {"text": [{"role": "user", "content": text}]}}, } return body @@ -85,13 +82,9 @@ def get_body(appid, text): def spark_response(text, appid, apikey, apisecret): host = "http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti" - url = assemble_ws_auth_url( - host, method="POST", api_key=apikey, api_secret=apisecret - ) + url = assemble_ws_auth_url(host, method="POST", api_key=apikey, api_secret=apisecret) content = get_body(appid, text) - response = requests.post( - url, json=content, headers={"content-type": "application/json"} - ).text + response = requests.post(url, json=content, headers={"content-type": "application/json"}).text return response @@ -105,19 +98,11 @@ class SparkImgGeneratorTool(BuiltinTool): invoke tools """ - if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get( - "APPID" - ): + if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get("APPID"): return self.create_text_message("APPID is required.") - if ( - "APISecret" not in self.runtime.credentials - or not self.runtime.credentials.get("APISecret") - ): + if "APISecret" not in self.runtime.credentials or not self.runtime.credentials.get("APISecret"): return self.create_text_message("APISecret is required.") - if ( - "APIKey" not in self.runtime.credentials - or not self.runtime.credentials.get("APIKey") - ): + if "APIKey" not in self.runtime.credentials or not self.runtime.credentials.get("APIKey"): return self.create_text_message("APIKey is required.") prompt = tool_parameters.get("prompt", "") diff --git a/api/core/tools/provider/builtin/spider/spider.py b/api/core/tools/provider/builtin/spider/spider.py index 5bcc56a724..5959555318 100644 --- a/api/core/tools/provider/builtin/spider/spider.py +++ b/api/core/tools/provider/builtin/spider/spider.py @@ -8,13 +8,13 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class SpiderProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - app = Spider(api_key=credentials['spider_api_key']) - app.scrape_url(url='https://spider.cloud') + app = Spider(api_key=credentials["spider_api_key"]) + app.scrape_url(url="https://spider.cloud") except AttributeError as e: # Handle cases where NoneType is not iterable, which might indicate API issues - if 'NoneType' in str(e) and 'not iterable' in str(e): - raise ToolProviderCredentialValidationError('API is currently down, try again in 15 minutes', str(e)) + if "NoneType" in str(e) and "not iterable" in str(e): + raise ToolProviderCredentialValidationError("API is currently down, try again in 15 minutes", str(e)) else: - raise ToolProviderCredentialValidationError('An unexpected error occurred.', str(e)) + raise ToolProviderCredentialValidationError("An unexpected error occurred.", str(e)) except Exception as e: - raise ToolProviderCredentialValidationError('An unexpected error occurred.', str(e)) + raise ToolProviderCredentialValidationError("An unexpected error occurred.", str(e)) diff --git a/api/core/tools/provider/builtin/spider/spiderApp.py b/api/core/tools/provider/builtin/spider/spiderApp.py index f0ed64867a..3972e560c4 100644 --- a/api/core/tools/provider/builtin/spider/spiderApp.py +++ b/api/core/tools/provider/builtin/spider/spiderApp.py @@ -65,9 +65,7 @@ class Spider: :return: The JSON response or the raw response stream if stream is True. """ headers = self._prepare_headers(content_type) - response = self._post_request( - f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream - ) + response = self._post_request(f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream) if stream: return response @@ -76,9 +74,7 @@ class Spider: else: self._handle_error(response, f"post to {endpoint}") - def api_get( - self, endpoint: str, stream: bool, content_type: str = "application/json" - ): + def api_get(self, endpoint: str, stream: bool, content_type: str = "application/json"): """ Send a GET request to the specified endpoint. @@ -86,9 +82,7 @@ class Spider: :return: The JSON decoded response. """ headers = self._prepare_headers(content_type) - response = self._get_request( - f"https://api.spider.cloud/v1/{endpoint}", headers, stream - ) + response = self._get_request(f"https://api.spider.cloud/v1/{endpoint}", headers, stream) if response.status_code == 200: return response.json() else: @@ -120,14 +114,12 @@ class Spider: # Add { "return_format": "markdown" } to the params if not already present if "return_format" not in params: - params["return_format"] = "markdown" + params["return_format"] = "markdown" # Set limit to 1 params["limit"] = 1 - return self.api_post( - "crawl", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("crawl", {"url": url, **(params or {})}, stream, content_type) def crawl_url( self, @@ -150,9 +142,7 @@ class Spider: if "return_format" not in params: params["return_format"] = "markdown" - return self.api_post( - "crawl", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("crawl", {"url": url, **(params or {})}, stream, content_type) def links( self, @@ -168,9 +158,7 @@ class Spider: :param params: Optional parameters for the link retrieval request. :return: JSON response containing the links. """ - return self.api_post( - "links", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("links", {"url": url, **(params or {})}, stream, content_type) def extract_contacts( self, @@ -207,9 +195,7 @@ class Spider: :param params: Optional parameters to guide the labeling process. :return: JSON response with labeled data. """ - return self.api_post( - "pipeline/label", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("pipeline/label", {"url": url, **(params or {})}, stream, content_type) def _prepare_headers(self, content_type: str = "application/json"): return { @@ -230,10 +216,6 @@ class Spider: def _handle_error(self, response, action): if response.status_code in [402, 409, 500]: error_message = response.json().get("error", "Unknown error occurred") - raise Exception( - f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}" - ) + raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") else: - raise Exception( - f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}" - ) + raise Exception(f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}") diff --git a/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py index 40736cd402..20d2daef55 100644 --- a/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py +++ b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py @@ -6,41 +6,43 @@ from core.tools.tool.builtin_tool import BuiltinTool class ScrapeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # initialize the app object with the api key - app = Spider(api_key=self.runtime.credentials['spider_api_key']) + app = Spider(api_key=self.runtime.credentials["spider_api_key"]) + + url = tool_parameters["url"] + mode = tool_parameters["mode"] - url = tool_parameters['url'] - mode = tool_parameters['mode'] - options = { - 'limit': tool_parameters.get('limit', 0), - 'depth': tool_parameters.get('depth', 0), - 'blacklist': tool_parameters.get('blacklist', '').split(',') if tool_parameters.get('blacklist') else [], - 'whitelist': tool_parameters.get('whitelist', '').split(',') if tool_parameters.get('whitelist') else [], - 'readability': tool_parameters.get('readability', False), + "limit": tool_parameters.get("limit", 0), + "depth": tool_parameters.get("depth", 0), + "blacklist": tool_parameters.get("blacklist", "").split(",") if tool_parameters.get("blacklist") else [], + "whitelist": tool_parameters.get("whitelist", "").split(",") if tool_parameters.get("whitelist") else [], + "readability": tool_parameters.get("readability", False), } result = "" try: - if mode == 'scrape': + if mode == "scrape": scrape_result = app.scrape_url( - url=url, + url=url, params=options, ) for i in scrape_result: - result += "URL: " + i.get('url', '') + "\n" - result += "CONTENT: " + i.get('content', '') + "\n\n" - elif mode == 'crawl': + result += "URL: " + i.get("url", "") + "\n" + result += "CONTENT: " + i.get("content", "") + "\n\n" + elif mode == "crawl": crawl_result = app.crawl_url( - url=tool_parameters['url'], + url=tool_parameters["url"], params=options, ) for i in crawl_result: - result += "URL: " + i.get('url', '') + "\n" - result += "CONTENT: " + i.get('content', '') + "\n\n" + result += "URL: " + i.get("url", "") + "\n" + result += "CONTENT: " + i.get("content", "") + "\n\n" except Exception as e: return self.create_text_message("An error occurred", str(e)) diff --git a/api/core/tools/provider/builtin/stability/stability.py b/api/core/tools/provider/builtin/stability/stability.py index b31d786178..f09d81ac27 100644 --- a/api/core/tools/provider/builtin/stability/stability.py +++ b/api/core/tools/provider/builtin/stability/stability.py @@ -8,6 +8,7 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz """ This class is responsible for providing the stability tool. """ + def _validate_credentials(self, credentials: dict[str, Any]) -> None: """ This method is responsible for validating the credentials. diff --git a/api/core/tools/provider/builtin/stability/tools/base.py b/api/core/tools/provider/builtin/stability/tools/base.py index a4788fd869..c3b7edbefa 100644 --- a/api/core/tools/provider/builtin/stability/tools/base.py +++ b/api/core/tools/provider/builtin/stability/tools/base.py @@ -9,26 +9,23 @@ class BaseStabilityAuthorization: """ This method is responsible for validating the credentials. """ - api_key = credentials.get('api_key', '') + api_key = credentials.get("api_key", "") if not api_key: - raise ToolProviderCredentialValidationError('API key is required.') - + raise ToolProviderCredentialValidationError("API key is required.") + response = requests.get( - URL('https://api.stability.ai') / 'v1' / 'user' / 'account', + URL("https://api.stability.ai") / "v1" / "user" / "account", headers=self.generate_authorization_headers(credentials), - timeout=(5, 30) + timeout=(5, 30), ) if not response.ok: - raise ToolProviderCredentialValidationError('Invalid API key.') + raise ToolProviderCredentialValidationError("Invalid API key.") return True - + def generate_authorization_headers(self, credentials: dict) -> dict[str, str]: """ This method is responsible for generating the authorization headers. """ - return { - 'Authorization': f'Bearer {credentials.get("api_key", "")}' - } - \ No newline at end of file + return {"Authorization": f'Bearer {credentials.get("api_key", "")}'} diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py index 41236f7b43..c33e3bd78f 100644 --- a/api/core/tools/provider/builtin/stability/tools/text2image.py +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -11,10 +11,11 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): """ This class is responsible for providing the stable diffusion tool. """ + model_endpoint_map: dict[str, str] = { - 'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', - 'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', - 'core': 'https://api.stability.ai/v2beta/stable-image/generate/core', + "sd3": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + "sd3-turbo": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + "core": "https://api.stability.ai/v2beta/stable-image/generate/core", } def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: @@ -22,39 +23,34 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): Invoke the tool. """ payload = { - 'prompt': tool_parameters.get('prompt', ''), - 'aspect_ratio': tool_parameters.get('aspect_ratio', '16:9') or tool_parameters.get('aspect_radio', '16:9'), - 'mode': 'text-to-image', - 'seed': tool_parameters.get('seed', 0), - 'output_format': 'png', + "prompt": tool_parameters.get("prompt", ""), + "aspect_ratio": tool_parameters.get("aspect_ratio", "16:9") or tool_parameters.get("aspect_radio", "16:9"), + "mode": "text-to-image", + "seed": tool_parameters.get("seed", 0), + "output_format": "png", } - model = tool_parameters.get('model', 'core') + model = tool_parameters.get("model", "core") - if model in ['sd3', 'sd3-turbo']: - payload['model'] = tool_parameters.get('model') + if model in ["sd3", "sd3-turbo"]: + payload["model"] = tool_parameters.get("model") - if not model == 'sd3-turbo': - payload['negative_prompt'] = tool_parameters.get('negative_prompt', '') + if not model == "sd3-turbo": + payload["negative_prompt"] = tool_parameters.get("negative_prompt", "") response = post( - self.model_endpoint_map[tool_parameters.get('model', 'core')], + self.model_endpoint_map[tool_parameters.get("model", "core")], headers={ - 'accept': 'image/*', + "accept": "image/*", **self.generate_authorization_headers(self.runtime.credentials), }, - files={ - key: (None, str(value)) for key, value in payload.items() - }, - timeout=(5, 30) + files={key: (None, str(value)) for key, value in payload.items()}, + timeout=(5, 30), ) if not response.status_code == 200: raise Exception(response.text) - + return self.create_blob_message( - blob=response.content, meta={ - 'mime_type': 'image/png' - }, - save_as=self.VARIABLE_KEY.IMAGE.value + blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value ) diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py index 317d705f7c..abaa297cf3 100644 --- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py @@ -15,4 +15,3 @@ class StableDiffusionProvider(BuiltinToolProviderController): ).validate_models() except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 4be9207d66..c31e178067 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -18,19 +18,17 @@ DRAW_TEXT_OPTIONS = { # Prompts "prompt": "", "negative_prompt": "", - # "styles": [], - # Seeds + # "styles": [], + # Seeds "seed": -1, "subseed": -1, "subseed_strength": 0, "seed_resize_from_h": -1, "seed_resize_from_w": -1, - # Samplers "sampler_name": "DPM++ 2M", # "scheduler": "", # "sampler_index": "Automatic", - # Latent Space Options "batch_size": 1, "n_iter": 1, @@ -42,9 +40,9 @@ DRAW_TEXT_OPTIONS = { # "tiling": True, "do_not_save_samples": False, "do_not_save_grid": False, - # "eta": 0, - # "denoising_strength": 0.75, - # "s_min_uncond": 0, + # "eta": 0, + # "denoising_strength": 0.75, + # "s_min_uncond": 0, # "s_churn": 0, # "s_tmax": 0, # "s_tmin": 0, @@ -73,7 +71,6 @@ DRAW_TEXT_OPTIONS = { "hr_negative_prompt": "", # Task Options # "force_task_id": "", - # Script Options # "script_name": "", "script_args": [], @@ -82,131 +79,130 @@ DRAW_TEXT_OPTIONS = { "save_images": False, "alwayson_scripts": {}, # "infotext": "", - } class StableDiffusionTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # base url - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: - return self.create_text_message('Please input base_url') + return self.create_text_message("Please input base_url") - if tool_parameters.get('model'): - self.runtime.credentials['model'] = tool_parameters['model'] + if tool_parameters.get("model"): + self.runtime.credentials["model"] = tool_parameters["model"] - model = self.runtime.credentials.get('model', None) + model = self.runtime.credentials.get("model", None) if not model: - return self.create_text_message('Please input model') - + return self.create_text_message("Please input model") + # set model try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'options') - response = post(url, data=json.dumps({ - 'sd_model_checkpoint': model - })) + url = str(URL(base_url) / "sdapi" / "v1" / "options") + response = post(url, data=json.dumps({"sd_model_checkpoint": model})) if response.status_code != 200: - raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") except Exception as e: - raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") # get image id and image variable - image_id = tool_parameters.get('image_id', '') + image_id = tool_parameters.get("image_id", "") image_variable = self.get_default_image_variable() # Return text2img if there's no image ID or no image variable if not image_id or not image_variable: - return self.text2img(base_url=base_url,tool_parameters=tool_parameters) + return self.text2img(base_url=base_url, tool_parameters=tool_parameters) # Proceed with image-to-image generation - return self.img2img(base_url=base_url,tool_parameters=tool_parameters) + return self.img2img(base_url=base_url, tool_parameters=tool_parameters) def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - validate models + validate models """ try: - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: - raise ToolProviderCredentialValidationError('Please input base_url') - model = self.runtime.credentials.get('model', None) + raise ToolProviderCredentialValidationError("Please input base_url") + model = self.runtime.credentials.get("model", None) if not model: - raise ToolProviderCredentialValidationError('Please input model') + raise ToolProviderCredentialValidationError("Please input model") - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models') + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") response = get(url=api_url, timeout=10) if response.status_code == 404: # try draw a picture self._invoke( - user_id='test', + user_id="test", tool_parameters={ - 'prompt': 'a cat', - 'width': 1024, - 'height': 1024, - 'steps': 1, - 'lora': '', - } + "prompt": "a cat", + "width": 1024, + "height": 1024, + "steps": 1, + "lora": "", + }, ) elif response.status_code != 200: - raise ToolProviderCredentialValidationError('Failed to get models') + raise ToolProviderCredentialValidationError("Failed to get models") else: - models = [d['model_name'] for d in response.json()] + models = [d["model_name"] for d in response.json()] if len([d for d in models if d == model]) > 0: return self.create_text_message(json.dumps(models)) else: - raise ToolProviderCredentialValidationError(f'model {model} does not exist') + raise ToolProviderCredentialValidationError(f"model {model} does not exist") except Exception as e: - raise ToolProviderCredentialValidationError(f'Failed to get models, {e}') + raise ToolProviderCredentialValidationError(f"Failed to get models, {e}") def get_sd_models(self) -> list[str]: """ - get sd models + get sd models """ try: - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: return [] - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models') + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") response = get(url=api_url, timeout=(2, 10)) if response.status_code != 200: return [] else: - return [d['model_name'] for d in response.json()] - except Exception as e: - return [] - - def get_sample_methods(self) -> list[str]: - """ - get sample method - """ - try: - base_url = self.runtime.credentials.get('base_url', None) - if not base_url: - return [] - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers') - response = get(url=api_url, timeout=(2, 10)) - if response.status_code != 200: - return [] - else: - return [d['name'] for d in response.json()] + return [d["model_name"] for d in response.json()] except Exception as e: return [] - def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def get_sample_methods(self) -> list[str]: """ - generate image + get sample method + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [] + api_url = str(URL(base_url) / "sdapi" / "v1" / "samplers") + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [] + else: + return [d["name"] for d in response.json()] + except Exception as e: + return [] + + def img2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + generate image """ # Fetch the binary data of the image image_variable = self.get_default_image_variable() image_binary = self.get_variable_file(image_variable.name) if not image_binary: - return self.create_text_message('Image not found, please request user to generate image firstly.') + return self.create_text_message("Image not found, please request user to generate image firstly.") # Convert image to RGB and save as PNG try: @@ -220,14 +216,14 @@ class StableDiffusionTool(BuiltinTool): # copy draw options draw_options = deepcopy(DRAW_TEXT_OPTIONS) # set image options - model = tool_parameters.get('model', '') + model = tool_parameters.get("model", "") draw_options_image = { - "init_images": [b64encode(image_binary).decode('utf-8')], + "init_images": [b64encode(image_binary).decode("utf-8")], "denoising_strength": 0.9, "restore_faces": False, "script_args": [], "override_settings": {"sd_model_checkpoint": model}, - "resize_mode":0, + "resize_mode": 0, "image_cfg_scale": 0, # "mask": None, "mask_blur_x": 4, @@ -247,136 +243,142 @@ class StableDiffusionTool(BuiltinTool): draw_options.update(tool_parameters) # get prompt lora model - prompt = tool_parameters.get('prompt', '') - lora = tool_parameters.get('lora', '') - model = tool_parameters.get('model', '') + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") if lora: - draw_options['prompt'] = f'{lora},{prompt}' + draw_options["prompt"] = f"{lora},{prompt}" else: - draw_options['prompt'] = prompt + draw_options["prompt"] = prompt try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img') + url = str(URL(base_url) / "sdapi" / "v1" / "img2img") response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: - return self.create_text_message('Failed to generate image') - - image = response.json()['images'][0] + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + ) - return self.create_blob_message(blob=b64decode(image), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value) - except Exception as e: - return self.create_text_message('Failed to generate image') + return self.create_text_message("Failed to generate image") - def text2img(self, base_url: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def text2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - generate image + generate image """ # copy draw options draw_options = deepcopy(DRAW_TEXT_OPTIONS) draw_options.update(tool_parameters) # get prompt lora model - prompt = tool_parameters.get('prompt', '') - lora = tool_parameters.get('lora', '') - model = tool_parameters.get('model', '') + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") if lora: - draw_options['prompt'] = f'{lora},{prompt}' + draw_options["prompt"] = f"{lora},{prompt}" else: - draw_options['prompt'] = prompt - draw_options['override_settings']['sd_model_checkpoint'] = model + draw_options["prompt"] = prompt + draw_options["override_settings"]["sd_model_checkpoint"] = model - try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img') + url = str(URL(base_url) / "sdapi" / "v1" / "txt2img") response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: - return self.create_text_message('Failed to generate image') - - image = response.json()['images'][0] + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + ) - return self.create_blob_message(blob=b64decode(image), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value) - except Exception as e: - return self.create_text_message('Failed to generate image') + return self.create_text_message("Failed to generate image") def get_runtime_parameters(self) -> list[ToolParameter]: parameters = [ - ToolParameter(name='prompt', - label=I18nObject(en_US='Prompt', zh_Hans='Prompt'), - human_description=I18nObject( - en_US='Image prompt, you can check the official documentation of Stable Diffusion', - zh_Hans='图像提示词,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.', - required=True), + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="Prompt"), + human_description=I18nObject( + en_US="Image prompt, you can check the official documentation of Stable Diffusion", + zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.", + required=True, + ), ] if len(self.list_default_image_variables()) != 0: parameters.append( - ToolParameter(name='image_id', - label=I18nObject(en_US='image_id', zh_Hans='image_id'), - human_description=I18nObject( - en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.', - zh_Hans='您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。', - ), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Image id of the original image, you can leave this field empty if you want to generate a new image.', - required=True, - options=[ToolParameterOption( - value=i.name, - label=I18nObject(en_US=i.name, zh_Hans=i.name) - ) for i in self.list_default_image_variables()]) + ToolParameter( + name="image_id", + label=I18nObject(en_US="image_id", zh_Hans="image_id"), + human_description=I18nObject( + en_US="Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.", + zh_Hans="您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image id of the original image, you can leave this field empty if you want to generate a new image.", + required=True, + options=[ + ToolParameterOption(value=i.name, label=I18nObject(en_US=i.name, zh_Hans=i.name)) + for i in self.list_default_image_variables() + ], + ) ) - + if self.runtime.credentials: try: models = self.get_sd_models() if len(models) != 0: parameters.append( - ToolParameter(name='model', - label=I18nObject(en_US='Model', zh_Hans='Model'), - human_description=I18nObject( - en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion', - zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.SELECT, - form=ToolParameter.ToolParameterForm.FORM, - llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion', - required=True, - default=models[0], - options=[ToolParameterOption( - value=i, - label=I18nObject(en_US=i, zh_Hans=i) - ) for i in models]) + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="Model of Stable Diffusion, you can check the official documentation of Stable Diffusion", + zh_Hans="Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Model of Stable Diffusion, you can check the official documentation of Stable Diffusion", + required=True, + default=models[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models + ], + ) ) - + except: pass - + sample_methods = self.get_sample_methods() if len(sample_methods) != 0: parameters.append( - ToolParameter(name='sampler_name', - label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'), - human_description=I18nObject( - en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', - zh_Hans='Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.SELECT, - form=ToolParameter.ToolParameterForm.FORM, - llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', - required=True, - default=sample_methods[0], - options=[ToolParameterOption( - value=i, - label=I18nObject(en_US=i, zh_Hans=i) - ) for i in sample_methods]) + ToolParameter( + name="sampler_name", + label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"), + human_description=I18nObject( + en_US="Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion", + zh_Hans="Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion", + required=True, + default=sample_methods[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in sample_methods + ], ) + ) return parameters diff --git a/api/core/tools/provider/builtin/stackexchange/stackexchange.py b/api/core/tools/provider/builtin/stackexchange/stackexchange.py index de64c84997..9680c633cc 100644 --- a/api/core/tools/provider/builtin/stackexchange/stackexchange.py +++ b/api/core/tools/provider/builtin/stackexchange/stackexchange.py @@ -11,16 +11,15 @@ class StackExchangeProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "intitle": "Test", - "sort": "relevance", + "sort": "relevance", "order": "desc", "site": "stackoverflow", "accepted": True, - "pagesize": 1 + "pagesize": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py b/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py index f8e1710844..5345320095 100644 --- a/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py +++ b/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py @@ -17,7 +17,9 @@ class FetchAnsByStackExQuesIDInput(BaseModel): class FetchAnsByStackExQuesIDTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: input = FetchAnsByStackExQuesIDInput(**tool_parameters) params = { @@ -26,7 +28,7 @@ class FetchAnsByStackExQuesIDTool(BuiltinTool): "order": input.order, "sort": input.sort, "pagesize": input.pagesize, - "page": input.page + "page": input.page, } response = requests.get(f"https://api.stackexchange.com/2.3/questions/{input.id}/answers", params=params) @@ -34,4 +36,4 @@ class FetchAnsByStackExQuesIDTool(BuiltinTool): if response.status_code == 200: return self.create_text_message(self.summary(user_id=user_id, content=response.text)) else: - return self.create_text_message(f"API request failed with status code {response.status_code}") \ No newline at end of file + return self.create_text_message(f"API request failed with status code {response.status_code}") diff --git a/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py b/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py index 8436433c32..4a25a808ad 100644 --- a/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py +++ b/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py @@ -9,26 +9,28 @@ from core.tools.tool.builtin_tool import BuiltinTool class SearchStackExQuestionsInput(BaseModel): intitle: str = Field(..., description="The search query.") - sort: str = Field(..., description="The sort order - relevance, activity, votes, creation.") + sort: str = Field(..., description="The sort order - relevance, activity, votes, creation.") order: str = Field(..., description="asc or desc") site: str = Field(..., description="The Stack Exchange site.") tagged: str = Field(None, description="Semicolon-separated tags to include.") nottagged: str = Field(None, description="Semicolon-separated tags to exclude.") - accepted: bool = Field(..., description="true for only accepted answers, false otherwise") + accepted: bool = Field(..., description="true for only accepted answers, false otherwise") pagesize: int = Field(..., description="Number of results per page") class SearchStackExQuestionsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: input = SearchStackExQuestionsInput(**tool_parameters) params = { "intitle": input.intitle, "sort": input.sort, - "order": input.order, + "order": input.order, "site": input.site, "accepted": input.accepted, - "pagesize": input.pagesize + "pagesize": input.pagesize, } if input.tagged: params["tagged"] = input.tagged @@ -40,4 +42,4 @@ class SearchStackExQuestionsTool(BuiltinTool): if response.status_code == 200: return self.create_text_message(self.summary(user_id=user_id, content=response.text)) else: - return self.create_text_message(f"API request failed with status code {response.status_code}") \ No newline at end of file + return self.create_text_message(f"API request failed with status code {response.status_code}") diff --git a/api/core/tools/provider/builtin/stepfun/stepfun.py b/api/core/tools/provider/builtin/stepfun/stepfun.py index e809b04546..b24f730c95 100644 --- a/api/core/tools/provider/builtin/stepfun/stepfun.py +++ b/api/core/tools/provider/builtin/stepfun/stepfun.py @@ -13,13 +13,12 @@ class StepfunProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "prompt": "cute girl, blue eyes, white hair, anime style", "size": "1024x1024", - "n": 1 + "n": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stepfun/tools/image.py b/api/core/tools/provider/builtin/stepfun/tools/image.py index c571f54675..0b92b122bf 100644 --- a/api/core/tools/provider/builtin/stepfun/tools/image.py +++ b/api/core/tools/provider/builtin/stepfun/tools/image.py @@ -9,61 +9,67 @@ from core.tools.tool.builtin_tool import BuiltinTool class StepfunTool(BuiltinTool): - """ Stepfun Image Generation Tool """ - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """Stepfun Image Generation Tool""" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - base_url = self.runtime.credentials.get('stepfun_base_url', 'https://api.stepfun.com') - base_url = str(URL(base_url) / 'v1') + base_url = self.runtime.credentials.get("stepfun_base_url", "https://api.stepfun.com") + base_url = str(URL(base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['stepfun_api_key'], + api_key=self.runtime.credentials["stepfun_api_key"], base_url=base_url, ) extra_body = {} - model = tool_parameters.get('model', 'step-1x-medium') + model = tool_parameters.get("model", "step-1x-medium") if not model: - return self.create_text_message('Please input model name') + return self.create_text_message("Please input model name") # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") - seed = tool_parameters.get('seed', 0) + seed = tool_parameters.get("seed", 0) if seed > 0: - extra_body['seed'] = seed - steps = tool_parameters.get('steps', 0) + extra_body["seed"] = seed + steps = tool_parameters.get("steps", 0) if steps > 0: - extra_body['steps'] = steps - negative_prompt = tool_parameters.get('negative_prompt', '') + extra_body["steps"] = steps + negative_prompt = tool_parameters.get("negative_prompt", "") if negative_prompt: - extra_body['negative_prompt'] = negative_prompt + extra_body["negative_prompt"] = negative_prompt # call openapi stepfun model response = client.images.generate( prompt=prompt, model=model, - size=tool_parameters.get('size', '1024x1024'), - n=tool_parameters.get('n', 1), - extra_body= extra_body + size=tool_parameters.get("size", "1024x1024"), + n=tool_parameters.get("n", 1), + extra_body=extra_body, ) print(response) result = [] for image in response.data: result.append(self.create_image_message(image=image.url)) - result.append(self.create_json_message({ - "url": image.url, - })) + result.append( + self.create_json_message( + { + "url": image.url, + } + ) + ) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/tavily/tavily.py b/api/core/tools/provider/builtin/tavily/tavily.py index e376d99d6b..a702b0a74e 100644 --- a/api/core/tools/provider/builtin/tavily/tavily.py +++ b/api/core/tools/provider/builtin/tavily/tavily.py @@ -13,7 +13,7 @@ class TavilyProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "Sachin Tendulkar", "search_depth": "basic", @@ -22,9 +22,8 @@ class TavilyProvider(BuiltinToolProviderController): "include_raw_content": False, "max_results": 5, "include_domains": "", - "exclude_domains": "" + "exclude_domains": "", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py index 0200df3c8a..ca6d8633e4 100644 --- a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py +++ b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py @@ -36,15 +36,23 @@ class TavilySearch: """ params["api_key"] = self.api_key - if 'exclude_domains' in params and isinstance(params['exclude_domains'], str) and params['exclude_domains'] != 'None': - params['exclude_domains'] = params['exclude_domains'].split() + if ( + "exclude_domains" in params + and isinstance(params["exclude_domains"], str) + and params["exclude_domains"] != "None" + ): + params["exclude_domains"] = params["exclude_domains"].split() else: - params['exclude_domains'] = [] - if 'include_domains' in params and isinstance(params['include_domains'], str) and params['include_domains'] != 'None': - params['include_domains'] = params['include_domains'].split() + params["exclude_domains"] = [] + if ( + "include_domains" in params + and isinstance(params["include_domains"], str) + and params["include_domains"] != "None" + ): + params["include_domains"] = params["include_domains"].split() else: - params['include_domains'] = [] - + params["include_domains"] = [] + response = requests.post(f"{TAVILY_API_URL}/search", json=params) response.raise_for_status() return response.json() @@ -91,9 +99,7 @@ class TavilySearchTool(BuiltinTool): A tool for searching Tavily using a given query. """ - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ Invokes the Tavily search tool with the given user ID and tool parameters. @@ -115,4 +121,4 @@ class TavilySearchTool(BuiltinTool): if not results: return self.create_text_message(f"No results found for '{query}' in Tavily") else: - return self.create_text_message(text=results) \ No newline at end of file + return self.create_text_message(text=results) diff --git a/api/core/tools/provider/builtin/tianditu/tianditu.py b/api/core/tools/provider/builtin/tianditu/tianditu.py index 1f96be06b0..cb7d7bd8bb 100644 --- a/api/core/tools/provider/builtin/tianditu/tianditu.py +++ b/api/core/tools/provider/builtin/tianditu/tianditu.py @@ -12,10 +12,12 @@ class TiandituProvider(BuiltinToolProviderController): runtime={ "credentials": credentials, } - ).invoke(user_id='', - tool_parameters={ - 'content': '北京', - 'specify': '156110000', - }) + ).invoke( + user_id="", + tool_parameters={ + "content": "北京", + "specify": "156110000", + }, + ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/tianditu/tools/geocoder.py b/api/core/tools/provider/builtin/tianditu/tools/geocoder.py index 484a3768c8..690a0aed6f 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/geocoder.py +++ b/api/core/tools/provider/builtin/tianditu/tools/geocoder.py @@ -8,26 +8,26 @@ from core.tools.tool.builtin_tool import BuiltinTool class GeocoderTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - base_url = 'http://api.tianditu.gov.cn/geocoder' - - keyword = tool_parameters.get('keyword', '') + base_url = "http://api.tianditu.gov.cn/geocoder" + + keyword = tool_parameters.get("keyword", "") if not keyword: - return self.create_text_message('Invalid parameter keyword') - - tk = self.runtime.credentials['tianditu_api_key'] - + return self.create_text_message("Invalid parameter keyword") + + tk = self.runtime.credentials["tianditu_api_key"] + params = { - 'keyWord': keyword, + "keyWord": keyword, } - - result = requests.get(base_url + '?ds=' + json.dumps(params, ensure_ascii=False) + '&tk=' + tk).json() + + result = requests.get(base_url + "?ds=" + json.dumps(params, ensure_ascii=False) + "&tk=" + tk).json() return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/tianditu/tools/poisearch.py b/api/core/tools/provider/builtin/tianditu/tools/poisearch.py index 08a5b8ef42..798dd94d33 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/poisearch.py +++ b/api/core/tools/provider/builtin/tianditu/tools/poisearch.py @@ -8,38 +8,51 @@ from core.tools.tool.builtin_tool import BuiltinTool class PoiSearchTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder' - base_url = 'http://api.tianditu.gov.cn/v2/search' - - keyword = tool_parameters.get('keyword', '') + geocoder_base_url = "http://api.tianditu.gov.cn/geocoder" + base_url = "http://api.tianditu.gov.cn/v2/search" + + keyword = tool_parameters.get("keyword", "") if not keyword: - return self.create_text_message('Invalid parameter keyword') - - baseAddress = tool_parameters.get('baseAddress', '') + return self.create_text_message("Invalid parameter keyword") + + baseAddress = tool_parameters.get("baseAddress", "") if not baseAddress: - return self.create_text_message('Invalid parameter baseAddress') - - tk = self.runtime.credentials['tianditu_api_key'] - - base_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': baseAddress,}, ensure_ascii=False) + '&tk=' + tk).json() - + return self.create_text_message("Invalid parameter baseAddress") + + tk = self.runtime.credentials["tianditu_api_key"] + + base_coords = requests.get( + geocoder_base_url + + "?ds=" + + json.dumps( + { + "keyWord": baseAddress, + }, + ensure_ascii=False, + ) + + "&tk=" + + tk + ).json() + params = { - 'keyWord': keyword, - 'queryRadius': 5000, - 'queryType': 3, - 'pointLonlat': base_coords['location']['lon'] + ',' + base_coords['location']['lat'], - 'start': 0, - 'count': 100, + "keyWord": keyword, + "queryRadius": 5000, + "queryType": 3, + "pointLonlat": base_coords["location"]["lon"] + "," + base_coords["location"]["lat"], + "start": 0, + "count": 100, } - - result = requests.get(base_url + '?postStr=' + json.dumps(params, ensure_ascii=False) + '&type=query&tk=' + tk).json() + + result = requests.get( + base_url + "?postStr=" + json.dumps(params, ensure_ascii=False) + "&type=query&tk=" + tk + ).json() return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py index ecac4404ca..93803d7937 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py +++ b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py @@ -8,29 +8,42 @@ from core.tools.tool.builtin_tool import BuiltinTool class PoiSearchTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - - geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder' - base_url = 'http://api.tianditu.gov.cn/staticimage' - - keyword = tool_parameters.get('keyword', '') - if not keyword: - return self.create_text_message('Invalid parameter keyword') - - tk = self.runtime.credentials['tianditu_api_key'] - - keyword_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': keyword,}, ensure_ascii=False) + '&tk=' + tk).json() - coords = keyword_coords['location']['lon'] + ',' + keyword_coords['location']['lat'] - - result = requests.get(base_url + '?center=' + coords + '&markers=' + coords + '&width=400&height=300&zoom=14&tk=' + tk).content - return self.create_blob_message(blob=result, - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + geocoder_base_url = "http://api.tianditu.gov.cn/geocoder" + base_url = "http://api.tianditu.gov.cn/staticimage" + + keyword = tool_parameters.get("keyword", "") + if not keyword: + return self.create_text_message("Invalid parameter keyword") + + tk = self.runtime.credentials["tianditu_api_key"] + + keyword_coords = requests.get( + geocoder_base_url + + "?ds=" + + json.dumps( + { + "keyWord": keyword, + }, + ensure_ascii=False, + ) + + "&tk=" + + tk + ).json() + coords = keyword_coords["location"]["lon"] + "," + keyword_coords["location"]["lat"] + + result = requests.get( + base_url + "?center=" + coords + "&markers=" + coords + "&width=400&height=300&zoom=14&tk=" + tk + ).content + + return self.create_blob_message( + blob=result, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + ) diff --git a/api/core/tools/provider/builtin/time/time.py b/api/core/tools/provider/builtin/time/time.py index 833ae194ef..e4df8d616c 100644 --- a/api/core/tools/provider/builtin/time/time.py +++ b/api/core/tools/provider/builtin/time/time.py @@ -9,9 +9,8 @@ class WikiPediaProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: CurrentTimeTool().invoke( - user_id='', + user_id="", tool_parameters={}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/time/tools/current_time.py b/api/core/tools/provider/builtin/time/tools/current_time.py index 90c01665e6..cc38739c16 100644 --- a/api/core/tools/provider/builtin/time/tools/current_time.py +++ b/api/core/tools/provider/builtin/time/tools/current_time.py @@ -8,21 +8,22 @@ from core.tools.tool.builtin_tool import BuiltinTool class CurrentTimeTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get timezone - tz = tool_parameters.get('timezone', 'UTC') - fm = tool_parameters.get('format') or '%Y-%m-%d %H:%M:%S %Z' - if tz == 'UTC': - return self.create_text_message(f'{datetime.now(timezone.utc).strftime(fm)}') - + tz = tool_parameters.get("timezone", "UTC") + fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z" + if tz == "UTC": + return self.create_text_message(f"{datetime.now(timezone.utc).strftime(fm)}") + try: tz = pytz_timezone(tz) except: - return self.create_text_message(f'Invalid timezone: {tz}') - return self.create_text_message(f'{datetime.now(tz).strftime(fm)}') \ No newline at end of file + return self.create_text_message(f"Invalid timezone: {tz}") + return self.create_text_message(f"{datetime.now(tz).strftime(fm)}") diff --git a/api/core/tools/provider/builtin/time/tools/weekday.py b/api/core/tools/provider/builtin/time/tools/weekday.py index 4461cb5a32..b327e54e17 100644 --- a/api/core/tools/provider/builtin/time/tools/weekday.py +++ b/api/core/tools/provider/builtin/time/tools/weekday.py @@ -7,25 +7,26 @@ from core.tools.tool.builtin_tool import BuiltinTool class WeekdayTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - Calculate the day of the week for a given date + Calculate the day of the week for a given date """ - year = tool_parameters.get('year') - month = tool_parameters.get('month') - day = tool_parameters.get('day') + year = tool_parameters.get("year") + month = tool_parameters.get("month") + day = tool_parameters.get("day") date_obj = self.convert_datetime(year, month, day) if not date_obj: - return self.create_text_message(f'Invalid date: Year {year}, Month {month}, Day {day}.') + return self.create_text_message(f"Invalid date: Year {year}, Month {month}, Day {day}.") weekday_name = calendar.day_name[date_obj.weekday()] month_name = calendar.month_name[month] readable_date = f"{month_name} {date_obj.day}, {date_obj.year}" - return self.create_text_message(f'{readable_date} is {weekday_name}.') + return self.create_text_message(f"{readable_date} is {weekday_name}.") @staticmethod def convert_datetime(year, month, day) -> datetime | None: diff --git a/api/core/tools/provider/builtin/trello/tools/create_board.py b/api/core/tools/provider/builtin/trello/tools/create_board.py index 2655602afa..5a61d22157 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_board.py @@ -22,19 +22,15 @@ class CreateBoardTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_name = tool_parameters.get('name') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_name = tool_parameters.get("name") if not (api_key and token and board_name): return self.create_text_message("Missing required parameters: API key, token, or board name.") url = "https://api.trello.com/1/boards/" - query_params = { - 'name': board_name, - 'key': api_key, - 'token': token - } + query_params = {"name": board_name, "key": api_key, "token": token} try: response = requests.post(url, params=query_params) @@ -43,5 +39,6 @@ class CreateBoardTool(BuiltinTool): return self.create_text_message("Failed to create board") board = response.json() - return self.create_text_message(text=f"Board created successfully! Board name: {board['name']}, ID: {board['id']}") - + return self.create_text_message( + text=f"Board created successfully! Board name: {board['name']}, ID: {board['id']}" + ) diff --git a/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py b/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py index f5b156cb44..26f12864c3 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py @@ -22,20 +22,16 @@ class CreateListOnBoardTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('id') - list_name = tool_parameters.get('name') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("id") + list_name = tool_parameters.get("name") if not (api_key and token and board_id and list_name): return self.create_text_message("Missing required parameters: API key, token, board ID, or list name.") url = f"https://api.trello.com/1/boards/{board_id}/lists" - params = { - 'name': list_name, - 'key': api_key, - 'token': token - } + params = {"name": list_name, "key": api_key, "token": token} try: response = requests.post(url, params=params) @@ -44,5 +40,6 @@ class CreateListOnBoardTool(BuiltinTool): return self.create_text_message("Failed to create list") new_list = response.json() - return self.create_text_message(text=f"List '{new_list['name']}' created successfully with Id {new_list['id']} on board {board_id}.") - + return self.create_text_message( + text=f"List '{new_list['name']}' created successfully with Id {new_list['id']} on board {board_id}." + ) diff --git a/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py b/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py index 74b73b40e5..dfc013a6b8 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py @@ -22,15 +22,15 @@ class CreateNewCardOnBoardTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") # Ensure required parameters are present - if 'name' not in tool_parameters or 'idList' not in tool_parameters: + if "name" not in tool_parameters or "idList" not in tool_parameters: return self.create_text_message("Missing required parameters: name or idList.") url = "https://api.trello.com/1/cards" - params = {**tool_parameters, 'key': api_key, 'token': token} + params = {**tool_parameters, "key": api_key, "token": token} try: response = requests.post(url, params=params) @@ -39,5 +39,6 @@ class CreateNewCardOnBoardTool(BuiltinTool): except requests.exceptions.RequestException as e: return self.create_text_message("Failed to create card") - return self.create_text_message(text=f"New card '{new_card['name']}' created successfully with ID {new_card['id']}.") - + return self.create_text_message( + text=f"New card '{new_card['name']}' created successfully with ID {new_card['id']}." + ) diff --git a/api/core/tools/provider/builtin/trello/tools/delete_board.py b/api/core/tools/provider/builtin/trello/tools/delete_board.py index 29df3fda2d..9dbd8f78d5 100644 --- a/api/core/tools/provider/builtin/trello/tools/delete_board.py +++ b/api/core/tools/provider/builtin/trello/tools/delete_board.py @@ -22,9 +22,9 @@ class DeleteBoardTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -38,4 +38,3 @@ class DeleteBoardTool(BuiltinTool): return self.create_text_message("Failed to delete board") return self.create_text_message(text=f"Board with ID {board_id} deleted successfully.") - diff --git a/api/core/tools/provider/builtin/trello/tools/delete_card.py b/api/core/tools/provider/builtin/trello/tools/delete_card.py index 2ced5f6c14..960c3055fe 100644 --- a/api/core/tools/provider/builtin/trello/tools/delete_card.py +++ b/api/core/tools/provider/builtin/trello/tools/delete_card.py @@ -22,9 +22,9 @@ class DeleteCardByIdTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - card_id = tool_parameters.get('id') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + card_id = tool_parameters.get("id") if not (api_key and token and card_id): return self.create_text_message("Missing required parameters: API key, token, or card ID.") @@ -38,4 +38,3 @@ class DeleteCardByIdTool(BuiltinTool): return self.create_text_message("Failed to delete card") return self.create_text_message(text=f"Card with ID {card_id} has been successfully deleted.") - diff --git a/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py b/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py index f9d554c6fb..0c5ed9ea85 100644 --- a/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py +++ b/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py @@ -28,9 +28,7 @@ class FetchAllBoardsTool(BuiltinTool): token = self.runtime.credentials.get("trello_api_token") if not (api_key and token): - return self.create_text_message( - "Missing Trello API key or token in credentials." - ) + return self.create_text_message("Missing Trello API key or token in credentials.") # Including board filter in the request if provided board_filter = tool_parameters.get("boards", "open") @@ -48,7 +46,5 @@ class FetchAllBoardsTool(BuiltinTool): return self.create_text_message("No boards found in Trello.") # Creating a string with both board names and IDs - boards_info = ", ".join( - [f"{board['name']} (ID: {board['id']})" for board in boards] - ) + boards_info = ", ".join([f"{board['name']} (ID: {board['id']})" for board in boards]) return self.create_text_message(text=f"Boards: {boards_info}") diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_actions.py b/api/core/tools/provider/builtin/trello/tools/get_board_actions.py index 5678d8f8d7..03510f1964 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_actions.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_actions.py @@ -22,9 +22,9 @@ class GetBoardActionsTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -38,6 +38,7 @@ class GetBoardActionsTool(BuiltinTool): except requests.exceptions.RequestException as e: return self.create_text_message("Failed to retrieve board actions") - actions_summary = "\n".join([f"{action['type']}: {action.get('data', {}).get('text', 'No details available')}" for action in actions]) + actions_summary = "\n".join( + [f"{action['type']}: {action.get('data', {}).get('text', 'No details available')}" for action in actions] + ) return self.create_text_message(text=f"Actions for Board ID {board_id}:\n{actions_summary}") - diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py b/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py index ee6cb065e5..5b41b128d0 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py @@ -22,9 +22,9 @@ class GetBoardByIdTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -63,4 +63,3 @@ class GetBoardByIdTool(BuiltinTool): f"Background Color: {board['prefs']['backgroundColor']}" ) return details - diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_cards.py b/api/core/tools/provider/builtin/trello/tools/get_board_cards.py index 1abb688750..e3bed2e6e6 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_cards.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_cards.py @@ -22,9 +22,9 @@ class GetBoardCardsTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -40,4 +40,3 @@ class GetBoardCardsTool(BuiltinTool): cards_summary = "\n".join([f"{card['name']} (ID: {card['id']})" for card in cards]) return self.create_text_message(text=f"Cards for Board ID {board_id}:\n{cards_summary}") - diff --git a/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py b/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py index 375ead5b1d..4d8854747c 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py +++ b/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py @@ -22,10 +22,10 @@ class GetFilteredBoardCardsTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') - filter = tool_parameters.get('filter') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") + filter = tool_parameters.get("filter") if not (api_key and token and board_id and filter): return self.create_text_message("Missing required parameters: API key, token, board ID, or filter.") @@ -40,5 +40,6 @@ class GetFilteredBoardCardsTool(BuiltinTool): return self.create_text_message("Failed to retrieve filtered cards") card_details = "\n".join([f"{card['name']} (ID: {card['id']})" for card in filtered_cards]) - return self.create_text_message(text=f"Filtered Cards for Board ID {board_id} with Filter '{filter}':\n{card_details}") - + return self.create_text_message( + text=f"Filtered Cards for Board ID {board_id} with Filter '{filter}':\n{card_details}" + ) diff --git a/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py b/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py index 7b9b9cf24b..ca8aa9c2d5 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py @@ -22,9 +22,9 @@ class GetListsFromBoardTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -40,4 +40,3 @@ class GetListsFromBoardTool(BuiltinTool): lists_info = "\n".join([f"{list['name']} (ID: {list['id']})" for list in lists]) return self.create_text_message(text=f"Lists on Board ID {board_id}:\n{lists_info}") - diff --git a/api/core/tools/provider/builtin/trello/tools/update_board.py b/api/core/tools/provider/builtin/trello/tools/update_board.py index 7ad6ac2e64..62681eea6b 100644 --- a/api/core/tools/provider/builtin/trello/tools/update_board.py +++ b/api/core/tools/provider/builtin/trello/tools/update_board.py @@ -22,9 +22,9 @@ class UpdateBoardByIdTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.pop('boardId', None) + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.pop("boardId", None) if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -33,8 +33,8 @@ class UpdateBoardByIdTool(BuiltinTool): # Removing parameters not intended for update action or with None value params = {k: v for k, v in tool_parameters.items() if v is not None} - params['key'] = api_key - params['token'] = token + params["key"] = api_key + params["token"] = token try: response = requests.put(url, params=params) @@ -44,4 +44,3 @@ class UpdateBoardByIdTool(BuiltinTool): updated_board = response.json() return self.create_text_message(text=f"Board '{updated_board['name']}' updated successfully.") - diff --git a/api/core/tools/provider/builtin/trello/tools/update_card.py b/api/core/tools/provider/builtin/trello/tools/update_card.py index 417344350c..26113f1229 100644 --- a/api/core/tools/provider/builtin/trello/tools/update_card.py +++ b/api/core/tools/provider/builtin/trello/tools/update_card.py @@ -22,17 +22,17 @@ class UpdateCardByIdTool(BuiltinTool): Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - card_id = tool_parameters.get('id') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + card_id = tool_parameters.get("id") if not (api_key and token and card_id): return self.create_text_message("Missing required parameters: API key, token, or card ID.") # Constructing the URL and the payload for the PUT request url = f"https://api.trello.com/1/cards/{card_id}" - params = {k: v for k, v in tool_parameters.items() if v is not None and k != 'id'} - params.update({'key': api_key, 'token': token}) + params = {k: v for k, v in tool_parameters.items() if v is not None and k != "id"} + params.update({"key": api_key, "token": token}) try: response = requests.put(url, params=params) diff --git a/api/core/tools/provider/builtin/trello/trello.py b/api/core/tools/provider/builtin/trello/trello.py index 84ecd20803..e0dca50ec9 100644 --- a/api/core/tools/provider/builtin/trello/trello.py +++ b/api/core/tools/provider/builtin/trello/trello.py @@ -9,17 +9,17 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class TrelloProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: """Validate Trello API credentials by making a test API call. - + Args: credentials (dict[str, Any]): The Trello API credentials to validate. - + Raises: ToolProviderCredentialValidationError: If the credentials are invalid. """ api_key = credentials.get("trello_api_key") token = credentials.get("trello_api_token") url = f"https://api.trello.com/1/members/me?key={api_key}&token={token}" - + try: response = requests.get(url) response.raise_for_status() # Raises an HTTPError for bad responses @@ -32,4 +32,3 @@ class TrelloProvider(BuiltinToolProviderController): except requests.exceptions.RequestException as e: # Handle other exceptions, such as connection errors raise ToolProviderCredentialValidationError("Error validating Trello credentials") - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py index 1c52589956..822d0c0ebd 100644 --- a/api/core/tools/provider/builtin/twilio/tools/send_message.py +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -32,17 +32,14 @@ class TwilioAPIWrapper(BaseModel): must be empty. """ - @field_validator('client', mode='before') + @field_validator("client", mode="before") @classmethod def set_validator(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" try: from twilio.rest import Client except ImportError: - raise ImportError( - "Could not import twilio python package. " - "Please install it with `pip install twilio`." - ) + raise ImportError("Could not import twilio python package. " "Please install it with `pip install twilio`.") account_sid = values.get("account_sid") auth_token = values.get("auth_token") values["from_number"] = values.get("from_number") @@ -91,9 +88,7 @@ class SendMessageTool(BuiltinTool): if to_number.startswith("whatsapp:"): from_number = f"whatsapp: {from_number}" - twilio = TwilioAPIWrapper( - account_sid=account_sid, auth_token=auth_token, from_number=from_number - ) + twilio = TwilioAPIWrapper(account_sid=account_sid, auth_token=auth_token, from_number=from_number) # Sending the message through Twilio result = twilio.run(message, to_number) diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index 06f276053a..b1d100aad9 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -14,7 +14,7 @@ class TwilioProvider(BuiltinToolProviderController): account_sid = credentials["account_sid"] auth_token = credentials["auth_token"] from_number = credentials["from_number"] - + # Initialize twilio client client = Client(account_sid, auth_token) @@ -27,4 +27,3 @@ class TwilioProvider(BuiltinToolProviderController): raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vanna/vanna.py b/api/core/tools/provider/builtin/vanna/vanna.py index ab1fd71df5..84724e921a 100644 --- a/api/core/tools/provider/builtin/vanna/vanna.py +++ b/api/core/tools/provider/builtin/vanna/vanna.py @@ -13,13 +13,13 @@ class VannaProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "model": "chinook", "db_type": "SQLite", "url": "https://vanna.ai/Chinook.sqlite", - "query": "What are the top 10 customers by sales?" + "query": "What are the top 10 customers by sales?", }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py index 1506ac0c9d..8e1b097776 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py @@ -1 +1 @@ -VECTORIZER_ICON_PNG = 'iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC' \ No newline at end of file +VECTORIZER_ICON_PNG = "iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC" diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py index c6ec198034..3ba4996be1 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py @@ -10,65 +10,60 @@ from core.tools.tool.builtin_tool import BuiltinTool class VectorizerTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - api_key_name = self.runtime.credentials.get('api_key_name', None) - api_key_value = self.runtime.credentials.get('api_key_value', None) - mode = tool_parameters.get('mode', 'test') - if mode == 'production': - mode = 'preview' + api_key_name = self.runtime.credentials.get("api_key_name", None) + api_key_value = self.runtime.credentials.get("api_key_value", None) + mode = tool_parameters.get("mode", "test") + if mode == "production": + mode = "preview" if not api_key_name or not api_key_value: - raise ToolProviderCredentialValidationError('Please input api key name and value') + raise ToolProviderCredentialValidationError("Please input api key name and value") - image_id = tool_parameters.get('image_id', '') + image_id = tool_parameters.get("image_id", "") if not image_id: - return self.create_text_message('Please input image id') - - if image_id.startswith('__test_'): + return self.create_text_message("Please input image id") + + if image_id.startswith("__test_"): image_binary = b64decode(VECTORIZER_ICON_PNG) else: image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) if not image_binary: - return self.create_text_message('Image not found, please request user to generate image firstly.') + return self.create_text_message("Image not found, please request user to generate image firstly.") response = post( - 'https://vectorizer.ai/api/v1/vectorize', - files={ - 'image': image_binary - }, - data={ - 'mode': mode - } if mode == 'test' else {}, - auth=(api_key_name, api_key_value), - timeout=30 + "https://vectorizer.ai/api/v1/vectorize", + files={"image": image_binary}, + data={"mode": mode} if mode == "test" else {}, + auth=(api_key_name, api_key_value), + timeout=30, ) if response.status_code != 200: raise Exception(response.text) - + return [ - self.create_text_message('the vectorized svg is saved as an image.'), - self.create_blob_message(blob=response.content, - meta={'mime_type': 'image/svg+xml'}) + self.create_text_message("the vectorized svg is saved as an image."), + self.create_blob_message(blob=response.content, meta={"mime_type": "image/svg+xml"}), ] - + def get_runtime_parameters(self) -> list[ToolParameter]: """ override the runtime parameters """ return [ ToolParameter.get_simple_instance( - name='image_id', - llm_description=f'the image id that you want to vectorize, \ + name="image_id", + llm_description=f"the image id that you want to vectorize, \ and the image id should be specified in \ - {[i.name for i in self.list_default_image_variables()]}', + {[i.name for i in self.list_default_image_variables()]}", type=ToolParameter.ToolParameterType.SELECT, required=True, - options=[i.name for i in self.list_default_image_variables()] + options=[i.name for i in self.list_default_image_variables()], ) ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 3f89a83500..3b868572f9 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -13,12 +13,8 @@ class VectorizerProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "mode": "test", - "image_id": "__test_123" - }, + user_id="", + tool_parameters={"mode": "test", "image_id": "__test_123"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py index 3d098e6768..12670b4b8b 100644 --- a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py @@ -6,23 +6,24 @@ from core.tools.tool.builtin_tool import BuiltinTool class WebscraperTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ try: - url = tool_parameters.get('url', '') - user_agent = tool_parameters.get('user_agent', '') + url = tool_parameters.get("url", "") + user_agent = tool_parameters.get("user_agent", "") if not url: - return self.create_text_message('Please input url') + return self.create_text_message("Please input url") # get webpage result = self.get_url(url, user_agent=user_agent) - if tool_parameters.get('generate_summary'): + if tool_parameters.get("generate_summary"): # summarize and return return self.create_text_message(self.summary(user_id=user_id, content=result)) else: diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.py b/api/core/tools/provider/builtin/webscraper/webscraper.py index 1e60fdb293..3c51393ac6 100644 --- a/api/core/tools/provider/builtin/webscraper/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/webscraper.py @@ -13,12 +13,11 @@ class WebscraperProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - 'url': 'https://www.google.com', - 'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + "url": "https://www.google.com", + "user_agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/websearch/tools/job_search.py b/api/core/tools/provider/builtin/websearch/tools/job_search.py index 9128305922..293f4f6329 100644 --- a/api/core/tools/provider/builtin/websearch/tools/job_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/job_search.py @@ -50,14 +50,16 @@ class SerplyApi: for job in jobs[:10]: try: string.append( - "\n".join([ - f"Position: {job['position']}", - f"Employer: {job['employer']}", - f"Location: {job['location']}", - f"Link: {job['link']}", - f"""Highest: {", ".join(list(job["highlights"]))}""", - "---", - ]) + "\n".join( + [ + f"Position: {job['position']}", + f"Employer: {job['employer']}", + f"Location: {job['location']}", + f"Link: {job['link']}", + f"""Highest: {", ".join(list(job["highlights"]))}""", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/news_search.py b/api/core/tools/provider/builtin/websearch/tools/news_search.py index e9c0744f05..9b5482fe18 100644 --- a/api/core/tools/provider/builtin/websearch/tools/news_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/news_search.py @@ -53,13 +53,15 @@ class SerplyApi: r = requests.get(entry["link"]) final_link = r.history[-1].headers["Location"] string.append( - "\n".join([ - f"Title: {entry['title']}", - f"Link: {final_link}", - f"Source: {entry['source']['title']}", - f"Published: {entry['published']}", - "---", - ]) + "\n".join( + [ + f"Title: {entry['title']}", + f"Link: {final_link}", + f"Source: {entry['source']['title']}", + f"Published: {entry['published']}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/scholar_search.py b/api/core/tools/provider/builtin/websearch/tools/scholar_search.py index 0030a03c06..798d059b51 100644 --- a/api/core/tools/provider/builtin/websearch/tools/scholar_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/scholar_search.py @@ -55,14 +55,16 @@ class SerplyApi: link = article["link"] authors = [author["name"] for author in article["author"]["authors"]] string.append( - "\n".join([ - f"Title: {article['title']}", - f"Link: {link}", - f"Description: {article['description']}", - f"Cite: {article['cite']}", - f"Authors: {', '.join(authors)}", - "---", - ]) + "\n".join( + [ + f"Title: {article['title']}", + f"Link: {link}", + f"Description: {article['description']}", + f"Cite: {article['cite']}", + f"Authors: {', '.join(authors)}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/web_search.py b/api/core/tools/provider/builtin/websearch/tools/web_search.py index 4f57c27caf..fe363ac7a4 100644 --- a/api/core/tools/provider/builtin/websearch/tools/web_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/web_search.py @@ -49,12 +49,14 @@ class SerplyApi: for result in results: try: string.append( - "\n".join([ - f"Title: {result['title']}", - f"Link: {result['link']}", - f"Description: {result['description'].strip()}", - "---", - ]) + "\n".join( + [ + f"Title: {result['title']}", + f"Link: {result['link']}", + f"Description: {result['description'].strip()}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py index fb44b70f4e..545d9f4f8d 100644 --- a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py +++ b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py @@ -8,41 +8,41 @@ from core.tools.utils.uuid_utils import is_valid_uuid class WecomGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - hook_key = tool_parameters.get('hook_key', '') + hook_key = tool_parameters.get("hook_key", "") if not is_valid_uuid(hook_key): - return self.create_text_message( - f'Invalid parameter hook_key ${hook_key}, not a valid UUID') + return self.create_text_message(f"Invalid parameter hook_key ${hook_key}, not a valid UUID") - message_type = tool_parameters.get('message_type', 'text') - if message_type == 'markdown': + message_type = tool_parameters.get("message_type", "text") + if message_type == "markdown": payload = { - "msgtype": 'markdown', + "msgtype": "markdown", "markdown": { "content": content, - } + }, } else: payload = { - "msgtype": 'text', + "msgtype": "text", "text": { "content": content, - } + }, } - api_url = 'https://qyapi.weixin.qq.com/cgi-bin/webhook/send' + api_url = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = { - 'key': hook_key, + "key": hook_key, } try: @@ -51,6 +51,7 @@ class WecomGroupBotTool(BuiltinTool): return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index 0796cd2392..67efcf0954 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -83,7 +83,6 @@ class WikipediaQueryRun: class WikiPediaSearchTool(BuiltinTool): - def _invoke( self, user_id: str, diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.py b/api/core/tools/provider/builtin/wikipedia/wikipedia.py index f8038714a5..178bf7b0ce 100644 --- a/api/core/tools/provider/builtin/wikipedia/wikipedia.py +++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.py @@ -11,11 +11,10 @@ class WikiPediaProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "misaka mikoto", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py index 8cb9c10ddf..9dc5bed824 100644 --- a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py @@ -8,29 +8,24 @@ from core.tools.tool.builtin_tool import BuiltinTool class WolframAlphaTool(BuiltinTool): - _base_url = 'https://api.wolframalpha.com/v2/query' + _base_url = "https://api.wolframalpha.com/v2/query" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - appid = self.runtime.credentials.get('appid', '') + return self.create_text_message("Please input query") + appid = self.runtime.credentials.get("appid", "") if not appid: - raise ToolProviderCredentialValidationError('Please input appid') - - params = { - 'appid': appid, - 'input': query, - 'includepodid': 'Result', - 'format': 'plaintext', - 'output': 'json' - } + raise ToolProviderCredentialValidationError("Please input appid") + + params = {"appid": appid, "input": query, "includepodid": "Result", "format": "plaintext", "output": "json"} finished = False result = None @@ -45,34 +40,33 @@ class WolframAlphaTool(BuiltinTool): response_data = response.json() except Exception as e: raise ToolInvokeError(str(e)) - - if 'success' not in response_data['queryresult'] or response_data['queryresult']['success'] != True: - query_result = response_data.get('queryresult', {}) - if query_result.get('error'): - if 'msg' in query_result['error']: - if query_result['error']['msg'] == 'Invalid appid': - raise ToolProviderCredentialValidationError('Invalid appid') - raise ToolInvokeError('Failed to invoke tool') - - if 'didyoumeans' in response_data['queryresult']: - # get the most likely interpretation - query = '' - max_score = 0 - for didyoumean in response_data['queryresult']['didyoumeans']: - if float(didyoumean['score']) > max_score: - query = didyoumean['val'] - max_score = float(didyoumean['score']) - params['input'] = query + if "success" not in response_data["queryresult"] or response_data["queryresult"]["success"] != True: + query_result = response_data.get("queryresult", {}) + if query_result.get("error"): + if "msg" in query_result["error"]: + if query_result["error"]["msg"] == "Invalid appid": + raise ToolProviderCredentialValidationError("Invalid appid") + raise ToolInvokeError("Failed to invoke tool") + + if "didyoumeans" in response_data["queryresult"]: + # get the most likely interpretation + query = "" + max_score = 0 + for didyoumean in response_data["queryresult"]["didyoumeans"]: + if float(didyoumean["score"]) > max_score: + query = didyoumean["val"] + max_score = float(didyoumean["score"]) + + params["input"] = query else: finished = True - if 'souces' in response_data['queryresult']: - return self.create_link_message(response_data['queryresult']['sources']['url']) - elif 'pods' in response_data['queryresult']: - result = response_data['queryresult']['pods'][0]['subpods'][0]['plaintext'] + if "souces" in response_data["queryresult"]: + return self.create_link_message(response_data["queryresult"]["sources"]["url"]) + elif "pods" in response_data["queryresult"]: + result = response_data["queryresult"]["pods"][0]["subpods"][0]["plaintext"] if not finished or not result: - return self.create_text_message('No result found') + return self.create_text_message("No result found") return self.create_text_message(result) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py index ef1aac7ff2..7be288b538 100644 --- a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py @@ -13,11 +13,10 @@ class GoogleProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "1+2+....+111", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py index cf511ea894..f044fbe540 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/analytics.py +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py @@ -10,27 +10,28 @@ from core.tools.tool.builtin_tool import BuiltinTool class YahooFinanceAnalyticsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - symbol = tool_parameters.get('symbol', '') + symbol = tool_parameters.get("symbol", "") if not symbol: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + time_range = [None, None] - start_date = tool_parameters.get('start_date', '') + start_date = tool_parameters.get("start_date", "") if start_date: time_range[0] = start_date else: - time_range[0] = '1800-01-01' + time_range[0] = "1800-01-01" - end_date = tool_parameters.get('end_date', '') + end_date = tool_parameters.get("end_date", "") if end_date: time_range[1] = end_date else: - time_range[1] = datetime.now().strftime('%Y-%m-%d') + time_range[1] = datetime.now().strftime("%Y-%m-%d") stock_data = download(symbol, start=time_range[0], end=time_range[1]) max_segments = min(15, len(stock_data)) @@ -41,30 +42,29 @@ class YahooFinanceAnalyticsTool(BuiltinTool): end_idx = (i + 1) * rows_per_segment if i < max_segments - 1 else len(stock_data) segment_data = stock_data.iloc[start_idx:end_idx] segment_summary = { - 'Start Date': segment_data.index[0], - 'End Date': segment_data.index[-1], - 'Average Close': segment_data['Close'].mean(), - 'Average Volume': segment_data['Volume'].mean(), - 'Average Open': segment_data['Open'].mean(), - 'Average High': segment_data['High'].mean(), - 'Average Low': segment_data['Low'].mean(), - 'Average Adj Close': segment_data['Adj Close'].mean(), - 'Max Close': segment_data['Close'].max(), - 'Min Close': segment_data['Close'].min(), - 'Max Volume': segment_data['Volume'].max(), - 'Min Volume': segment_data['Volume'].min(), - 'Max Open': segment_data['Open'].max(), - 'Min Open': segment_data['Open'].min(), - 'Max High': segment_data['High'].max(), - 'Min High': segment_data['High'].min(), + "Start Date": segment_data.index[0], + "End Date": segment_data.index[-1], + "Average Close": segment_data["Close"].mean(), + "Average Volume": segment_data["Volume"].mean(), + "Average Open": segment_data["Open"].mean(), + "Average High": segment_data["High"].mean(), + "Average Low": segment_data["Low"].mean(), + "Average Adj Close": segment_data["Adj Close"].mean(), + "Max Close": segment_data["Close"].max(), + "Min Close": segment_data["Close"].min(), + "Max Volume": segment_data["Volume"].max(), + "Min Volume": segment_data["Volume"].min(), + "Max Open": segment_data["Open"].max(), + "Min Open": segment_data["Open"].min(), + "Max High": segment_data["High"].max(), + "Min High": segment_data["High"].min(), } - + summary_data.append(segment_summary) summary_df = pd.DataFrame(summary_data) - + try: return self.create_text_message(str(summary_df.to_dict())) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') - \ No newline at end of file + return self.create_text_message("There is a internet connection problem. Please try again later.") diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py index 4f2922ef3e..ff820430f9 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/news.py +++ b/api/core/tools/provider/builtin/yahoo/tools/news.py @@ -8,40 +8,39 @@ from core.tools.tool.builtin_tool import BuiltinTool class YahooFinanceSearchTickerTool(BuiltinTool): - def _invoke(self,user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - ''' - invoke tools - ''' - - query = tool_parameters.get('symbol', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + + query = tool_parameters.get("symbol", "") if not query: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + try: return self.run(ticker=query, user_id=user_id) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') + return self.create_text_message("There is a internet connection problem. Please try again later.") def run(self, ticker: str, user_id: str) -> ToolInvokeMessage: company = yfinance.Ticker(ticker) try: if company.isin is None: - return self.create_text_message(f'Company ticker {ticker} not found.') + return self.create_text_message(f"Company ticker {ticker} not found.") except (HTTPError, ReadTimeout, ConnectionError): - return self.create_text_message(f'Company ticker {ticker} not found.') + return self.create_text_message(f"Company ticker {ticker} not found.") links = [] try: - links = [n['link'] for n in company.news if n['type'] == 'STORY'] + links = [n["link"] for n in company.news if n["type"] == "STORY"] except (HTTPError, ReadTimeout, ConnectionError): if not links: - return self.create_text_message(f'There is nothing about {ticker} ticker') + return self.create_text_message(f"There is nothing about {ticker} ticker") if not links: - return self.create_text_message(f'No news found for company that searched with {ticker} ticker.') - - result = '\n\n'.join([ - self.get_url(link) for link in links - ]) + return self.create_text_message(f"No news found for company that searched with {ticker} ticker.") + + result = "\n\n".join([self.get_url(link) for link in links]) return self.create_text_message(self.summary(user_id=user_id, content=result)) diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py index 262fff3b25..dfc7e46047 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/ticker.py +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py @@ -8,19 +8,20 @@ from core.tools.tool.builtin_tool import BuiltinTool class YahooFinanceSearchTickerTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - query = tool_parameters.get('symbol', '') + query = tool_parameters.get("symbol", "") if not query: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + try: return self.create_text_message(self.run(ticker=query)) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') - + return self.create_text_message("There is a internet connection problem. Please try again later.") + def run(self, ticker: str) -> str: - return str(Ticker(ticker).info) \ No newline at end of file + return str(Ticker(ticker).info) diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.py b/api/core/tools/provider/builtin/yahoo/yahoo.py index 96dbc6c3d0..8d82084e76 100644 --- a/api/core/tools/provider/builtin/yahoo/yahoo.py +++ b/api/core/tools/provider/builtin/yahoo/yahoo.py @@ -11,11 +11,10 @@ class YahooFinanceProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "ticker": "MSFT", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py index 7a9b9fce4a..95dec2eac9 100644 --- a/api/core/tools/provider/builtin/youtube/tools/videos.py +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -8,60 +8,67 @@ from core.tools.tool.builtin_tool import BuiltinTool class YoutubeVideosAnalyticsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - channel = tool_parameters.get('channel', '') + channel = tool_parameters.get("channel", "") if not channel: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + time_range = [None, None] - start_date = tool_parameters.get('start_date', '') + start_date = tool_parameters.get("start_date", "") if start_date: time_range[0] = start_date else: - time_range[0] = '1800-01-01' + time_range[0] = "1800-01-01" - end_date = tool_parameters.get('end_date', '') + end_date = tool_parameters.get("end_date", "") if end_date: time_range[1] = end_date else: - time_range[1] = datetime.now().strftime('%Y-%m-%d') + time_range[1] = datetime.now().strftime("%Y-%m-%d") - if 'google_api_key' not in self.runtime.credentials or not self.runtime.credentials['google_api_key']: - return self.create_text_message('Please input api key') + if "google_api_key" not in self.runtime.credentials or not self.runtime.credentials["google_api_key"]: + return self.create_text_message("Please input api key") - youtube = build('youtube', 'v3', developerKey=self.runtime.credentials['google_api_key']) + youtube = build("youtube", "v3", developerKey=self.runtime.credentials["google_api_key"]) # try to get channel id - search_results = youtube.search().list(q=channel, type='channel', order='relevance', part='id').execute() - channel_id = search_results['items'][0]['id']['channelId'] + search_results = youtube.search().list(q=channel, type="channel", order="relevance", part="id").execute() + channel_id = search_results["items"][0]["id"]["channelId"] start_date, end_date = time_range - start_date = datetime.strptime(start_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ') - end_date = datetime.strptime(end_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ') + start_date = datetime.strptime(start_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") + end_date = datetime.strptime(end_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") # get videos - time_range_videos = youtube.search().list( - part='snippet', channelId=channel_id, order='date', type='video', - publishedAfter=start_date, - publishedBefore=end_date - ).execute() + time_range_videos = ( + youtube.search() + .list( + part="snippet", + channelId=channel_id, + order="date", + type="video", + publishedAfter=start_date, + publishedBefore=end_date, + ) + .execute() + ) def extract_video_data(video_list): data = [] - for video in video_list['items']: - video_id = video['id']['videoId'] - video_info = youtube.videos().list(part='snippet,statistics', id=video_id).execute() - title = video_info['items'][0]['snippet']['title'] - views = video_info['items'][0]['statistics']['viewCount'] - data.append({'Title': title, 'Views': views}) + for video in video_list["items"]: + video_id = video["id"]["videoId"] + video_info = youtube.videos().list(part="snippet,statistics", id=video_id).execute() + title = video_info["items"][0]["snippet"]["title"] + views = video_info["items"][0]["statistics"]["viewCount"] + data.append({"Title": title, "Views": views}) return data summary = extract_video_data(time_range_videos) - + return self.create_text_message(str(summary)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/youtube.py b/api/core/tools/provider/builtin/youtube/youtube.py index 83a4fccb32..aad876491c 100644 --- a/api/core/tools/provider/builtin/youtube/youtube.py +++ b/api/core/tools/provider/builtin/youtube/youtube.py @@ -11,7 +11,7 @@ class YahooFinanceProvider(BuiltinToolProviderController): "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "channel": "TOKYO GIRLS COLLECTION", "start_date": "2020-01-01", @@ -20,4 +20,3 @@ class YahooFinanceProvider(BuiltinToolProviderController): ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index bcf41c90ed..6b64dd1b4e 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -22,34 +22,36 @@ class BuiltinToolProviderController(ToolProviderController): if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: super().__init__(**data) return - + # load provider yaml - provider = self.__class__.__module__.split('.')[-1] - yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') + provider = self.__class__.__module__.split(".")[-1] + yaml_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, f"{provider}.yaml") try: provider_yaml = load_yaml_file(yaml_path, ignore_error=False) except Exception as e: - raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}') + raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}") - if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None: + if "credentials_for_provider" in provider_yaml and provider_yaml["credentials_for_provider"] is not None: # set credentials name - for credential_name in provider_yaml['credentials_for_provider']: - provider_yaml['credentials_for_provider'][credential_name]['name'] = credential_name + for credential_name in provider_yaml["credentials_for_provider"]: + provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name - super().__init__(**{ - 'identity': provider_yaml['identity'], - 'credentials_schema': provider_yaml.get('credentials_for_provider', None), - }) + super().__init__( + **{ + "identity": provider_yaml["identity"], + "credentials_schema": provider_yaml.get("credentials_for_provider", None), + } + ) def _get_builtin_tools(self) -> list[Tool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ if self.tools: return self.tools - + provider = self.identity.name tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools") # get all the yaml files in the tool path @@ -62,155 +64,161 @@ class BuiltinToolProviderController(ToolProviderController): # get tool class, import the module assistant_tool_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'builtin', provider, 'tools', f'{tool_name}.py'), - parent_type=BuiltinTool) + module_name=f"core.tools.provider.builtin.{provider}.tools.{tool_name}", + script_path=path.join( + path.dirname(path.realpath(__file__)), "builtin", provider, "tools", f"{tool_name}.py" + ), + parent_type=BuiltinTool, + ) tool["identity"]["provider"] = provider tools.append(assistant_tool_class(**tool)) self.tools = tools return tools - + def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: """ - returns the credentials schema of the provider + returns the credentials schema of the provider - :return: the credentials schema + :return: the credentials schema """ if not self.credentials_schema: return {} - + return self.credentials_schema.copy() def get_tools(self) -> list[Tool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ return self._get_builtin_tools() - + def get_tool(self, tool_name: str) -> Tool: """ - returns the tool that the provider can provide + returns the tool that the provider can provide """ return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ - returns the parameters of the tool + returns the parameters of the tool - :param tool_name: the name of the tool, defined in `get_tools` - :return: list of parameters + :param tool_name: the name of the tool, defined in `get_tools` + :return: list of parameters """ tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) if tool is None: - raise ToolNotFoundError(f'tool {tool_name} not found') + raise ToolNotFoundError(f"tool {tool_name} not found") return tool.parameters @property def need_credentials(self) -> bool: """ - returns whether the provider needs credentials + returns whether the provider needs credentials - :return: whether the provider needs credentials + :return: whether the provider needs credentials """ - return self.credentials_schema is not None and \ - len(self.credentials_schema) != 0 + return self.credentials_schema is not None and len(self.credentials_schema) != 0 @property def provider_type(self) -> ToolProviderType: """ - returns the type of the provider + returns the type of the provider - :return: type of the provider + :return: type of the provider """ return ToolProviderType.BUILT_IN @property def tool_labels(self) -> list[str]: """ - returns the labels of the provider + returns the labels of the provider - :return: labels of the provider + :return: labels of the provider """ label_enums = self._get_tool_labels() return [default_tool_label_dict[label].name for label in label_enums] def _get_tool_labels(self) -> list[ToolLabelEnum]: """ - returns the labels of the provider + returns the labels of the provider """ return self.identity.tags or [] def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: """ - validate the parameters of the tool and set the default value if needed + validate the parameters of the tool and set the default value if needed - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param tool_parameters: the parameters of the tool """ tool_parameters_schema = self.get_parameters(tool_name) - + tool_parameters_need_to_validate: dict[str, ToolParameter] = {} for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter for parameter in tool_parameters: if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}') - + raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + # check type parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f'parameter {parameter} should be number') - + raise ToolParameterValidationError(f"parameter {parameter} should be number") + if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: - raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be greater than {parameter_schema.min}" + ) + if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: - raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be less than {parameter_schema.max}" + ) + elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f'parameter {parameter} should be boolean') - + raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f'parameter {parameter} options should be list') - + raise ToolParameterValidationError(f"parameter {parameter} options should be list") + if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}') - + raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + tool_parameters_need_to_validate.pop(parameter) for parameter in tool_parameters_need_to_validate: parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.required: - raise ToolParameterValidationError(f'parameter {parameter} is required') - + raise ToolParameterValidationError(f"parameter {parameter} is required") + # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, - parameter_schema.type) + default_value = ToolParameterConverter.cast_parameter_by_type( + parameter_schema.default, parameter_schema.type + ) tool_parameters[parameter] = default_value - + def validate_credentials(self, credentials: dict[str, Any]) -> None: """ - validate the credentials of the provider + validate the credentials of the provider - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool """ # validate credentials format self.validate_credentials_format(credentials) @@ -221,9 +229,9 @@ class BuiltinToolProviderController(ToolProviderController): @abstractmethod def _validate_credentials(self, credentials: dict[str, Any]) -> None: """ - validate the credentials of the provider + validate the credentials of the provider - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool """ pass diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index ef1ace9c7c..f4008eedce 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -21,162 +21,174 @@ class ToolProviderController(BaseModel, ABC): def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: """ - returns the credentials schema of the provider + returns the credentials schema of the provider - :return: the credentials schema + :return: the credentials schema """ return self.credentials_schema.copy() - + @abstractmethod def get_tools(self) -> list[Tool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ pass @abstractmethod def get_tool(self, tool_name: str) -> Tool: """ - returns a tool that the provider can provide + returns a tool that the provider can provide - :return: tool + :return: tool """ pass def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ - returns the parameters of the tool + returns the parameters of the tool - :param tool_name: the name of the tool, defined in `get_tools` - :return: list of parameters + :param tool_name: the name of the tool, defined in `get_tools` + :return: list of parameters """ tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) if tool is None: - raise ToolNotFoundError(f'tool {tool_name} not found') + raise ToolNotFoundError(f"tool {tool_name} not found") return tool.parameters @property def provider_type(self) -> ToolProviderType: """ - returns the type of the provider + returns the type of the provider - :return: type of the provider + :return: type of the provider """ return ToolProviderType.BUILT_IN def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: """ - validate the parameters of the tool and set the default value if needed + validate the parameters of the tool and set the default value if needed - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param tool_parameters: the parameters of the tool """ tool_parameters_schema = self.get_parameters(tool_name) - + tool_parameters_need_to_validate: dict[str, ToolParameter] = {} for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter for parameter in tool_parameters: if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}') - + raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + # check type parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f'parameter {parameter} should be number') - + raise ToolParameterValidationError(f"parameter {parameter} should be number") + if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: - raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be greater than {parameter_schema.min}" + ) + if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: - raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be less than {parameter_schema.max}" + ) + elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f'parameter {parameter} should be boolean') - + raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f'parameter {parameter} options should be list') - + raise ToolParameterValidationError(f"parameter {parameter} options should be list") + if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}') - + raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + tool_parameters_need_to_validate.pop(parameter) for parameter in tool_parameters_need_to_validate: parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.required: - raise ToolParameterValidationError(f'parameter {parameter} is required') - + raise ToolParameterValidationError(f"parameter {parameter} is required") + # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, - parameter_schema.type) + tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type( + parameter_schema.default, parameter_schema.type + ) def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ - validate the format of the credentials of the provider and set the default value if needed + validate the format of the credentials of the provider and set the default value if needed - :param credentials: the credentials of the tool + :param credentials: the credentials of the tool """ credentials_schema = self.credentials_schema if credentials_schema is None: return - + credentials_need_to_validate: dict[str, ToolProviderCredentials] = {} for credential_name in credentials_schema: credentials_need_to_validate[credential_name] = credentials_schema[credential_name] for credential_name in credentials: if credential_name not in credentials_need_to_validate: - raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}') - + raise ToolProviderCredentialValidationError( + f"credential {credential_name} not found in provider {self.identity.name}" + ) + # check type credential_schema = credentials_need_to_validate[credential_name] - if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ - credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT: + if ( + credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT + or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT + ): if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + options = credential_schema.options if not isinstance(options, list): - raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") + if credentials[credential_name] not in [x.value for x in options]: - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}') - + raise ToolProviderCredentialValidationError( + f"credential {credential_name} should be one of {options}" + ) + credentials_need_to_validate.pop(credential_name) for credential_name in credentials_need_to_validate: credential_schema = credentials_need_to_validate[credential_name] if credential_schema.required: - raise ToolProviderCredentialValidationError(f'credential {credential_name} is required') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") + # the credential is not set currently, set the default value if needed if credential_schema.default is not None: default_value = credential_schema.default # parse default value into the correct type - if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ - credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \ - credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: + if ( + credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT + or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT + or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT + ): default_value = str(default_value) credentials[credential_name] = default_value - \ No newline at end of file diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index f14abac767..25eaf6a66a 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -30,29 +30,25 @@ class WorkflowToolProviderController(ToolProviderController): provider_id: str @classmethod - def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController': + def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": app = db_provider.app if not app: - raise ValueError('app not found') + raise ValueError("app not found") - controller = WorkflowToolProviderController(**{ - 'identity': { - 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', - 'name': db_provider.label, - 'label': { - 'en_US': db_provider.label, - 'zh_Hans': db_provider.label + controller = WorkflowToolProviderController( + **{ + "identity": { + "author": db_provider.user.name if db_provider.user_id and db_provider.user else "", + "name": db_provider.label, + "label": {"en_US": db_provider.label, "zh_Hans": db_provider.label}, + "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, + "icon": db_provider.icon, }, - 'description': { - 'en_US': db_provider.description, - 'zh_Hans': db_provider.description - }, - 'icon': db_provider.icon, - }, - 'credentials_schema': {}, - 'provider_id': db_provider.id or '', - }) + "credentials_schema": {}, + "provider_id": db_provider.id or "", + } + ) # init tools @@ -66,25 +62,23 @@ class WorkflowToolProviderController(ToolProviderController): def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: """ - get db provider tool - :param db_provider: the db provider - :param app: the app - :return: the tool + get db provider tool + :param db_provider: the db provider + :param app: the app + :return: the tool """ - workflow: Workflow = db.session.query(Workflow).filter( - Workflow.app_id == db_provider.app_id, - Workflow.version == db_provider.version - ).first() + workflow: Workflow = ( + db.session.query(Workflow) + .filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) + .first() + ) if not workflow: - raise ValueError('workflow not found') + raise ValueError("workflow not found") # fetch start node graph: dict = workflow.graph_dict features_dict: dict = workflow.features_dict - features = WorkflowAppConfigManager.convert_features( - config_dict=features_dict, - app_mode=AppMode.WORKFLOW - ) + features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW) parameters = db_provider.parameter_configurations variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) @@ -101,51 +95,34 @@ class WorkflowToolProviderController(ToolProviderController): parameter_type = None options = None if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: - raise ValueError(f'unsupported variable type {variable.type}') + raise ValueError(f"unsupported variable type {variable.type}") parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] if variable.type == VariableEntityType.SELECT and variable.options: options = [ - ToolParameterOption( - value=option, - label=I18nObject( - en_US=option, - zh_Hans=option - ) - ) for option in variable.options + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in variable.options ] workflow_tool_parameters.append( ToolParameter( name=parameter.name, - label=I18nObject( - en_US=variable.label, - zh_Hans=variable.label - ), - human_description=I18nObject( - en_US=parameter.description, - zh_Hans=parameter.description - ), + label=I18nObject(en_US=variable.label, zh_Hans=variable.label), + human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), type=parameter_type, form=parameter.form, llm_description=parameter.description, required=variable.required, options=options, - default=variable.default + default=variable.default, ) ) elif features.file_upload: workflow_tool_parameters.append( ToolParameter( name=parameter.name, - label=I18nObject( - en_US=parameter.name, - zh_Hans=parameter.name - ), - human_description=I18nObject( - en_US=parameter.description, - zh_Hans=parameter.description - ), + label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name), + human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), type=ToolParameter.ToolParameterType.FILE, llm_description=parameter.description, required=False, @@ -153,53 +130,51 @@ class WorkflowToolProviderController(ToolProviderController): ) ) else: - raise ValueError('variable not found') + raise ValueError("variable not found") return WorkflowTool( identity=ToolIdentity( - author=user.name if user else '', + author=user.name if user else "", name=db_provider.name, - label=I18nObject( - en_US=db_provider.label, - zh_Hans=db_provider.label - ), + label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), provider=self.provider_id, icon=db_provider.icon, ), description=ToolDescription( - human=I18nObject( - en_US=db_provider.description, - zh_Hans=db_provider.description - ), + human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), llm=db_provider.description, ), parameters=workflow_tool_parameters, is_team_authorization=True, workflow_app_id=app.id, workflow_entities={ - 'app': app, - 'workflow': workflow, + "app": app, + "workflow": workflow, }, version=db_provider.version, workflow_call_depth=0, - label=db_provider.label + label=db_provider.label, ) def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: """ - fetch tools from database + fetch tools from database - :param user_id: the user id - :param tenant_id: the tenant id - :return: the tools + :param user_id: the user id + :param tenant_id: the tenant id + :return: the tools """ if self.tools is not None: return self.tools - db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == self.provider_id, - ).first() + db_providers: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.app_id == self.provider_id, + ) + .first() + ) if not db_providers: return [] @@ -210,10 +185,10 @@ class WorkflowToolProviderController(ToolProviderController): def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: """ - get tool by name + get tool by name - :param tool_name: the name of the tool - :return: the tool + :param tool_name: the name of the tool + :return: the tool """ if self.tools is None: return None diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 38f10032e2..bf336b48f3 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -12,8 +12,8 @@ from core.tools.errors import ToolInvokeError, ToolParameterValidationError, Too from core.tools.tool.tool import Tool API_TOOL_DEFAULT_TIMEOUT = ( - int(getenv('API_TOOL_DEFAULT_CONNECT_TIMEOUT', '10')), - int(getenv('API_TOOL_DEFAULT_READ_TIMEOUT', '60')) + int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), + int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")), ) @@ -24,31 +24,32 @@ class ApiTool(Tool): Api tool """ - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=self.identity.model_copy() if self.identity else None, parameters=self.parameters.copy() if self.parameters else None, description=self.description.model_copy() if self.description else None, api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, - runtime=Tool.Runtime(**runtime) + runtime=Tool.Runtime(**runtime), ) - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], - format_only: bool = False) -> str: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str: """ - validate the credentials for Api tool + validate the credentials for Api tool """ - # assemble validate request and request parameters + # assemble validate request and request parameters headers = self.assembling_request(parameters) if format_only: - return '' + return "" response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) # validate response @@ -61,30 +62,30 @@ class ApiTool(Tool): headers = {} credentials = self.runtime.credentials or {} - if 'auth_type' not in credentials: - raise ToolProviderCredentialValidationError('Missing auth_type') + if "auth_type" not in credentials: + raise ToolProviderCredentialValidationError("Missing auth_type") - if credentials['auth_type'] == 'api_key': - api_key_header = 'api_key' + if credentials["auth_type"] == "api_key": + api_key_header = "api_key" - if 'api_key_header' in credentials: - api_key_header = credentials['api_key_header'] + if "api_key_header" in credentials: + api_key_header = credentials["api_key_header"] - if 'api_key_value' not in credentials: - raise ToolProviderCredentialValidationError('Missing api_key_value') - elif not isinstance(credentials['api_key_value'], str): - raise ToolProviderCredentialValidationError('api_key_value must be a string') + if "api_key_value" not in credentials: + raise ToolProviderCredentialValidationError("Missing api_key_value") + elif not isinstance(credentials["api_key_value"], str): + raise ToolProviderCredentialValidationError("api_key_value must be a string") - if 'api_key_header_prefix' in credentials: - api_key_header_prefix = credentials['api_key_header_prefix'] - if api_key_header_prefix == 'basic' and credentials['api_key_value']: - credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}' - elif api_key_header_prefix == 'bearer' and credentials['api_key_value']: - credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}' - elif api_key_header_prefix == 'custom': + if "api_key_header_prefix" in credentials: + api_key_header_prefix = credentials["api_key_header_prefix"] + if api_key_header_prefix == "basic" and credentials["api_key_value"]: + credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}' + elif api_key_header_prefix == "bearer" and credentials["api_key_value"]: + credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}' + elif api_key_header_prefix == "custom": pass - headers[api_key_header] = credentials['api_key_value'] + headers[api_key_header] = credentials["api_key_value"] needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] for parameter in needed_parameters: @@ -98,13 +99,13 @@ class ApiTool(Tool): def validate_and_parse_response(self, response: httpx.Response) -> str: """ - validate the response + validate the response """ if isinstance(response, httpx.Response): if response.status_code >= 400: raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") if not response.content: - return 'Empty response from the tool, please check your parameters and try again.' + return "Empty response from the tool, please check your parameters and try again." try: response = response.json() try: @@ -114,21 +115,22 @@ class ApiTool(Tool): except Exception as e: return response.text else: - raise ValueError(f'Invalid response type {type(response)}') + raise ValueError(f"Invalid response type {type(response)}") @staticmethod def get_parameter_value(parameter, parameters): - if parameter['name'] in parameters: - return parameters[parameter['name']] - elif parameter.get('required', False): + if parameter["name"] in parameters: + return parameters[parameter["name"]] + elif parameter.get("required", False): raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}") else: - return (parameter.get('schema', {}) or {}).get('default', '') + return (parameter.get("schema", {}) or {}).get("default", "") - def do_http_request(self, url: str, method: str, headers: dict[str, Any], - parameters: dict[str, Any]) -> httpx.Response: + def do_http_request( + self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] + ) -> httpx.Response: """ - do http request depending on api bundle + do http request depending on api bundle """ method = method.lower() @@ -138,29 +140,30 @@ class ApiTool(Tool): cookies = {} # check parameters - for parameter in self.api_bundle.openapi.get('parameters', []): + for parameter in self.api_bundle.openapi.get("parameters", []): value = self.get_parameter_value(parameter, parameters) - if parameter['in'] == 'path': - path_params[parameter['name']] = value + if parameter["in"] == "path": + path_params[parameter["name"]] = value - elif parameter['in'] == 'query': - if value !='': params[parameter['name']] = value + elif parameter["in"] == "query": + if value != "": + params[parameter["name"]] = value - elif parameter['in'] == 'cookie': - cookies[parameter['name']] = value + elif parameter["in"] == "cookie": + cookies[parameter["name"]] = value - elif parameter['in'] == 'header': - headers[parameter['name']] = value + elif parameter["in"] == "header": + headers[parameter["name"]] = value # check if there is a request body and handle it - if 'requestBody' in self.api_bundle.openapi and self.api_bundle.openapi['requestBody'] is not None: + if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None: # handle json request body - if 'content' in self.api_bundle.openapi['requestBody']: - for content_type in self.api_bundle.openapi['requestBody']['content']: - headers['Content-Type'] = content_type - body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema'] - required = body_schema.get('required', []) - properties = body_schema.get('properties', {}) + if "content" in self.api_bundle.openapi["requestBody"]: + for content_type in self.api_bundle.openapi["requestBody"]["content"]: + headers["Content-Type"] = content_type + body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) for name, property in properties.items(): if name in parameters: # convert type @@ -169,63 +172,71 @@ class ApiTool(Tool): raise ToolParameterValidationError( f"Missing required parameter {name} in operation {self.api_bundle.operation_id}" ) - elif 'default' in property: - body[name] = property['default'] + elif "default" in property: + body[name] = property["default"] else: body[name] = None break # replace path parameters for name, value in path_params.items(): - url = url.replace(f'{{{name}}}', f'{value}') + url = url.replace(f"{{{name}}}", f"{value}") # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored - if 'Content-Type' in headers: - if headers['Content-Type'] == 'application/json': + if "Content-Type" in headers: + if headers["Content-Type"] == "application/json": body = json.dumps(body) - elif headers['Content-Type'] == 'application/x-www-form-urlencoded': + elif headers["Content-Type"] == "application/x-www-form-urlencoded": body = urlencode(body) else: body = body - if method in ('get', 'head', 'post', 'put', 'delete', 'patch'): - response = getattr(ssrf_proxy, method)(url, params=params, headers=headers, cookies=cookies, data=body, - timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True) + if method in ("get", "head", "post", "put", "delete", "patch"): + response = getattr(ssrf_proxy, method)( + url, + params=params, + headers=headers, + cookies=cookies, + data=body, + timeout=API_TOOL_DEFAULT_TIMEOUT, + follow_redirects=True, + ) return response else: - raise ValueError(f'Invalid http method {self.method}') + raise ValueError(f"Invalid http method {self.method}") - def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], - max_recursive=10) -> Any: + def _convert_body_property_any_of( + self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 + ) -> Any: if max_recursive <= 0: raise Exception("Max recursion depth reached") for option in any_of or []: try: - if 'type' in option: + if "type" in option: # Attempt to convert the value based on the type. - if option['type'] == 'integer' or option['type'] == 'int': + if option["type"] == "integer" or option["type"] == "int": return int(value) - elif option['type'] == 'number': - if '.' in str(value): + elif option["type"] == "number": + if "." in str(value): return float(value) else: return int(value) - elif option['type'] == 'string': + elif option["type"] == "string": return str(value) - elif option['type'] == 'boolean': - if str(value).lower() in ['true', '1']: + elif option["type"] == "boolean": + if str(value).lower() in ["true", "1"]: return True - elif str(value).lower() in ['false', '0']: + elif str(value).lower() in ["false", "0"]: return False else: continue # Not a boolean, try next option - elif option['type'] == 'null' and not value: + elif option["type"] == "null" and not value: return None else: continue # Unsupported type, try next option - elif 'anyOf' in option and isinstance(option['anyOf'], list): + elif "anyOf" in option and isinstance(option["anyOf"], list): # Recursive call to handle nested anyOf - return self._convert_body_property_any_of(property, value, option['anyOf'], max_recursive - 1) + return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1) except ValueError: continue # Conversion failed, try next option # If no option succeeded, you might want to return the value as is or raise an error @@ -233,23 +244,23 @@ class ApiTool(Tool): def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any: try: - if 'type' in property: - if property['type'] == 'integer' or property['type'] == 'int': + if "type" in property: + if property["type"] == "integer" or property["type"] == "int": return int(value) - elif property['type'] == 'number': + elif property["type"] == "number": # check if it is a float - if '.' in str(value): + if "." in str(value): return float(value) else: return int(value) - elif property['type'] == 'string': + elif property["type"] == "string": return str(value) - elif property['type'] == 'boolean': + elif property["type"] == "boolean": return bool(value) - elif property['type'] == 'null': + elif property["type"] == "null": if value is None: return None - elif property['type'] == 'object' or property['type'] == 'array': + elif property["type"] == "object" or property["type"] == "array": if isinstance(value, str): try: # an array str like '[1,2]' also can convert to list [1,2] through json.loads @@ -264,8 +275,8 @@ class ApiTool(Tool): return value else: raise ValueError(f"Invalid type {property['type']} for property {property}") - elif 'anyOf' in property and isinstance(property['anyOf'], list): - return self._convert_body_property_any_of(property, value, property['anyOf']) + elif "anyOf" in property and isinstance(property["anyOf"], list): + return self._convert_body_property_any_of(property, value, property["anyOf"]) except ValueError as e: return value diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index ad7a88838b..8edaf7c0e6 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -1,4 +1,3 @@ - from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.entities.tool_entities import ToolProviderType @@ -16,40 +15,38 @@ Please summarize the text you got. class BuiltinTool(Tool): """ - Builtin tool + Builtin tool - :param meta: the meta data of a tool call processing + :param meta: the meta data of a tool call processing """ - def invoke_model( - self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str] - ) -> LLMResult: + def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult: """ - invoke model + invoke model - :param model_config: the model config - :param prompt_messages: the prompt messages - :param stop: the stop words - :return: the model result + :param model_config: the model config + :param prompt_messages: the prompt messages + :param stop: the stop words + :return: the model result """ # invoke model return ModelInvocationUtils.invoke( user_id=user_id, tenant_id=self.runtime.tenant_id, - tool_type='builtin', + tool_type="builtin", tool_name=self.identity.name, prompt_messages=prompt_messages, ) - + def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.BUILT_IN - + def get_max_tokens(self) -> int: """ - get max tokens + get max tokens - :param model_config: the model config - :return: the max tokens + :param model_config: the model config + :return: the max tokens """ return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id, @@ -57,39 +54,34 @@ class BuiltinTool(Tool): def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: """ - get prompt tokens + get prompt tokens - :param prompt_messages: the prompt messages - :return: the tokens + :param prompt_messages: the prompt messages + :return: the tokens """ - return ModelInvocationUtils.calculate_tokens( - tenant_id=self.runtime.tenant_id, - prompt_messages=prompt_messages - ) + return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages) def summary(self, user_id: str, content: str) -> str: max_tokens = self.get_max_tokens() - if self.get_prompt_tokens(prompt_messages=[ - UserPromptMessage(content=content) - ]) < max_tokens * 0.6: + if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=content)]) < max_tokens * 0.6: return content - + def get_prompt_tokens(content: str) -> int: - return self.get_prompt_tokens(prompt_messages=[ - SystemPromptMessage(content=_SUMMARY_PROMPT), - UserPromptMessage(content=content) - ]) - + return self.get_prompt_tokens( + prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)] + ) + def summarize(content: str) -> str: - summary = self.invoke_model(user_id=user_id, prompt_messages=[ - SystemPromptMessage(content=_SUMMARY_PROMPT), - UserPromptMessage(content=content) - ], stop=[]) + summary = self.invoke_model( + user_id=user_id, + prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)], + stop=[], + ) return summary.message.content - lines = content.split('\n') + lines = content.split("\n") new_lines = [] # split long line into multiple lines for i in range(len(lines)): @@ -100,8 +92,8 @@ class BuiltinTool(Tool): new_lines.append(line) elif get_prompt_tokens(line) > max_tokens * 0.7: while get_prompt_tokens(line) > max_tokens * 0.7: - new_lines.append(line[:int(max_tokens * 0.5)]) - line = line[int(max_tokens * 0.5):] + new_lines.append(line[: int(max_tokens * 0.5)]) + line = line[int(max_tokens * 0.5) :] new_lines.append(line) else: new_lines.append(line) @@ -125,17 +117,15 @@ class BuiltinTool(Tool): summary = summarize(message) summaries.append(summary) - result = '\n'.join(summaries) + result = "\n".join(summaries) - if self.get_prompt_tokens(prompt_messages=[ - UserPromptMessage(content=result) - ]) > max_tokens * 0.7: + if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=result)]) > max_tokens * 0.7: return self.summary(user_id=user_id, content=result) - + return result - + def get_url(self, url: str, user_agent: str = None) -> str: """ - get url + get url """ - return get_url(url, user_agent=user_agent) \ No newline at end of file + return get_url(url, user_agent=user_agent) diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index d6ecc9257b..e76af6fe70 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -14,14 +14,11 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -31,6 +28,7 @@ class DatasetMultiRetrieverToolInput(BaseModel): class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying multi dataset.""" + name: str = "dataset_" args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput description: str = "dataset multi retriever and rerank. " @@ -38,27 +36,26 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): reranking_provider_name: str reranking_model_name: str - @classmethod def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): return cls( - name=f"dataset_{tenant_id.replace('-', '_')}", - tenant_id=tenant_id, - dataset_ids=dataset_ids, - **kwargs + name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs ) def _run(self, query: str) -> str: threads = [] all_documents = [] for dataset_id in self.dataset_ids: - retrieval_thread = threading.Thread(target=self._retriever, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'all_documents': all_documents, - 'hit_callbacks': self.hit_callbacks - }) + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "all_documents": all_documents, + "hit_callbacks": self.hit_callbacks, + }, + ) threads.append(retrieval_thread) retrieval_thread.start() for thread in threads: @@ -69,7 +66,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): tenant_id=self.tenant_id, provider=self.reranking_provider_name, model_type=ModelType.RERANK, - model=self.reranking_model_name + model=self.reranking_model_name, ) rerank_runner = RerankModelRunner(rerank_model_instance) @@ -80,62 +77,61 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): document_score_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if self.return_resource: context_list = [] resource_number = 1 for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ).first() - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': self.retriever_from, - 'score': document_score_list.get(segment.index_node_id, None) + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), } - if self.retriever_from == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 @@ -144,13 +140,18 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): return str("\n".join(document_context_list)) - def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list, - hit_callbacks: list[DatasetIndexToolCallbackHandler]): + def _retriever( + self, + flask_app: Flask, + dataset_id: str, + query: str, + all_documents: list, + hit_callbacks: list[DatasetIndexToolCallbackHandler], + ): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == dataset_id - ).first() + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() + ) if not dataset: return [] @@ -163,27 +164,29 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrieval_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=self.top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + ) if documents: all_documents.extend(documents) else: if self.top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'], - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) - all_documents.extend(documents) \ No newline at end of file + all_documents.extend(documents) diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py index 62e97a0230..dad8c77357 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py @@ -9,6 +9,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa class DatasetRetrieverBaseTool(BaseModel, ABC): """Tool for querying a Dataset.""" + name: str = "dataset" description: str = "use this to retrieve a dataset. " tenant_id: str diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 220e4baa85..f61458278e 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService @@ -8,15 +7,12 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'reranking_mode': 'reranking_model', - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "reranking_mode": "reranking_model", + "top_k": 2, + "score_threshold_enabled": False, } @@ -26,35 +22,34 @@ class DatasetRetrieverToolInput(BaseModel): class DatasetRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying a Dataset.""" + name: str = "dataset" args_schema: type[BaseModel] = DatasetRetrieverToolInput description: str = "use this to retrieve a dataset. " dataset_id: str - @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): description = dataset.description if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name + description = "useful for when you want to answer queries about the " + dataset.name - description = description.replace('\n', '').replace('\r', '') + description = description.replace("\n", "").replace("\r", "") return cls( name=f"dataset_{dataset.id.replace('-', '_')}", tenant_id=dataset.tenant_id, dataset_id=dataset.id, description=description, - **kwargs + **kwargs, ) def _run(self, query: str) -> str: - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == self.dataset_id - ).first() + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() + ) if not dataset: - return '' + return "" for hit_callback in self.hit_callbacks: hit_callback.on_query(query, dataset.id) @@ -63,27 +58,29 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrieval_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=self.top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + ) return str("\n".join([document.page_content for document in documents])) else: if self.top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrieval_method=retrieval_model.get('search_method', 'semantic_search'), - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) else: documents = [] @@ -92,25 +89,26 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): document_score_list = {} if dataset.indexing_technique != "economy": for item in documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in documents] - segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) - ).all() + index_node_ids = [document.metadata["doc_id"] for document in documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id == self.dataset_id, + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if self.return_resource: @@ -118,36 +116,36 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): resource_number = 1 for segment in sorted_segments: context = {} - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': self.retriever_from, - 'score': document_score_list.get(segment.index_node_id, None) - + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), } - if self.retriever_from == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) - return str("\n".join(document_context_list)) \ No newline at end of file + return str("\n".join(document_context_list)) diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index b5698ad230..3c9295c493 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -20,13 +20,14 @@ class DatasetRetrieverTool(Tool): retrieval_tool: DatasetRetrieverBaseTool @staticmethod - def get_dataset_tools(tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler - ) -> list['DatasetRetrieverTool']: + def get_dataset_tools( + tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler, + ) -> list["DatasetRetrieverTool"]: """ get dataset tool """ @@ -48,7 +49,7 @@ class DatasetRetrieverTool(Tool): retrieve_config=retrieve_config, return_resource=return_resource, invoke_from=invoke_from, - hit_callback=hit_callback + hit_callback=hit_callback, ) # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode @@ -58,13 +59,13 @@ class DatasetRetrieverTool(Tool): for retrieval_tool in retrieval_tools: tool = DatasetRetrieverTool( retrieval_tool=retrieval_tool, - identity=ToolIdentity(provider='', author='', name=retrieval_tool.name, label=I18nObject(en_US='', zh_Hans='')), + identity=ToolIdentity( + provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") + ), parameters=[], is_team_authorization=True, - description=ToolDescription( - human=I18nObject(en_US='', zh_Hans=''), - llm=retrieval_tool.description), - runtime=DatasetRetrieverTool.Runtime() + description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), + runtime=DatasetRetrieverTool.Runtime(), ) tools.append(tool) @@ -73,16 +74,18 @@ class DatasetRetrieverTool(Tool): def get_runtime_parameters(self) -> list[ToolParameter]: return [ - ToolParameter(name='query', - label=I18nObject(en_US='', zh_Hans=''), - human_description=I18nObject(en_US='', zh_Hans=''), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Query for the dataset to be used to retrieve the dataset.', - required=True, - default=''), + ToolParameter( + name="query", + label=I18nObject(en_US="", zh_Hans=""), + human_description=I18nObject(en_US="", zh_Hans=""), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Query for the dataset to be used to retrieve the dataset.", + required=True, + default="", + ), ] - + def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.DATASET_RETRIEVAL @@ -90,9 +93,9 @@ class DatasetRetrieverTool(Tool): """ invoke dataset retriever tool """ - query = tool_parameters.get('query') + query = tool_parameters.get("query") if not query: - return self.create_text_message(text='please input query') + return self.create_text_message(text="please input query") # invoke dataset retriever tool result = self.retrieval_tool._run(query=query) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index d990131b5f..ac3dc84db4 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -35,15 +35,16 @@ class Tool(BaseModel, ABC): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - @field_validator('parameters', mode='before') + @field_validator("parameters", mode="before") @classmethod def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: return v or [] class Runtime(BaseModel): """ - Meta data of a tool call processing + Meta data of a tool call processing """ + def __init__(self, **data: Any): super().__init__(**data) if not self.runtime_parameters: @@ -63,14 +64,14 @@ class Tool(BaseModel, ABC): super().__init__(**data) class VARIABLE_KEY(Enum): - IMAGE = 'image' + IMAGE = "image" - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=self.identity.model_copy() if self.identity else None, @@ -82,22 +83,22 @@ class Tool(BaseModel, ABC): @abstractmethod def tool_provider_type(self) -> ToolProviderType: """ - get the tool provider type + get the tool provider type - :return: the tool provider type + :return: the tool provider type """ def load_variables(self, variables: ToolRuntimeVariablePool): """ - load variables from database + load variables from database - :param conversation_id: the conversation id + :param conversation_id: the conversation id """ self.variables = variables def set_image_variable(self, variable_name: str, image_key: str) -> None: """ - set an image variable + set an image variable """ if not self.variables: return @@ -106,7 +107,7 @@ class Tool(BaseModel, ABC): def set_text_variable(self, variable_name: str, text: str) -> None: """ - set a text variable + set a text variable """ if not self.variables: return @@ -115,10 +116,10 @@ class Tool(BaseModel, ABC): def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: """ - get a variable + get a variable - :param name: the name of the variable - :return: the variable + :param name: the name of the variable + :return: the variable """ if not self.variables: return None @@ -134,9 +135,9 @@ class Tool(BaseModel, ABC): def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: """ - get the default image variable + get the default image variable - :return: the image variable + :return: the image variable """ if not self.variables: return None @@ -145,10 +146,10 @@ class Tool(BaseModel, ABC): def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: """ - get a variable file + get a variable file - :param name: the name of the variable - :return: the variable file + :param name: the name of the variable + :return: the variable file """ variable = self.get_variable(name) if not variable: @@ -167,9 +168,9 @@ class Tool(BaseModel, ABC): def list_variables(self) -> list[ToolRuntimeVariable]: """ - list all variables + list all variables - :return: the variables + :return: the variables """ if not self.variables: return [] @@ -178,9 +179,9 @@ class Tool(BaseModel, ABC): def list_default_image_variables(self) -> list[ToolRuntimeVariable]: """ - list all image variables + list all image variables - :return: the image variables + :return: the image variables """ if not self.variables: return [] @@ -220,38 +221,42 @@ class Tool(BaseModel, ABC): result = deepcopy(tool_parameters) for parameter in self.parameters or []: if parameter.name in tool_parameters: - result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(tool_parameters[parameter.name], parameter.type) + result[parameter.name] = ToolParameterConverter.cast_parameter_by_type( + tool_parameters[parameter.name], parameter.type + ) return result @abstractmethod - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: pass def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: """ - validate the credentials + validate the credentials - :param credentials: the credentials - :param parameters: the parameters + :param credentials: the credentials + :param parameters: the parameters """ pass def get_runtime_parameters(self) -> list[ToolParameter]: """ - get the runtime parameters + get the runtime parameters - interface for developer to dynamic change the parameters of a tool depends on the variables pool + interface for developer to dynamic change the parameters of a tool depends on the variables pool - :return: the runtime parameters + :return: the runtime parameters """ return self.parameters or [] def get_all_runtime_parameters(self) -> list[ToolParameter]: """ - get all runtime parameters + get all runtime parameters - :return: all runtime parameters + :return: all runtime parameters """ parameters = self.parameters or [] parameters = parameters.copy() @@ -281,67 +286,49 @@ class Tool(BaseModel, ABC): return parameters - def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: + def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage: """ - create an image message + create an image message - :param image: the url of the image - :return: the image message + :param image: the url of the image + :return: the image message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, - message=image, - save_as=save_as) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as) def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, - message='', - meta={ - 'file_var': file_var - }, - save_as='') - - def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: - """ - create a link message - - :param link: the url of the link - :return: the link message - """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, - message=link, - save_as=save_as) - - def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: - """ - create a text message - - :param text: the text - :return: the text message - """ return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=text, - save_as=save_as + type=ToolInvokeMessage.MessageType.FILE_VAR, message="", meta={"file_var": file_var}, save_as="" ) - def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: + def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage: """ - create a blob message + create a link message - :param blob: the blob - :return: the blob message + :param link: the url of the link + :return: the link message """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.BLOB, - message=blob, meta=meta, - save_as=save_as - ) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, message=link, save_as=save_as) + + def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage: + """ + create a text message + + :param text: the text + :return: the text message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as) + + def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = "") -> ToolInvokeMessage: + """ + create a blob message + + :param blob: the blob + :return: the blob message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, message=blob, meta=meta, save_as=save_as) def create_json_message(self, object: dict) -> ToolInvokeMessage: """ - create a json message + create a json message """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.JSON, - message=object - ) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object) diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 15e915628e..ad0c7fc631 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -13,6 +13,7 @@ from models.workflow import Workflow logger = logging.getLogger(__name__) + class WorkflowTool(Tool): workflow_app_id: str version: str @@ -25,11 +26,12 @@ class WorkflowTool(Tool): """ Workflow tool. """ + def tool_provider_type(self) -> ToolProviderType: """ - get the tool provider type + get the tool provider type - :return: the tool provider type + :return: the tool provider type """ return ToolProviderType.WORKFLOW @@ -37,7 +39,7 @@ class WorkflowTool(Tool): self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke the tool + invoke the tool """ app = self._get_app(app_id=self.workflow_app_id) workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version) @@ -46,33 +48,31 @@ class WorkflowTool(Tool): tool_parameters, files = self._transform_args(tool_parameters) from core.app.apps.workflow.app_generator import WorkflowAppGenerator + generator = WorkflowAppGenerator() result = generator.generate( - app_model=app, - workflow=workflow, - user=self._get_user(user_id), - args={ - 'inputs': tool_parameters, - 'files': files - }, + app_model=app, + workflow=workflow, + user=self._get_user(user_id), + args={"inputs": tool_parameters, "files": files}, invoke_from=self.runtime.invoke_from, stream=False, call_depth=self.workflow_call_depth + 1, - workflow_thread_pool_id=self.thread_pool_id + workflow_thread_pool_id=self.thread_pool_id, ) - data = result.get('data', {}) + data = result.get("data", {}) + + if data.get("error"): + raise Exception(data.get("error")) - if data.get('error'): - raise Exception(data.get('error')) - result = [] - outputs = data.get('outputs', {}) + outputs = data.get("outputs", {}) outputs, files = self._extract_files(outputs) for file in files: result.append(self.create_file_var_message(file)) - + result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) result.append(self.create_json_message(outputs)) @@ -80,7 +80,7 @@ class WorkflowTool(Tool): def _get_user(self, user_id: str) -> Union[EndUser, Account]: """ - get the user by user id + get the user by user id """ user = db.session.query(EndUser).filter(EndUser.id == user_id).first() @@ -88,16 +88,16 @@ class WorkflowTool(Tool): user = db.session.query(Account).filter(Account.id == user_id).first() if not user: - raise ValueError('user not found') + raise ValueError("user not found") return user - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "WorkflowTool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=deepcopy(self.identity), @@ -108,45 +108,44 @@ class WorkflowTool(Tool): workflow_entities=self.workflow_entities, workflow_call_depth=self.workflow_call_depth, version=self.version, - label=self.label + label=self.label, ) - + def _get_workflow(self, app_id: str, version: str) -> Workflow: """ - get the workflow by app id and version + get the workflow by app id and version """ if not version: - workflow = db.session.query(Workflow).filter( - Workflow.app_id == app_id, - Workflow.version != 'draft' - ).order_by(Workflow.created_at.desc()).first() + workflow = ( + db.session.query(Workflow) + .filter(Workflow.app_id == app_id, Workflow.version != "draft") + .order_by(Workflow.created_at.desc()) + .first() + ) else: - workflow = db.session.query(Workflow).filter( - Workflow.app_id == app_id, - Workflow.version == version - ).first() + workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first() if not workflow: - raise ValueError('workflow not found or not published') + raise ValueError("workflow not found or not published") return workflow - + def _get_app(self, app_id: str) -> App: """ - get the app by app id + get the app by app id """ app = db.session.query(App).filter(App.id == app_id).first() if not app: - raise ValueError('app not found') + raise ValueError("app not found") return app - + def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: """ - transform the tool parameters + transform the tool parameters - :param tool_parameters: the tool parameters - :return: tool_parameters, files + :param tool_parameters: the tool parameters + :return: tool_parameters, files """ parameter_rules = self.get_all_runtime_parameters() parameters_result = {} @@ -159,15 +158,15 @@ class WorkflowTool(Tool): file_var_list = [FileVar(**f) for f in file] for file_var in file_var_list: file_dict = { - 'transfer_method': file_var.transfer_method.value, - 'type': file_var.type.value, + "transfer_method": file_var.transfer_method.value, + "type": file_var.type.value, } if file_var.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict['tool_file_id'] = file_var.related_id + file_dict["tool_file_id"] = file_var.related_id elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict['upload_file_id'] = file_var.related_id + file_dict["upload_file_id"] = file_var.related_id elif file_var.transfer_method == FileTransferMethod.REMOTE_URL: - file_dict['url'] = file_var.preview_url + file_dict["url"] = file_var.preview_url files.append(file_dict) except Exception as e: @@ -176,13 +175,13 @@ class WorkflowTool(Tool): parameters_result[parameter.name] = tool_parameters.get(parameter.name) return parameters_result, files - + def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]: """ - extract files from the result + extract files from the result - :param result: the result - :return: the result, files + :param result: the result + :return: the result, files """ files = [] result = {} @@ -190,7 +189,7 @@ class WorkflowTool(Tool): if isinstance(value, list): has_file = False for item in value: - if isinstance(item, dict) and item.get('__variant') == 'FileVar': + if isinstance(item, dict) and item.get("__variant") == "FileVar": try: files.append(FileVar(**item)) has_file = True @@ -201,4 +200,4 @@ class WorkflowTool(Tool): result[key] = value - return result, files \ No newline at end of file + return result, files diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 6c0e906628..9a6a49d8f4 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -33,12 +33,17 @@ class ToolEngine: """ Tool runtime engine take care of the tool executions. """ + @staticmethod def agent_invoke( - tool: Tool, tool_parameters: Union[str, dict], - user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom, + tool: Tool, + tool_parameters: Union[str, dict], + user_id: str, + tenant_id: str, + message: Message, + invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. @@ -47,40 +52,30 @@ class ToolEngine: if isinstance(tool_parameters, str): # check if this tool has only one parameter parameters = [ - parameter for parameter in tool.get_runtime_parameters() or [] + parameter + for parameter in tool.get_runtime_parameters() or [] if parameter.form == ToolParameter.ToolParameterForm.LLM ] if parameters and len(parameters) == 1: - tool_parameters = { - parameters[0].name: tool_parameters - } + tool_parameters = {parameters[0].name: tool_parameters} else: raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool try: # hit the callback handler - agent_tool_callback.on_tool_start( - tool_name=tool.identity.name, - tool_inputs=tool_parameters - ) + agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) meta, response = ToolEngine._invoke(tool, tool_parameters, user_id) response = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=response, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=message.conversation_id + messages=response, user_id=user_id, tenant_id=tenant_id, conversation_id=message.conversation_id ) # extract binary data from tool invoke message binary_files = ToolEngine._extract_tool_response_binary(response) # create message file message_files = ToolEngine._create_message_files( - tool_messages=binary_files, - agent_message=message, - invoke_from=invoke_from, - user_id=user_id + tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id ) plain_text = ToolEngine._convert_tool_response_to_str(response) @@ -91,7 +86,7 @@ class ToolEngine: tool_inputs=tool_parameters, tool_outputs=plain_text, message_id=message.id, - trace_manager=trace_manager + trace_manager=trace_manager, ) # transform tool invoke message to get LLM friendly message @@ -99,14 +94,10 @@ class ToolEngine: except ToolProviderCredentialValidationError as e: error_response = "Please check your tool provider credentials" agent_tool_callback.on_tool_error(e) - except ( - ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError - ) as e: + except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: error_response = f"there is not a tool named {tool.identity.name}" agent_tool_callback.on_tool_error(e) - except ( - ToolParameterValidationError - ) as e: + except ToolParameterValidationError as e: error_response = f"tool parameters validation error: {e}, please check your tool parameters" agent_tool_callback.on_tool_error(e) except ToolInvokeError as e: @@ -124,21 +115,20 @@ class ToolEngine: return error_response, [], ToolInvokeMeta.error_instance(error_response) @staticmethod - def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any], - user_id: str, - workflow_tool_callback: DifyWorkflowCallbackHandler, - workflow_call_depth: int, - thread_pool_id: Optional[str] = None - ) -> list[ToolInvokeMessage]: + def workflow_invoke( + tool: Tool, + tool_parameters: Mapping[str, Any], + user_id: str, + workflow_tool_callback: DifyWorkflowCallbackHandler, + workflow_call_depth: int, + thread_pool_id: Optional[str] = None, + ) -> list[ToolInvokeMessage]: """ Workflow invokes the tool with the given arguments. """ try: # hit the callback handler - workflow_tool_callback.on_tool_start( - tool_name=tool.identity.name, - tool_inputs=tool_parameters - ) + workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) if isinstance(tool, WorkflowTool): tool.workflow_call_depth = workflow_call_depth + 1 @@ -159,21 +149,24 @@ class ToolEngine: except Exception as e: workflow_tool_callback.on_tool_error(e) raise e - + @staticmethod - def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \ - -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]: + def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]: """ Invoke the tool with the given arguments. """ started_at = datetime.now(timezone.utc) - meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={ - 'tool_name': tool.identity.name, - 'tool_provider': tool.identity.provider, - 'tool_provider_type': tool.tool_provider_type().value, - 'tool_parameters': deepcopy(tool.runtime.runtime_parameters), - 'tool_icon': tool.identity.icon - }) + meta = ToolInvokeMeta( + time_cost=0.0, + error=None, + tool_config={ + "tool_name": tool.identity.name, + "tool_provider": tool.identity.provider, + "tool_provider_type": tool.tool_provider_type().value, + "tool_parameters": deepcopy(tool.runtime.runtime_parameters), + "tool_icon": tool.identity.icon, + }, + ) try: response = tool.invoke(user_id, tool_parameters) except Exception as e: @@ -184,20 +177,22 @@ class ToolEngine: meta.time_cost = (ended_at - started_at).total_seconds() return meta, response - + @staticmethod def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: """ Handle tool response """ - result = '' + result = "" for response in tool_response: if response.type == ToolInvokeMessage.MessageType.TEXT: result += response.message elif response.type == ToolInvokeMessage.MessageType.LINK: result += f"result link: {response.message}. please tell user to check it." - elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: + elif ( + response.type == ToolInvokeMessage.MessageType.IMAGE_LINK + or response.type == ToolInvokeMessage.MessageType.IMAGE + ): result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now." elif response.type == ToolInvokeMessage.MessageType.JSON: result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}." @@ -205,7 +200,7 @@ class ToolEngine: result += f"tool response: {response.message}." return result - + @staticmethod def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]: """ @@ -214,52 +209,59 @@ class ToolEngine: result = [] for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: + if ( + response.type == ToolInvokeMessage.MessageType.IMAGE_LINK + or response.type == ToolInvokeMessage.MessageType.IMAGE + ): mimetype = None - if response.meta.get('mime_type'): - mimetype = response.meta.get('mime_type') + if response.meta.get("mime_type"): + mimetype = response.meta.get("mime_type") else: try: url = URL(response.message) extension = url.suffix - guess_type_result, _ = guess_type(f'a{extension}') + guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: mimetype = guess_type_result except Exception: pass - + if not mimetype: - mimetype = 'image/jpeg' - - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'image/jpeg'), - url=response.message, - save_as=response.save_as, - )) - elif response.type == ToolInvokeMessage.MessageType.BLOB: - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream'), - url=response.message, - save_as=response.save_as, - )) - elif response.type == ToolInvokeMessage.MessageType.LINK: - # check if there is a mime type in meta - if response.meta and 'mime_type' in response.meta: - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream', + mimetype = "image/jpeg" + + result.append( + ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "image/jpeg"), url=response.message, save_as=response.save_as, - )) + ) + ) + elif response.type == ToolInvokeMessage.MessageType.BLOB: + result.append( + ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "octet/stream"), + url=response.message, + save_as=response.save_as, + ) + ) + elif response.type == ToolInvokeMessage.MessageType.LINK: + # check if there is a mime type in meta + if response.meta and "mime_type" in response.meta: + result.append( + ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "octet/stream") + if response.meta + else "octet/stream", + url=response.message, + save_as=response.save_as, + ) + ) return result - + @staticmethod def _create_message_files( - tool_messages: list[ToolInvokeMessageBinary], - agent_message: Message, - invoke_from: InvokeFrom, - user_id: str + tool_messages: list[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str ) -> list[tuple[Any, str]]: """ Create message file @@ -270,29 +272,29 @@ class ToolEngine: result = [] for message in tool_messages: - file_type = 'bin' - if 'image' in message.mimetype: - file_type = 'image' - elif 'video' in message.mimetype: - file_type = 'video' - elif 'audio' in message.mimetype: - file_type = 'audio' - elif 'text' in message.mimetype: - file_type = 'text' - elif 'pdf' in message.mimetype: - file_type = 'pdf' - elif 'zip' in message.mimetype: - file_type = 'archive' + file_type = "bin" + if "image" in message.mimetype: + file_type = "image" + elif "video" in message.mimetype: + file_type = "video" + elif "audio" in message.mimetype: + file_type = "audio" + elif "text" in message.mimetype: + file_type = "text" + elif "pdf" in message.mimetype: + file_type = "pdf" + elif "zip" in message.mimetype: + file_type = "archive" # ... message_file = MessageFile( message_id=agent_message.id, type=file_type, transfer_method=FileTransferMethod.TOOL_FILE.value, - belongs_to='assistant', + belongs_to="assistant", url=message.url, upload_file_id=None, - created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), + created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"), created_by=user_id, ) @@ -300,11 +302,8 @@ class ToolEngine: db.session.commit() db.session.refresh(message_file) - result.append(( - message_file.id, - message.save_as - )) + result.append((message_file.id, message.save_as)) db.session.close() - return result \ No newline at end of file + return result diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index f9f7c7d78a..ad3b9c7328 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -27,24 +27,24 @@ class ToolFileManager: sign file to get a temporary url """ base_url = dify_config.FILES_URL - file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}' + file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() - data_to_sign = f'file-preview|{tool_file_id}|{timestamp}|{nonce}' - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f'{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}' + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" @staticmethod def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: """ verify signature """ - data_to_sign = f'file-preview|{file_id}|{timestamp}|{nonce}' - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() @@ -62,9 +62,9 @@ class ToolFileManager: """ create file """ - extension = guess_extension(mimetype) or '.bin' + extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f'tools/{tenant_id}/{unique_name}{extension}' + filename = f"tools/{tenant_id}/{unique_name}{extension}" storage.save(filename, file_binary) tool_file = ToolFile( @@ -90,10 +90,10 @@ class ToolFileManager: response = get(file_url) response.raise_for_status() blob = response.content - mimetype = guess_type(file_url)[0] or 'octet/stream' - extension = guess_extension(mimetype) or '.bin' + mimetype = guess_type(file_url)[0] or "octet/stream" + extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f'tools/{tenant_id}/{unique_name}{extension}' + filename = f"tools/{tenant_id}/{unique_name}{extension}" storage.save(filename, blob) tool_file = ToolFile( @@ -166,13 +166,12 @@ class ToolFileManager: # Check if message_file is not None if message_file is not None: # get tool file id - tool_file_id = message_file.url.split('/')[-1] + tool_file_id = message_file.url.split("/")[-1] # trim extension - tool_file_id = tool_file_id.split('.')[0] + tool_file_id = tool_file_id.split(".")[0] else: tool_file_id = None - tool_file: ToolFile = ( db.session.query(ToolFile) .filter( @@ -216,4 +215,4 @@ class ToolFileManager: # init tool_file_parser from core.file.tool_file_parser import tool_file_manager -tool_file_manager['manager'] = ToolFileManager +tool_file_manager["manager"] = ToolFileManager diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 97788a7a07..2a5a2944ef 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -15,7 +15,7 @@ class ToolLabelManager: """ tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] return list(set(tool_labels)) - + @classmethod def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): """ @@ -26,20 +26,20 @@ class ToolLabelManager: if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): provider_id = controller.provider_id else: - raise ValueError('Unsupported tool type') + raise ValueError("Unsupported tool type") # delete old labels - db.session.query(ToolLabelBinding).filter( - ToolLabelBinding.tool_id == provider_id - ).delete() + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() # insert new labels for label in labels: - db.session.add(ToolLabelBinding( - tool_id=provider_id, - tool_type=controller.provider_type.value, - label_name=label, - )) + db.session.add( + ToolLabelBinding( + tool_id=provider_id, + tool_type=controller.provider_type.value, + label_name=label, + ) + ) db.session.commit() @@ -53,12 +53,16 @@ class ToolLabelManager: elif isinstance(controller, BuiltinToolProviderController): return controller.tool_labels else: - raise ValueError('Unsupported tool type') + raise ValueError("Unsupported tool type") - labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter( - ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, - ).all() + labels: list[ToolLabelBinding] = ( + db.session.query(ToolLabelBinding.label_name) + .filter( + ToolLabelBinding.tool_id == provider_id, + ToolLabelBinding.tool_type == controller.provider_type.value, + ) + .all() + ) return [label.label_name for label in labels] @@ -75,22 +79,20 @@ class ToolLabelManager: """ if not tool_providers: return {} - + for controller in tool_providers: if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - raise ValueError('Unsupported tool type') - + raise ValueError("Unsupported tool type") + provider_ids = [controller.provider_id for controller in tool_providers] - labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter( - ToolLabelBinding.tool_id.in_(provider_ids) - ).all() + labels: list[ToolLabelBinding] = ( + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() + ) - tool_labels = { - label.tool_id: [] for label in labels - } + tool_labels = {label.tool_id: [] for label in labels} for label in labels: tool_labels[label.tool_id].append(label.label_name) - return tool_labels \ No newline at end of file + return tool_labels diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 4778d79ed9..a3303797e1 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -41,29 +41,29 @@ class ToolManager: @classmethod def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: """ - get the builtin provider + get the builtin provider - :param provider: the name of the provider - :return: the provider + :param provider: the name of the provider + :return: the provider """ if len(cls._builtin_providers) == 0: # init the builtin providers cls.load_builtin_providers_cache() if provider not in cls._builtin_providers: - raise ToolProviderNotFoundError(f'builtin provider {provider} not found') + raise ToolProviderNotFoundError(f"builtin provider {provider} not found") return cls._builtin_providers[provider] @classmethod def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: """ - get the builtin tool + get the builtin tool - :param provider: the name of the provider - :param tool_name: the name of the tool + :param provider: the name of the provider + :param tool_name: the name of the tool - :return: the provider, the tool + :return: the provider, the tool """ provider_controller = cls.get_builtin_provider(provider) tool = provider_controller.get_tool(tool_name) @@ -71,67 +71,76 @@ class ToolManager: return tool @classmethod - def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \ - -> Union[BuiltinTool, ApiTool]: + def get_tool( + cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None + ) -> Union[BuiltinTool, ApiTool]: """ - get the tool + get the tool - :param provider_type: the type of the provider - :param provider_name: the name of the provider - :param tool_name: the name of the tool + :param provider_type: the type of the provider + :param provider_name: the name of the provider + :param tool_name: the name of the tool - :return: the tool + :return: the tool """ - if provider_type == 'builtin': + if provider_type == "builtin": return cls.get_builtin_tool(provider_id, tool_name) - elif provider_type == 'api': + elif provider_type == "api": if tenant_id is None: - raise ValueError('tenant id is required for api provider') + raise ValueError("tenant id is required for api provider") api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id) return api_provider.get_tool(tool_name) - elif provider_type == 'app': - raise NotImplementedError('app provider not implemented') + elif provider_type == "app": + raise NotImplementedError("app provider not implemented") else: - raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + raise ToolProviderNotFoundError(f"provider type {provider_type} not found") @classmethod - def get_tool_runtime(cls, provider_type: str, - provider_id: str, - tool_name: str, - tenant_id: str, - invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ - -> Union[BuiltinTool, ApiTool]: + def get_tool_runtime( + cls, + provider_type: str, + provider_id: str, + tool_name: str, + tenant_id: str, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, + ) -> Union[BuiltinTool, ApiTool]: """ - get the tool runtime + get the tool runtime - :param provider_type: the type of the provider - :param provider_name: the name of the provider - :param tool_name: the name of the tool + :param provider_type: the type of the provider + :param provider_name: the name of the provider + :param tool_name: the name of the tool - :return: the tool + :return: the tool """ - if provider_type == 'builtin': + if provider_type == "builtin": builtin_tool = cls.get_builtin_tool(provider_id, tool_name) # check if the builtin tool need credentials provider_controller = cls.get_builtin_provider(provider_id) if not provider_controller.need_credentials: - return builtin_tool.fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) + return builtin_tool.fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) # get credentials - builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_id, - ).first() + builtin_provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_id, + ) + .first() + ) if builtin_provider is None: - raise ToolProviderNotFoundError(f'builtin provider {provider_id} not found') + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") # decrypt the credentials credentials = builtin_provider.credentials @@ -140,17 +149,19 @@ class ToolManager: decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return builtin_tool.fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': decrypted_credentials, - 'runtime_parameters': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) + return builtin_tool.fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": decrypted_credentials, + "runtime_parameters": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) - elif provider_type == 'api': + elif provider_type == "api": if tenant_id is None: - raise ValueError('tenant id is required for api provider') + raise ValueError("tenant id is required for api provider") api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) @@ -158,40 +169,43 @@ class ToolManager: tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': decrypted_credentials, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) - elif provider_type == 'workflow': - workflow_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == provider_id - ).first() - - if workflow_provider is None: - raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') - - controller = ToolTransformService.workflow_provider_to_controller( - db_provider=workflow_provider + return api_provider.get_tool(tool_name).fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": decrypted_credentials, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) + elif provider_type == "workflow": + workflow_provider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() ) - return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) - elif provider_type == 'app': - raise NotImplementedError('app provider not implemented') + if workflow_provider is None: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + + return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) + elif provider_type == "app": + raise NotImplementedError("app provider not implemented") else: - raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + raise ToolProviderNotFoundError(f"provider type {provider_type} not found") @classmethod def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: """ - init runtime parameter + init runtime parameter """ parameter_value = parameters.get(parameter_rule.name) if not parameter_value and parameter_value != 0: @@ -205,14 +219,17 @@ class ToolManager: options = [x.value for x in parameter_rule.options] if parameter_value is not None and parameter_value not in options: raise ValueError( - f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") + f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" + ) return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type) @classmethod - def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_agent_tool_runtime( + cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER + ) -> Tool: """ - get the agent tool runtime + get the agent tool runtime """ tool_entity = cls.get_tool_runtime( provider_type=agent_tool.provider_type, @@ -220,7 +237,7 @@ class ToolManager: tool_name=agent_tool.tool_name, tenant_id=tenant_id, invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.AGENT + tool_invoke_from=ToolInvokeFrom.AGENT, ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -240,7 +257,7 @@ class ToolManager: tool_runtime=tool_entity, provider_name=agent_tool.provider_id, provider_type=agent_tool.provider_type, - identity_id=f'AGENT.{app_id}' + identity_id=f"AGENT.{app_id}", ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) @@ -248,9 +265,16 @@ class ToolManager: return tool_entity @classmethod - def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_workflow_tool_runtime( + cls, + tenant_id: str, + app_id: str, + node_id: str, + workflow_tool: "ToolEntity", + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + ) -> Tool: """ - get the workflow tool runtime + get the workflow tool runtime """ tool_entity = cls.get_tool_runtime( provider_type=workflow_tool.provider_type, @@ -258,7 +282,7 @@ class ToolManager: tool_name=workflow_tool.tool_name, tenant_id=tenant_id, invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW + tool_invoke_from=ToolInvokeFrom.WORKFLOW, ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -275,7 +299,7 @@ class ToolManager: tool_runtime=tool_entity, provider_name=workflow_tool.provider_id, provider_type=workflow_tool.provider_type, - identity_id=f'WORKFLOW.{app_id}.{node_id}' + identity_id=f"WORKFLOW.{app_id}.{node_id}", ) if runtime_parameters: @@ -287,24 +311,30 @@ class ToolManager: @classmethod def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]: """ - get the absolute path of the icon of the builtin provider + get the absolute path of the icon of the builtin provider - :param provider: the name of the provider + :param provider: the name of the provider - :return: the absolute path of the icon, the mime type of the icon + :return: the absolute path of the icon, the mime type of the icon """ # get provider provider_controller = cls.get_builtin_provider(provider) - absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', - provider_controller.identity.icon) + absolute_path = path.join( + path.dirname(path.realpath(__file__)), + "provider", + "builtin", + provider, + "_assets", + provider_controller.identity.icon, + ) # check if the icon exists if not path.exists(absolute_path): - raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found') + raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found") # get the mime type mime_type, _ = mimetypes.guess_type(absolute_path) - mime_type = mime_type or 'application/octet-stream' + mime_type = mime_type or "application/octet-stream" return absolute_path, mime_type @@ -325,23 +355,25 @@ class ToolManager: @classmethod def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: """ - list all the builtin providers + list all the builtin providers """ - for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): - if provider.startswith('__'): + for provider in listdir(path.join(path.dirname(path.realpath(__file__)), "provider", "builtin")): + if provider.startswith("__"): continue - if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)): - if provider.startswith('__'): + if path.isdir(path.join(path.dirname(path.realpath(__file__)), "provider", "builtin", provider)): + if provider.startswith("__"): continue # init provider try: provider_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.{provider}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'provider', 'builtin', provider, f'{provider}.py'), - parent_type=BuiltinToolProviderController) + module_name=f"core.tools.provider.builtin.{provider}.{provider}", + script_path=path.join( + path.dirname(path.realpath(__file__)), "provider", "builtin", provider, f"{provider}.py" + ), + parent_type=BuiltinToolProviderController, + ) provider: BuiltinToolProviderController = provider_class() cls._builtin_providers[provider.identity.name] = provider for tool in provider.get_tools(): @@ -349,7 +381,7 @@ class ToolManager: yield provider except Exception as e: - logger.error(f'load builtin provider {provider} error: {e}') + logger.error(f"load builtin provider {provider} error: {e}") continue # set builtin providers loaded cls._builtin_providers_loaded = True @@ -367,11 +399,11 @@ class ToolManager: @classmethod def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: """ - get the tool label + get the tool label - :param tool_name: the name of the tool + :param tool_name: the name of the tool - :return: the label of the tool + :return: the label of the tool """ if len(cls._builtin_tools_labels) == 0: # init the builtin providers @@ -383,75 +415,78 @@ class ToolManager: return cls._builtin_tools_labels[tool_name] @classmethod - def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral) -> list[UserToolProvider]: + def user_list_providers( + cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral + ) -> list[UserToolProvider]: result_providers: dict[str, UserToolProvider] = {} filters = [] if not typ: - filters.extend(['builtin', 'api', 'workflow']) + filters.extend(["builtin", "api", "workflow"]) else: filters.append(typ) - if 'builtin' in filters: - + if "builtin" in filters: # get builtin providers builtin_providers = cls.list_builtin_providers() # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ - filter(BuiltinToolProvider.tenant_id == tenant_id).all() + db_builtin_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() + ) find_db_builtin_provider = lambda provider: next( - (x for x in db_builtin_providers if x.provider == provider), - None + (x for x in db_builtin_providers if x.provider == provider), None ) # append builtin providers for provider in builtin_providers: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, - data=provider, - name_func=lambda x: x.identity.name + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider, + name_func=lambda x: x.identity.name, ): continue user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, db_provider=find_db_builtin_provider(provider.identity.name), - decrypt_credentials=False + decrypt_credentials=False, ) result_providers[provider.identity.name] = user_provider # get db api providers - if 'api' in filters: - db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ - filter(ApiToolProvider.tenant_id == tenant_id).all() + if "api" in filters: + db_api_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() + ) - api_provider_controllers = [{ - 'provider': provider, - 'controller': ToolTransformService.api_provider_to_controller(provider) - } for provider in db_api_providers] + api_provider_controllers = [ + {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} + for provider in db_api_providers + ] # get labels - labels = ToolLabelManager.get_tools_labels([x['controller'] for x in api_provider_controllers]) + labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) for api_provider_controller in api_provider_controllers: user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller=api_provider_controller['controller'], - db_provider=api_provider_controller['provider'], + provider_controller=api_provider_controller["controller"], + db_provider=api_provider_controller["provider"], decrypt_credentials=False, - labels=labels.get(api_provider_controller['controller'].provider_id, []) + labels=labels.get(api_provider_controller["controller"].provider_id, []), ) - result_providers[f'api_provider.{user_provider.name}'] = user_provider + result_providers[f"api_provider.{user_provider.name}"] = user_provider - if 'workflow' in filters: + if "workflow" in filters: # get workflow providers - workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \ - filter(WorkflowToolProvider.tenant_id == tenant_id).all() + workflow_providers: list[WorkflowToolProvider] = ( + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + ) workflow_provider_controllers = [] for provider in workflow_providers: @@ -470,32 +505,36 @@ class ToolManager: provider_controller=provider_controller, labels=labels.get(provider_controller.provider_id, []), ) - result_providers[f'workflow_provider.{user_provider.name}'] = user_provider + result_providers[f"workflow_provider.{user_provider.name}"] = user_provider return BuiltinToolProviderSort.sort(list(result_providers.values())) @classmethod - def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ - ApiToolProviderController, dict[str, Any]]: + def get_api_provider_controller( + cls, tenant_id: str, provider_id: str + ) -> tuple[ApiToolProviderController, dict[str, Any]]: """ - get the api provider + get the api provider - :param provider_name: the name of the provider + :param provider_name: the name of the provider - :return: the provider controller, the credentials + :return: the provider controller, the credentials """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.id == provider_id, - ApiToolProvider.tenant_id == tenant_id, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.id == provider_id, + ApiToolProvider.tenant_id == tenant_id, + ) + .first() + ) if provider is None: - raise ToolProviderNotFoundError(f'api provider {provider_id} not found') + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") controller = ApiToolProviderController.from_db( provider, - ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else - ApiProviderAuthType.NONE + ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) controller.load_bundled_tools(provider.tools) @@ -504,18 +543,22 @@ class ToolManager: @classmethod def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: """ - get api provider + get api provider """ """ get tool provider """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider}') + raise ValueError(f"you have not added provider {provider}") try: credentials = json.loads(provider.credentials_str) or {} @@ -524,7 +567,7 @@ class ToolManager: # package tool provider controller controller = ApiToolProviderController.from_db( - provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE ) # init tool configuration tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) @@ -535,62 +578,62 @@ class ToolManager: try: icon = json.loads(provider.icon) except: - icon = { - "background": "#252525", - "content": "\ud83d\ude01" - } + icon = {"background": "#252525", "content": "\ud83d\ude01"} # add tool labels labels = ToolLabelManager.get_tool_labels(controller) - return jsonable_encoder({ - 'schema_type': provider.schema_type, - 'schema': provider.schema, - 'tools': provider.tools, - 'icon': icon, - 'description': provider.description, - 'credentials': masked_credentials, - 'privacy_policy': provider.privacy_policy, - 'custom_disclaimer': provider.custom_disclaimer, - 'labels': labels, - }) + return jsonable_encoder( + { + "schema_type": provider.schema_type, + "schema": provider.schema, + "tools": provider.tools, + "icon": icon, + "description": provider.description, + "credentials": masked_credentials, + "privacy_policy": provider.privacy_policy, + "custom_disclaimer": provider.custom_disclaimer, + "labels": labels, + } + ) @classmethod def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]: """ - get the tool icon + get the tool icon - :param tenant_id: the id of the tenant - :param provider_type: the type of the provider - :param provider_id: the id of the provider - :return: + :param tenant_id: the id of the tenant + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :return: """ provider_type = provider_type provider_id = provider_id - if provider_type == 'builtin': - return (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/builtin/" - + provider_id - + "/icon") - elif provider_type == 'api': + if provider_type == "builtin": + return ( + dify_config.CONSOLE_API_URL + + "/console/api/workspaces/current/tool-provider/builtin/" + + provider_id + + "/icon" + ) + elif provider_type == "api": try: - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.id == provider_id - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) + .first() + ) return json.loads(provider.icon) except: - return { - "background": "#252525", - "content": "\ud83d\ude01" - } - elif provider_type == 'workflow': - provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == provider_id - ).first() + return {"background": "#252525", "content": "\ud83d\ude01"} + elif provider_type == "workflow": + provider: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() + ) if provider is None: - raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") return json.loads(provider.icon) else: diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index b213879e96..83600d21c1 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -56,12 +56,13 @@ class ToolConfigurationManager(BaseModel): if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: if field_name in credentials: if len(credentials[field_name]) > 6: - credentials[field_name] = \ - credentials[field_name][:2] + \ - '*' * (len(credentials[field_name]) - 4) + \ - credentials[field_name][-2:] + credentials[field_name] = ( + credentials[field_name][:2] + + "*" * (len(credentials[field_name]) - 4) + + credentials[field_name][-2:] + ) else: - credentials[field_name] = '*' * len(credentials[field_name]) + credentials[field_name] = "*" * len(credentials[field_name]) return credentials @@ -72,9 +73,9 @@ class ToolConfigurationManager(BaseModel): return a deep copy of credentials with decrypted values """ cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', - cache_type=ToolProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, + identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cached_credentials = cache.get() if cached_credentials: @@ -95,16 +96,18 @@ class ToolConfigurationManager(BaseModel): def delete_tool_credentials_cache(self): cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', - cache_type=ToolProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, + identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cache.delete() + class ToolParameterConfigurationManager(BaseModel): """ Tool parameter configuration manager """ + tenant_id: str tool_runtime: Tool provider_name: str @@ -152,15 +155,19 @@ class ToolParameterConfigurationManager(BaseModel): current_parameters = self._merge_parameters() for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: if len(parameters[parameter.name]) > 6: - parameters[parameter.name] = \ - parameters[parameter.name][:2] + \ - '*' * (len(parameters[parameter.name]) - 4) + \ - parameters[parameter.name][-2:] + parameters[parameter.name] = ( + parameters[parameter.name][:2] + + "*" * (len(parameters[parameter.name]) - 4) + + parameters[parameter.name][-2:] + ) else: - parameters[parameter.name] = '*' * len(parameters[parameter.name]) + parameters[parameter.name] = "*" * len(parameters[parameter.name]) return parameters @@ -176,7 +183,10 @@ class ToolParameterConfigurationManager(BaseModel): parameters = self._deep_copy(parameters) for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) parameters[parameter.name] = encrypted @@ -191,10 +201,10 @@ class ToolParameterConfigurationManager(BaseModel): """ cache = ToolParameterCache( tenant_id=self.tenant_id, - provider=f'{self.provider_type}.{self.provider_name}', + provider=f"{self.provider_type}.{self.provider_name}", tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, - identity_id=self.identity_id + identity_id=self.identity_id, ) cached_parameters = cache.get() if cached_parameters: @@ -205,7 +215,10 @@ class ToolParameterConfigurationManager(BaseModel): has_secret_input = False for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: try: has_secret_input = True @@ -221,9 +234,9 @@ class ToolParameterConfigurationManager(BaseModel): def delete_tool_parameters_cache(self): cache = ToolParameterCache( tenant_id=self.tenant_id, - provider=f'{self.provider_type}.{self.provider_name}', + provider=f"{self.provider_type}.{self.provider_name}", tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, - identity_id=self.identity_id + identity_id=self.identity_id, ) cache.delete() diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py index e6b288868f..7bb026a383 100644 --- a/api/core/tools/utils/feishu_api_utils.py +++ b/api/core/tools/utils/feishu_api_utils.py @@ -17,8 +17,9 @@ class FeishuRequest: redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) return res.get("tenant_access_token") - def _send_request(self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, - params: dict = None): + def _send_request( + self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, params: dict = None + ): headers = { "Content-Type": "application/json", "user-agent": "Dify", @@ -42,10 +43,7 @@ class FeishuRequest: } """ url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/access_token/get_tenant_access_token" - payload = { - "app_id": app_id, - "app_secret": app_secret - } + payload = {"app_id": app_id, "app_secret": app_secret} res = self._send_request(url, require_token=False, payload=payload) return res @@ -76,11 +74,7 @@ class FeishuRequest: def write_document(self, document_id: str, content: str, position: str = "start") -> dict: url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/write_document" - payload = { - "document_id": document_id, - "content": content, - "position": position - } + payload = {"document_id": document_id, "content": content, "position": position} res = self._send_request(url, payload=payload) return res.get("data") diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 23e7c0c243..c4983ebc65 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -10,10 +10,9 @@ logger = logging.getLogger(__name__) class ToolFileMessageTransformer: @classmethod - def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], - user_id: str, - tenant_id: str, - conversation_id: str) -> list[ToolInvokeMessage]: + def transform_tool_invoke_messages( + cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str + ) -> list[ToolInvokeMessage]: """ Transform tool message and handle file download """ @@ -28,78 +27,88 @@ class ToolFileMessageTransformer: # try to download image try: file = ToolFileManager.create_file_by_url( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=conversation_id, - file_url=message.message + user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message ) url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) except Exception as e: logger.exception(e) - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=f"Failed to download image: {message.message}, you can try to download it yourself.", - meta=message.meta.copy() if message.meta is not None else {}, - save_as=message.save_as, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=f"Failed to download image: {message.message}, you can try to download it yourself.", + meta=message.meta.copy() if message.meta is not None else {}, + save_as=message.save_as, + ) + ) elif message.type == ToolInvokeMessage.MessageType.BLOB: # get mime type and save blob to storage - mimetype = message.meta.get('mime_type', 'octet/stream') + mimetype = message.meta.get("mime_type", "octet/stream") # if message is str, encode it to bytes if isinstance(message.message, str): - message.message = message.message.encode('utf-8') + message.message = message.message.encode("utf-8") file = ToolFileManager.create_file_by_raw( - user_id=user_id, tenant_id=tenant_id, + user_id=user_id, + tenant_id=tenant_id, conversation_id=conversation_id, file_binary=message.message, - mimetype=mimetype + mimetype=mimetype, ) url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) # check if file is image - if 'image' in mimetype: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + if "image" in mimetype: + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) else: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: - file_var = message.meta.get('file_var') + file_var = message.meta.get("file_var") if file_var: if file_var.transfer_method == FileTransferMethod.TOOL_FILE: url = cls.get_tool_file_url(file_var.related_id, file_var.extension) if file_var.type == FileType.IMAGE: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) else: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) else: result.append(message) diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 9e8ef47823..4e226810d6 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -1,7 +1,7 @@ """ - For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. +For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. - Therefore, a model manager is needed to list/invoke/validate models. +Therefore, a model manager is needed to list/invoke/validate models. """ import json @@ -27,52 +27,49 @@ from models.tools import ToolModelInvoke class InvokeModelError(Exception): pass + class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, ) -> int: """ - get max llm context tokens of the model + get max llm context tokens of the model """ model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM, + tenant_id=tenant_id, + model_type=ModelType.LLM, ) if not model_instance: - raise InvokeModelError('Model not found') - + raise InvokeModelError("Model not found") + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) if not schema: - raise InvokeModelError('No model schema found') + raise InvokeModelError("No model schema found") max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 - + return max_tokens @staticmethod - def calculate_tokens( - tenant_id: str, - prompt_messages: list[PromptMessage] - ) -> int: + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: """ - calculate tokens from prompt messages and model parameters + calculate tokens from prompt messages and model parameters """ # get model instance model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM - ) + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) if not model_instance: - raise InvokeModelError('Model not found') - + raise InvokeModelError("Model not found") + # get tokens tokens = model_instance.get_llm_num_tokens(prompt_messages) @@ -80,9 +77,7 @@ class ModelInvocationUtils: @staticmethod def invoke( - user_id: str, tenant_id: str, - tool_type: str, tool_name: str, - prompt_messages: list[PromptMessage] + user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] ) -> LLMResult: """ invoke model with parameters in user's own context @@ -103,15 +98,16 @@ class ModelInvocationUtils: model_manager = ModelManager() # get model instance model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM, + tenant_id=tenant_id, + model_type=ModelType.LLM, ) # get prompt tokens prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) model_parameters = { - 'temperature': 0.8, - 'top_p': 0.8, + "temperature": 0.8, + "top_p": 0.8, } # create tool model invoke @@ -123,14 +119,14 @@ class ModelInvocationUtils: tool_name=tool_name, model_parameters=json.dumps(model_parameters), prompt_messages=json.dumps(jsonable_encoder(prompt_messages)), - model_response='', + model_response="", prompt_tokens=prompt_tokens, answer_tokens=0, answer_unit_price=0, answer_price_unit=0, provider_response_latency=0, total_price=0, - currency='USD', + currency="USD", ) db.session.add(tool_model_invoke) @@ -140,20 +136,24 @@ class ModelInvocationUtils: response: LLMResult = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=[], stop=[], stream=False, user=user_id, callbacks=[] + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], ) except InvokeRateLimitError as e: - raise InvokeModelError(f'Invoke rate limit error: {e}') + raise InvokeModelError(f"Invoke rate limit error: {e}") except InvokeBadRequestError as e: - raise InvokeModelError(f'Invoke bad request error: {e}') + raise InvokeModelError(f"Invoke bad request error: {e}") except InvokeConnectionError as e: - raise InvokeModelError(f'Invoke connection error: {e}') + raise InvokeModelError(f"Invoke connection error: {e}") except InvokeAuthorizationError as e: - raise InvokeModelError('Invoke authorization error') + raise InvokeModelError("Invoke authorization error") except InvokeServerUnavailableError as e: - raise InvokeModelError(f'Invoke server unavailable error: {e}') + raise InvokeModelError(f"Invoke server unavailable error: {e}") except Exception as e: - raise InvokeModelError(f'Invoke error: {e}') + raise InvokeModelError(f"Invoke error: {e}") # update tool model invoke tool_model_invoke.model_response = response.message.content diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index f711f7c9f3..654c9acaf9 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,4 +1,3 @@ - import re import uuid from json import dumps as json_dumps @@ -16,54 +15,56 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro class ApiBasedToolSchemaParser: @staticmethod - def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openapi_to_tool_bundle( + openapi: dict, extra_info: dict = None, warning: dict = None + ) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} # set description to extra_info - extra_info['description'] = openapi['info'].get('description', '') + extra_info["description"] = openapi["info"].get("description", "") - if len(openapi['servers']) == 0: - raise ToolProviderNotFoundError('No server found in the openapi yaml.') + if len(openapi["servers"]) == 0: + raise ToolProviderNotFoundError("No server found in the openapi yaml.") - server_url = openapi['servers'][0]['url'] + server_url = openapi["servers"][0]["url"] # list all interfaces interfaces = [] - for path, path_item in openapi['paths'].items(): - methods = ['get', 'post', 'put', 'delete', 'patch', 'head', 'options', 'trace'] + for path, path_item in openapi["paths"].items(): + methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"] for method in methods: if method in path_item: - interfaces.append({ - 'path': path, - 'method': method, - 'operation': path_item[method], - }) + interfaces.append( + { + "path": path, + "method": method, + "operation": path_item[method], + } + ) # get all parameters bundles = [] for interface in interfaces: # convert parameters parameters = [] - if 'parameters' in interface['operation']: - for parameter in interface['operation']['parameters']: + if "parameters" in interface["operation"]: + for parameter in interface["operation"]["parameters"]: tool_parameter = ToolParameter( - name=parameter['name'], - label=I18nObject( - en_US=parameter['name'], - zh_Hans=parameter['name'] - ), + name=parameter["name"], + label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]), human_description=I18nObject( - en_US=parameter.get('description', ''), - zh_Hans=parameter.get('description', '') + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") ), type=ToolParameter.ToolParameterType.STRING, - required=parameter.get('required', False), + required=parameter.get("required", False), form=ToolParameter.ToolParameterForm.LLM, - llm_description=parameter.get('description'), - default=parameter['schema']['default'] if 'schema' in parameter and 'default' in parameter['schema'] else None, + llm_description=parameter.get("description"), + default=parameter["schema"]["default"] + if "schema" in parameter and "default" in parameter["schema"] + else None, ) - + # check if there is a type typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter) if typ: @@ -72,44 +73,40 @@ class ApiBasedToolSchemaParser: parameters.append(tool_parameter) # create tool bundle # check if there is a request body - if 'requestBody' in interface['operation']: - request_body = interface['operation']['requestBody'] - if 'content' in request_body: - for content_type, content in request_body['content'].items(): + if "requestBody" in interface["operation"]: + request_body = interface["operation"]["requestBody"] + if "content" in request_body: + for content_type, content in request_body["content"].items(): # if there is a reference, get the reference and overwrite the content - if 'schema' not in content: + if "schema" not in content: continue - if '$ref' in content['schema']: + if "$ref" in content["schema"]: # get the reference root = openapi - reference = content['schema']['$ref'].split('/')[1:] + reference = content["schema"]["$ref"].split("/")[1:] for ref in reference: root = root[ref] # overwrite the content - interface['operation']['requestBody']['content'][content_type]['schema'] = root + interface["operation"]["requestBody"]["content"][content_type]["schema"] = root # parse body parameters - if 'schema' in interface['operation']['requestBody']['content'][content_type]: - body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] - required = body_schema.get('required', []) - properties = body_schema.get('properties', {}) + if "schema" in interface["operation"]["requestBody"]["content"][content_type]: + body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) for name, property in properties.items(): tool = ToolParameter( name=name, - label=I18nObject( - en_US=name, - zh_Hans=name - ), + label=I18nObject(en_US=name, zh_Hans=name), human_description=I18nObject( - en_US=property.get('description', ''), - zh_Hans=property.get('description', '') + en_US=property.get("description", ""), zh_Hans=property.get("description", "") ), type=ToolParameter.ToolParameterType.STRING, required=name in required, form=ToolParameter.ToolParameterForm.LLM, - llm_description=property.get('description', ''), - default=property.get('default', None), + llm_description=property.get("description", ""), + default=property.get("default", None), ) # check if there is a type @@ -127,172 +124,176 @@ class ApiBasedToolSchemaParser: parameters_count[parameter.name] += 1 for name, count in parameters_count.items(): if count > 1: - warning['duplicated_parameter'] = f'Parameter {name} is duplicated.' + warning["duplicated_parameter"] = f"Parameter {name} is duplicated." # check if there is a operation id, use $path_$method as operation id if not - if 'operationId' not in interface['operation']: + if "operationId" not in interface["operation"]: # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ - path = interface['path'] - if interface['path'].startswith('/'): - path = interface['path'][1:] + path = interface["path"] + if interface["path"].startswith("/"): + path = interface["path"][1:] # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ - path = re.sub(r'[^a-zA-Z0-9_-]', '', path) + path = re.sub(r"[^a-zA-Z0-9_-]", "", path) if not path: path = str(uuid.uuid4()) - - interface['operation']['operationId'] = f'{path}_{interface["method"]}' - bundles.append(ApiToolBundle( - server_url=server_url + interface['path'], - method=interface['method'], - summary=interface['operation']['description'] if 'description' in interface['operation'] else - interface['operation'].get('summary', None), - operation_id=interface['operation']['operationId'], - parameters=parameters, - author='', - icon=None, - openapi=interface['operation'], - )) + interface["operation"]["operationId"] = f'{path}_{interface["method"]}' + + bundles.append( + ApiToolBundle( + server_url=server_url + interface["path"], + method=interface["method"], + summary=interface["operation"]["description"] + if "description" in interface["operation"] + else interface["operation"].get("summary", None), + operation_id=interface["operation"]["operationId"], + parameters=parameters, + author="", + icon=None, + openapi=interface["operation"], + ) + ) return bundles - + @staticmethod def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: parameter = parameter or {} typ = None - if 'type' in parameter: - typ = parameter['type'] - elif 'schema' in parameter and 'type' in parameter['schema']: - typ = parameter['schema']['type'] - - if typ == 'integer' or typ == 'number': + if "type" in parameter: + typ = parameter["type"] + elif "schema" in parameter and "type" in parameter["schema"]: + typ = parameter["schema"]["type"] + + if typ == "integer" or typ == "number": return ToolParameter.ToolParameterType.NUMBER - elif typ == 'boolean': + elif typ == "boolean": return ToolParameter.ToolParameterType.BOOLEAN - elif typ == 'string': + elif typ == "string": return ToolParameter.ToolParameterType.STRING @staticmethod - def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openapi_yaml_to_tool_bundle( + yaml: str, extra_info: dict = None, warning: dict = None + ) -> list[ApiToolBundle]: """ - parse openapi yaml to tool bundle + parse openapi yaml to tool bundle - :param yaml: the yaml string - :return: the tool bundle + :param yaml: the yaml string + :return: the tool bundle """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} openapi: dict = safe_load(yaml) if openapi is None: - raise ToolApiSchemaError('Invalid openapi yaml.') + raise ToolApiSchemaError("Invalid openapi yaml.") return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) - + @staticmethod def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict: """ - parse swagger to openapi + parse swagger to openapi - :param swagger: the swagger dict - :return: the openapi dict + :param swagger: the swagger dict + :return: the openapi dict """ # convert swagger to openapi - info = swagger.get('info', { - 'title': 'Swagger', - 'description': 'Swagger', - 'version': '1.0.0' - }) + info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"}) - servers = swagger.get('servers', []) + servers = swagger.get("servers", []) if len(servers) == 0: - raise ToolApiSchemaError('No server found in the swagger yaml.') + raise ToolApiSchemaError("No server found in the swagger yaml.") openapi = { - 'openapi': '3.0.0', - 'info': { - 'title': info.get('title', 'Swagger'), - 'description': info.get('description', 'Swagger'), - 'version': info.get('version', '1.0.0') + "openapi": "3.0.0", + "info": { + "title": info.get("title", "Swagger"), + "description": info.get("description", "Swagger"), + "version": info.get("version", "1.0.0"), }, - 'servers': swagger['servers'], - 'paths': {}, - 'components': { - 'schemas': {} - } + "servers": swagger["servers"], + "paths": {}, + "components": {"schemas": {}}, } # check paths - if 'paths' not in swagger or len(swagger['paths']) == 0: - raise ToolApiSchemaError('No paths found in the swagger yaml.') + if "paths" not in swagger or len(swagger["paths"]) == 0: + raise ToolApiSchemaError("No paths found in the swagger yaml.") # convert paths - for path, path_item in swagger['paths'].items(): - openapi['paths'][path] = {} + for path, path_item in swagger["paths"].items(): + openapi["paths"][path] = {} for method, operation in path_item.items(): - if 'operationId' not in operation: - raise ToolApiSchemaError(f'No operationId found in operation {method} {path}.') - - if ('summary' not in operation or len(operation['summary']) == 0) and \ - ('description' not in operation or len(operation['description']) == 0): - warning['missing_summary'] = f'No summary or description found in operation {method} {path}.' - - openapi['paths'][path][method] = { - 'operationId': operation['operationId'], - 'summary': operation.get('summary', ''), - 'description': operation.get('description', ''), - 'parameters': operation.get('parameters', []), - 'responses': operation.get('responses', {}), + if "operationId" not in operation: + raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") + + if ("summary" not in operation or len(operation["summary"]) == 0) and ( + "description" not in operation or len(operation["description"]) == 0 + ): + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + + openapi["paths"][path][method] = { + "operationId": operation["operationId"], + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "parameters": operation.get("parameters", []), + "responses": operation.get("responses", {}), } - if 'requestBody' in operation: - openapi['paths'][path][method]['requestBody'] = operation['requestBody'] + if "requestBody" in operation: + openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # convert definitions - for name, definition in swagger['definitions'].items(): - openapi['components']['schemas'][name] = definition + for name, definition in swagger["definitions"].items(): + openapi["components"]["schemas"][name] = definition return openapi @staticmethod - def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openai_plugin_json_to_tool_bundle( + json: str, extra_info: dict = None, warning: dict = None + ) -> list[ApiToolBundle]: """ - parse openapi plugin yaml to tool bundle + parse openapi plugin yaml to tool bundle - :param json: the json string - :return: the tool bundle + :param json: the json string + :return: the tool bundle """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} try: openai_plugin = json_loads(json) - api = openai_plugin['api'] - api_url = api['url'] - api_type = api['type'] + api = openai_plugin["api"] + api_url = api["url"] + api_type = api["type"] except: - raise ToolProviderNotFoundError('Invalid openai plugin json.') - - if api_type != 'openapi': - raise ToolNotSupportedError('Only openapi is supported now.') - + raise ToolProviderNotFoundError("Invalid openai plugin json.") + + if api_type != "openapi": + raise ToolNotSupportedError("Only openapi is supported now.") + # get openapi yaml - response = get(api_url, headers={ - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ' - }, timeout=5) + response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5) if response.status_code != 200: - raise ToolProviderNotFoundError('cannot get openapi yaml from url.') - - return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning) - - @staticmethod - def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]: - """ - auto parse to tool bundle + raise ToolProviderNotFoundError("cannot get openapi yaml from url.") - :param content: the content - :return: tools bundle, schema_type + return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( + response.text, extra_info=extra_info, warning=warning + ) + + @staticmethod + def auto_parse_to_tool_bundle( + content: str, extra_info: dict = None, warning: dict = None + ) -> tuple[list[ApiToolBundle], str]: + """ + auto parse to tool bundle + + :param content: the content + :return: tools bundle, schema_type """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -301,7 +302,7 @@ class ApiBasedToolSchemaParser: loaded_content = None json_error = None yaml_error = None - + try: loaded_content = json_loads(content) except JSONDecodeError as e: @@ -313,34 +314,46 @@ class ApiBasedToolSchemaParser: except YAMLError as e: yaml_error = e if loaded_content is None: - raise ToolApiSchemaError(f'Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}') + raise ToolApiSchemaError( + f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}" + ) swagger_error = None openapi_error = None openapi_plugin_error = None schema_type = None - + try: - openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(loaded_content, extra_info=extra_info, warning=warning) + openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + loaded_content, extra_info=extra_info, warning=warning + ) schema_type = ApiProviderSchemaType.OPENAPI.value return openapi, schema_type except ToolApiSchemaError as e: openapi_error = e - + # openai parse error, fallback to swagger try: - converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(loaded_content, extra_info=extra_info, warning=warning) + converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( + loaded_content, extra_info=extra_info, warning=warning + ) schema_type = ApiProviderSchemaType.SWAGGER.value - return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(converted_swagger, extra_info=extra_info, warning=warning), schema_type + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + converted_swagger, extra_info=extra_info, warning=warning + ), schema_type except ToolApiSchemaError as e: swagger_error = e - + # swagger parse error, fallback to openai plugin try: - openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(json_dumps(loaded_content), extra_info=extra_info, warning=warning) + openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + json_dumps(loaded_content), extra_info=extra_info, warning=warning + ) return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value except ToolNotSupportedError as e: # maybe it's not plugin at all openapi_plugin_error = e - raise ToolApiSchemaError(f'Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}') + raise ToolApiSchemaError( + f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}" + ) diff --git a/api/core/tools/utils/tool_parameter_converter.py b/api/core/tools/utils/tool_parameter_converter.py index 6f88eeaa0a..6f7610651c 100644 --- a/api/core/tools/utils/tool_parameter_converter.py +++ b/api/core/tools/utils/tool_parameter_converter.py @@ -7,16 +7,18 @@ class ToolParameterConverter: @staticmethod def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str: match parameter_type: - case ToolParameter.ToolParameterType.STRING \ - | ToolParameter.ToolParameterType.SECRET_INPUT \ - | ToolParameter.ToolParameterType.SELECT: - return 'string' + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): + return "string" case ToolParameter.ToolParameterType.BOOLEAN: - return 'boolean' + return "boolean" case ToolParameter.ToolParameterType.NUMBER: - return 'number' + return "number" case _: raise ValueError(f"Unsupported parameter type {parameter_type}") @@ -26,11 +28,13 @@ class ToolParameterConverter: # convert tool parameter config to correct type try: match parameter_type: - case ToolParameter.ToolParameterType.STRING \ - | ToolParameter.ToolParameterType.SECRET_INPUT \ - | ToolParameter.ToolParameterType.SELECT: + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): if value is None: - return '' + return "" else: return value if isinstance(value, str) else str(value) @@ -41,9 +45,9 @@ class ToolParameterConverter: # Allowed YAML boolean value strings: https://yaml.org/type/bool.html # and also '0' for False and '1' for True match value.lower(): - case 'true' | 'yes' | 'y' | '1': + case "true" | "yes" | "y" | "1": return True - case 'false' | 'no' | 'n' | '0': + case "false" | "no" | "n" | "0": return False case _: return bool(value) @@ -53,8 +57,8 @@ class ToolParameterConverter: case ToolParameter.ToolParameterType.NUMBER: if isinstance(value, int) | isinstance(value, float): return value - elif isinstance(value, str) and value != '': - if '.' in value: + elif isinstance(value, str) and value != "": + if "." in value: return float(value) else: return int(value) diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 150941924d..3639b5fff7 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -32,7 +32,7 @@ TEXT: def page_result(text: str, cursor: int, max_length: int) -> str: """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" - return text[cursor: cursor + max_length] + return text[cursor : cursor + max_length] def get_url(url: str, user_agent: str = None) -> str: @@ -49,15 +49,15 @@ def get_url(url: str, user_agent: str = None) -> str: if response.status_code == 200: # check content-type - content_type = response.headers.get('Content-Type') + content_type = response.headers.get("Content-Type") if content_type: - main_content_type = response.headers.get('Content-Type').split(';')[0].strip() + main_content_type = response.headers.get("Content-Type").split(";")[0].strip() else: - content_disposition = response.headers.get('Content-Disposition', '') + content_disposition = response.headers.get("Content-Disposition", "") filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - extension = re.search(r'\.(\w+)$', filename) + extension = re.search(r"\.(\w+)$", filename) if extension: main_content_type = mimetypes.guess_type(filename)[0] @@ -78,7 +78,7 @@ def get_url(url: str, user_agent: str = None) -> str: # Detect encoding using chardet detected_encoding = chardet.detect(response.content) - encoding = detected_encoding['encoding'] + encoding = detected_encoding["encoding"] if encoding: try: content = response.content.decode(encoding) @@ -89,29 +89,29 @@ def get_url(url: str, user_agent: str = None) -> str: a = extract_using_readabilipy(content) - if not a['plain_text'] or not a['plain_text'].strip(): - return '' + if not a["plain_text"] or not a["plain_text"].strip(): + return "" res = FULL_TEMPLATE.format( - title=a['title'], - authors=a['byline'], - publish_date=a['date'], + title=a["title"], + authors=a["byline"], + publish_date=a["date"], top_image="", - text=a['plain_text'] if a['plain_text'] else "", + text=a["plain_text"] if a["plain_text"] else "", ) return res def extract_using_readabilipy(html): - with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html: + with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: f_html.write(html) f_html.close() html_path = f_html.name # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file article_json_path = html_path + ".json" - jsdir = os.path.join(find_module_path('readabilipy'), 'javascript') + jsdir = os.path.join(find_module_path("readabilipy"), "javascript") with chdir(jsdir): subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) @@ -129,7 +129,7 @@ def extract_using_readabilipy(html): "date": None, "content": None, "plain_content": None, - "plain_text": None + "plain_text": None, } # Populate article fields from readability fields where present if input_json: @@ -145,7 +145,7 @@ def extract_using_readabilipy(html): article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) if input_json.get("textContent"): article_json["plain_text"] = input_json["textContent"] - article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"]) + article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) return article_json @@ -158,6 +158,7 @@ def find_module_path(module_name): return None + @contextmanager def chdir(path): """Change directory in context and return to original on exit""" @@ -172,12 +173,14 @@ def chdir(path): def extract_text_blocks_as_plain_text(paragraph_html): # Load article as DOM - soup = BeautifulSoup(paragraph_html, 'html.parser') + soup = BeautifulSoup(paragraph_html, "html.parser") # Select all lists - list_elements = soup.find_all(['ul', 'ol']) + list_elements = soup.find_all(["ul", "ol"]) # Prefix text in all list items with "* " and make lists paragraphs for list_element in list_elements: - plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')]))) + plain_items = "".join( + list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) + ) list_element.string = plain_items list_element.name = "p" # Select all text blocks @@ -204,7 +207,7 @@ def plain_text_leaf_node(element): def plain_content(readability_content, content_digests, node_indexes): # Load article as DOM - soup = BeautifulSoup(readability_content, 'html.parser') + soup = BeautifulSoup(readability_content, "html.parser") # Make all elements plain elements = plain_elements(soup.contents, content_digests, node_indexes) if node_indexes: @@ -217,8 +220,7 @@ def plain_content(readability_content, content_digests, node_indexes): def plain_elements(elements, content_digests, node_indexes): # Get plain content versions of all elements - elements = [plain_element(element, content_digests, node_indexes) - for element in elements] + elements = [plain_element(element, content_digests, node_indexes) for element in elements] if content_digests: # Add content digest attribute to nodes elements = [add_content_digest(element) for element in elements] @@ -258,11 +260,9 @@ def add_node_indexes(element, node_index="0"): # Add index to current element element["data-node-index"] = node_index # Add index to child elements - for local_idx, child in enumerate( - [c for c in element.contents if not is_text(c)], start=1): + for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): # Can't add attributes to leaf string types - child_index = "{stem}.{local}".format( - stem=node_index, local=local_idx) + child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) add_node_indexes(child, node_index=child_index) return element @@ -284,11 +284,16 @@ def strip_control_characters(text): # [Cn]: Other, Not Assigned # [Co]: Other, Private Use # [Cs]: Other, Surrogate - control_chars = {'Cc', 'Cf', 'Cn', 'Co', 'Cs'} - retained_chars = ['\t', '\n', '\r', '\f'] + control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} + retained_chars = ["\t", "\n", "\r", "\f"] # Remove non-printing control characters - return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text]) + return "".join( + [ + "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char + for char in text + ] + ) def normalize_unicode(text): @@ -305,8 +310,9 @@ def normalize_whitespace(text): text = text.strip() return text + def is_leaf(element): - return (element.name in ['p', 'li']) + return element.name in ["p", "li"] def is_text(element): @@ -330,7 +336,7 @@ def content_digest(element): if trimmed_string == "": digest = "" else: - digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest() + digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() else: contents = element.contents num_contents = len(contents) @@ -343,9 +349,8 @@ def content_digest(element): else: # Build content digest from the "non-empty" digests of child nodes digest = hashlib.sha256() - child_digests = list( - filter(lambda x: x != "", [content_digest(content) for content in contents])) + child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) for child in child_digests: - digest.update(child.encode('utf-8')) + digest.update(child.encode("utf-8")) digest = digest.hexdigest() return digest diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index ff5505bbbf..94d9fd9eb9 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -10,27 +10,25 @@ class WorkflowToolConfigurationUtils: """ for configuration in configurations: if not WorkflowToolParameterConfiguration(**configuration): - raise ValueError('invalid parameter configuration') + raise ValueError("invalid parameter configuration") @classmethod def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]: """ get workflow graph variables """ - nodes = graph.get('nodes', []) - start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None) + nodes = graph.get("nodes", []) + start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) if not start_node: return [] - return [ - VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', []) - ] - + return [VariableEntity(**variable) for variable in start_node.get("data", {}).get("variables", [])] + @classmethod - def check_is_synced(cls, - variables: list[VariableEntity], - tool_configurations: list[WorkflowToolParameterConfiguration]) -> None: + def check_is_synced( + cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] + ) -> None: """ check is synced @@ -39,10 +37,10 @@ class WorkflowToolConfigurationUtils: variable_names = [variable.variable for variable in variables] if len(tool_configurations) != len(variables): - raise ValueError('parameter configuration mismatch, please republish the tool to update') - + raise ValueError("parameter configuration mismatch, please republish the tool to update") + for parameter in tool_configurations: if parameter.name not in variable_names: - raise ValueError('parameter configuration mismatch, please republish the tool to update') + raise ValueError("parameter configuration mismatch, please republish the tool to update") - return True \ No newline at end of file + return True diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index f751c43096..bcb061376d 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -18,12 +18,12 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any :return: an object of the YAML content """ try: - with open(file_path, encoding='utf-8') as yaml_file: + with open(file_path, encoding="utf-8") as yaml_file: try: yaml_content = yaml.safe_load(yaml_file) return yaml_content if yaml_content else default_value except Exception as e: - raise YAMLError(f'Failed to load YAML file {file_path}: {e}') + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") except Exception as e: if ignore_error: return default_value diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 9015eea85c..83086d1afc 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -5,10 +5,7 @@ from core.workflow.graph_engine.entities.event import GraphEngineEvent class WorkflowCallback(ABC): @abstractmethod - def on_event( - self, - event: GraphEngineEvent - ) -> None: + def on_event(self, event: GraphEngineEvent) -> None: """ Published event """ diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index e7e6710cbd..2a864dd7a8 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -8,9 +8,11 @@ class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + class BaseIterationNodeData(BaseNodeData): start_node_id: Optional[str] = None + class BaseIterationState(BaseModel): iteration_node_id: str index: int @@ -19,4 +21,4 @@ class BaseIterationState(BaseModel): class MetaData(BaseModel): pass - metadata: MetaData \ No newline at end of file + metadata: MetaData diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 5e2a5cb466..5353b99ed3 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -12,28 +12,28 @@ class NodeType(Enum): Node Types. """ - START = 'start' - END = 'end' - ANSWER = 'answer' - LLM = 'llm' - KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' - IF_ELSE = 'if-else' - CODE = 'code' - TEMPLATE_TRANSFORM = 'template-transform' - QUESTION_CLASSIFIER = 'question-classifier' - HTTP_REQUEST = 'http-request' - TOOL = 'tool' - VARIABLE_AGGREGATOR = 'variable-aggregator' + START = "start" + END = "end" + ANSWER = "answer" + LLM = "llm" + KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + IF_ELSE = "if-else" + CODE = "code" + TEMPLATE_TRANSFORM = "template-transform" + QUESTION_CLASSIFIER = "question-classifier" + HTTP_REQUEST = "http-request" + TOOL = "tool" + VARIABLE_AGGREGATOR = "variable-aggregator" # TODO: merge this into VARIABLE_AGGREGATOR - VARIABLE_ASSIGNER = 'variable-assigner' - LOOP = 'loop' - ITERATION = 'iteration' - ITERATION_START = 'iteration-start' # fake start node for iteration - PARAMETER_EXTRACTOR = 'parameter-extractor' - CONVERSATION_VARIABLE_ASSIGNER = 'assigner' + VARIABLE_ASSIGNER = "variable-assigner" + LOOP = "loop" + ITERATION = "iteration" + ITERATION_START = "iteration-start" # fake start node for iteration + PARAMETER_EXTRACTOR = "parameter-extractor" + CONVERSATION_VARIABLE_ASSIGNER = "assigner" @classmethod - def value_of(cls, value: str) -> 'NodeType': + def value_of(cls, value: str) -> "NodeType": """ Get value of given node type. @@ -43,7 +43,7 @@ class NodeType(Enum): for node_type in cls: if node_type.value == value: return node_type - raise ValueError(f'invalid node type value {value}') + raise ValueError(f"invalid node type value {value}") class NodeRunMetadataKey(Enum): @@ -51,16 +51,16 @@ class NodeRunMetadataKey(Enum): Node Run Metadata Key. """ - TOTAL_TOKENS = 'total_tokens' - TOTAL_PRICE = 'total_price' - CURRENCY = 'currency' - TOOL_INFO = 'tool_info' - ITERATION_ID = 'iteration_id' - ITERATION_INDEX = 'iteration_index' - PARALLEL_ID = 'parallel_id' - PARALLEL_START_NODE_ID = 'parallel_start_node_id' - PARENT_PARALLEL_ID = 'parent_parallel_id' - PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id' + TOTAL_TOKENS = "total_tokens" + TOTAL_PRICE = "total_price" + CURRENCY = "currency" + TOOL_INFO = "tool_info" + ITERATION_ID = "iteration_id" + ITERATION_INDEX = "iteration_index" + PARALLEL_ID = "parallel_id" + PARALLEL_START_NODE_ID = "parallel_start_node_id" + PARENT_PARALLEL_ID = "parent_parallel_id" + PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" class NodeRunResult(BaseModel): @@ -85,6 +85,7 @@ class UserFrom(Enum): """ User from """ + ACCOUNT = "account" END_USER = "end-user" diff --git a/api/core/workflow/entities/variable_entities.py b/api/core/workflow/entities/variable_entities.py index 19d9af2a61..1dfb1852f8 100644 --- a/api/core/workflow/entities/variable_entities.py +++ b/api/core/workflow/entities/variable_entities.py @@ -5,5 +5,6 @@ class VariableSelector(BaseModel): """ Variable Selector. """ + variable: str value_selector: list[str] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 48a20d25ae..b94b7f7198 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -23,23 +23,19 @@ class VariablePool(BaseModel): # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. variable_dictionary: dict[str, dict[int, Segment]] = Field( - description='Variables mapping', - default=defaultdict(dict) + description="Variables mapping", default=defaultdict(dict) ) # TODO: This user inputs is not used for pool. user_inputs: Mapping[str, Any] = Field( - description='User inputs', + description="User inputs", ) system_variables: Mapping[SystemVariableKey, Any] = Field( - description='System variables', + description="System variables", ) - environment_variables: Sequence[Variable] = Field( - description="Environment variables.", - default_factory=list - ) + environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list) conversation_variables: Sequence[Variable] | None = None diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 4bf4e454bb..0a1eb57de4 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -46,13 +46,16 @@ class WorkflowRunState: current_iteration_state: Optional[BaseIterationState] - def __init__(self, workflow: Workflow, - start_at: float, - variable_pool: VariablePool, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - workflow_call_depth: int): + def __init__( + self, + workflow: Workflow, + start_at: float, + variable_pool: VariablePool, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + workflow_call_depth: int, + ): self.workflow_id = workflow.id self.tenant_id = workflow.tenant_id self.app_id = workflow.app_id diff --git a/api/core/workflow/graph_engine/condition_handlers/base_handler.py b/api/core/workflow/graph_engine/condition_handlers/base_handler.py index 4099def4e2..697392b2a3 100644 --- a/api/core/workflow/graph_engine/condition_handlers/base_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/base_handler.py @@ -8,19 +8,13 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta class RunConditionHandler(ABC): - def __init__(self, - init_params: GraphInitParams, - graph: Graph, - condition: RunCondition): + def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition): self.init_params = init_params self.graph = graph self.condition = condition @abstractmethod - def check(self, - graph_runtime_state: GraphRuntimeState, - previous_route_node_state: RouteNodeState - ) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: """ Check if the condition can be executed diff --git a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py index 705eb908b1..af695df7d8 100644 --- a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py @@ -4,10 +4,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta class BranchIdentifyRunConditionHandler(RunConditionHandler): - - def check(self, - graph_runtime_state: GraphRuntimeState, - previous_route_node_state: RouteNodeState) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: """ Check if the condition can be executed diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index 1edaf92da7..eda5fe079c 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -5,10 +5,7 @@ from core.workflow.utils.condition.processor import ConditionProcessor class ConditionRunConditionHandlerHandler(RunConditionHandler): - def check(self, - graph_runtime_state: GraphRuntimeState, - previous_route_node_state: RouteNodeState - ) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: """ Check if the condition can be executed @@ -22,8 +19,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler): # process condition condition_processor = ConditionProcessor() input_conditions, group_result = condition_processor.process_conditions( - variable_pool=graph_runtime_state.variable_pool, - conditions=self.condition.conditions + variable_pool=graph_runtime_state.variable_pool, conditions=self.condition.conditions ) # Apply the logical operator for the current case diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py index 2eb2e58bfc..1c9237d82f 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py @@ -9,9 +9,7 @@ from core.workflow.graph_engine.entities.run_condition import RunCondition class ConditionManager: @staticmethod def get_condition_handler( - init_params: GraphInitParams, - graph: Graph, - run_condition: RunCondition + init_params: GraphInitParams, graph: Graph, run_condition: RunCondition ) -> RunConditionHandler: """ Get condition handler @@ -22,14 +20,6 @@ class ConditionManager: :return: condition handler """ if run_condition.type == "branch_identify": - return BranchIdentifyRunConditionHandler( - init_params=init_params, - graph=graph, - condition=run_condition - ) + return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition) else: - return ConditionRunConditionHandlerHandler( - init_params=init_params, - graph=graph, - condition=run_condition - ) + return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 49007b870d..b54595f780 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -34,38 +34,25 @@ class Graph(BaseModel): root_node_id: str = Field(..., description="root node id of the graph") node_ids: list[str] = Field(default_factory=list, description="graph node ids") node_id_config_mapping: dict[str, dict] = Field( - default_factory=list, - description="node configs mapping (node id: node config)" + default_factory=list, description="node configs mapping (node id: node config)" ) edge_mapping: dict[str, list[GraphEdge]] = Field( - default_factory=dict, - description="graph edge mapping (source node id: edges)" + default_factory=dict, description="graph edge mapping (source node id: edges)" ) reverse_edge_mapping: dict[str, list[GraphEdge]] = Field( - default_factory=dict, - description="reverse graph edge mapping (target node id: edges)" + default_factory=dict, description="reverse graph edge mapping (target node id: edges)" ) parallel_mapping: dict[str, GraphParallel] = Field( - default_factory=dict, - description="graph parallel mapping (parallel id: parallel)" + default_factory=dict, description="graph parallel mapping (parallel id: parallel)" ) node_parallel_mapping: dict[str, str] = Field( - default_factory=dict, - description="graph node parallel mapping (node id: parallel id)" - ) - answer_stream_generate_routes: AnswerStreamGenerateRoute = Field( - ..., - description="answer stream generate routes" - ) - end_stream_param: EndStreamParam = Field( - ..., - description="end stream param" + default_factory=dict, description="graph node parallel mapping (node id: parallel id)" ) + answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes") + end_stream_param: EndStreamParam = Field(..., description="end stream param") @classmethod - def init(cls, - graph_config: Mapping[str, Any], - root_node_id: Optional[str] = None) -> "Graph": + def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph": """ Init graph @@ -74,7 +61,7 @@ class Graph(BaseModel): :return: graph """ # edge configs - edge_configs = graph_config.get('edges') + edge_configs = graph_config.get("edges") if edge_configs is None: edge_configs = [] @@ -85,14 +72,14 @@ class Graph(BaseModel): reverse_edge_mapping: dict[str, list[GraphEdge]] = {} target_edge_ids = set() for edge_config in edge_configs: - source_node_id = edge_config.get('source') + source_node_id = edge_config.get("source") if not source_node_id: continue if source_node_id not in edge_mapping: edge_mapping[source_node_id] = [] - target_node_id = edge_config.get('target') + target_node_id = edge_config.get("target") if not target_node_id: continue @@ -107,23 +94,18 @@ class Graph(BaseModel): # parse run condition run_condition = None - if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source': - run_condition = RunCondition( - type='branch_identify', - branch_identify=edge_config.get('sourceHandle') - ) + if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source": + run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle")) graph_edge = GraphEdge( - source_node_id=source_node_id, - target_node_id=target_node_id, - run_condition=run_condition + source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition ) edge_mapping[source_node_id].append(graph_edge) reverse_edge_mapping[target_node_id].append(graph_edge) # node configs - node_configs = graph_config.get('nodes') + node_configs = graph_config.get("nodes") if not node_configs: raise ValueError("Graph must have at least one node") @@ -133,7 +115,7 @@ class Graph(BaseModel): root_node_configs = [] all_node_id_config_mapping: dict[str, dict] = {} for node_config in node_configs: - node_id = node_config.get('id') + node_id = node_config.get("id") if not node_id: continue @@ -142,30 +124,29 @@ class Graph(BaseModel): all_node_id_config_mapping[node_id] = node_config - root_node_ids = [node_config.get('id') for node_config in root_node_configs] + root_node_ids = [node_config.get("id") for node_config in root_node_configs] # fetch root node if not root_node_id: # if no root node id, use the START type node as root node - root_node_id = next((node_config.get("id") for node_config in root_node_configs - if node_config.get('data', {}).get('type', '') == NodeType.START.value), None) + root_node_id = next( + ( + node_config.get("id") + for node_config in root_node_configs + if node_config.get("data", {}).get("type", "") == NodeType.START.value + ), + None, + ) if not root_node_id or root_node_id not in root_node_ids: raise ValueError(f"Root node id {root_node_id} not found in the graph") - + # Check whether it is connected to the previous node - cls._check_connected_to_previous_node( - route=[root_node_id], - edge_mapping=edge_mapping - ) + cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping) # fetch all node ids from root node node_ids = [root_node_id] - cls._recursively_add_node_ids( - node_ids=node_ids, - edge_mapping=edge_mapping, - node_id=root_node_id - ) + cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id) node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids} @@ -177,29 +158,26 @@ class Graph(BaseModel): reverse_edge_mapping=reverse_edge_mapping, start_node_id=root_node_id, parallel_mapping=parallel_mapping, - node_parallel_mapping=node_parallel_mapping + node_parallel_mapping=node_parallel_mapping, ) # Check if it exceeds N layers of parallel for parallel in parallel_mapping.values(): if parallel.parent_parallel_id: cls._check_exceed_parallel_limit( - parallel_mapping=parallel_mapping, - level_limit=3, - parent_parallel_id=parallel.parent_parallel_id + parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id ) # init answer stream generate routes answer_stream_generate_routes = AnswerStreamGeneratorRouter.init( - node_id_config_mapping=node_id_config_mapping, - reverse_edge_mapping=reverse_edge_mapping + node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping ) # init end stream param end_stream_param = EndStreamGeneratorRouter.init( node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - node_parallel_mapping=node_parallel_mapping + node_parallel_mapping=node_parallel_mapping, ) # init graph @@ -212,14 +190,14 @@ class Graph(BaseModel): parallel_mapping=parallel_mapping, node_parallel_mapping=node_parallel_mapping, answer_stream_generate_routes=answer_stream_generate_routes, - end_stream_param=end_stream_param + end_stream_param=end_stream_param, ) return graph - def add_extra_edge(self, source_node_id: str, - target_node_id: str, - run_condition: Optional[RunCondition] = None) -> None: + def add_extra_edge( + self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None + ) -> None: """ Add extra edge to the graph @@ -237,9 +215,7 @@ class Graph(BaseModel): return graph_edge = GraphEdge( - source_node_id=source_node_id, - target_node_id=target_node_id, - run_condition=run_condition + source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition ) self.edge_mapping[source_node_id].append(graph_edge) @@ -254,17 +230,18 @@ class Graph(BaseModel): for node_id in self.node_ids: if node_id not in self.edge_mapping: leaf_node_ids.append(node_id) - elif (len(self.edge_mapping[node_id]) == 1 - and self.edge_mapping[node_id][0].target_node_id == self.root_node_id): + elif ( + len(self.edge_mapping[node_id]) == 1 + and self.edge_mapping[node_id][0].target_node_id == self.root_node_id + ): leaf_node_ids.append(node_id) return leaf_node_ids @classmethod - def _recursively_add_node_ids(cls, - node_ids: list[str], - edge_mapping: dict[str, list[GraphEdge]], - node_id: str) -> None: + def _recursively_add_node_ids( + cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str + ) -> None: """ Recursively add node ids @@ -278,17 +255,11 @@ class Graph(BaseModel): node_ids.append(graph_edge.target_node_id) cls._recursively_add_node_ids( - node_ids=node_ids, - edge_mapping=edge_mapping, - node_id=graph_edge.target_node_id + node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id ) @classmethod - def _check_connected_to_previous_node( - cls, - route: list[str], - edge_mapping: dict[str, list[GraphEdge]] - ) -> None: + def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None: """ Check whether it is connected to the previous node """ @@ -299,7 +270,9 @@ class Graph(BaseModel): continue if graph_edge.target_node_id in route: - raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.") + raise ValueError( + f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph." + ) new_route = route[:] new_route.append(graph_edge.target_node_id) @@ -316,7 +289,7 @@ class Graph(BaseModel): start_node_id: str, parallel_mapping: dict[str, GraphParallel], node_parallel_mapping: dict[str, str], - parent_parallel: Optional[GraphParallel] = None + parent_parallel: Optional[GraphParallel] = None, ) -> None: """ Recursively add parallel ids @@ -355,14 +328,14 @@ class Graph(BaseModel): parallel = GraphParallel( start_from_node_id=start_node_id, parent_parallel_id=parent_parallel.id if parent_parallel else None, - parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None + parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None, ) parallel_mapping[parallel.id] = parallel in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( edge_mapping=edge_mapping, reverse_edge_mapping=reverse_edge_mapping, - parallel_branch_node_ids=parallel_branch_node_ids + parallel_branch_node_ids=parallel_branch_node_ids, ) # collect all branches node ids @@ -403,14 +376,25 @@ class Graph(BaseModel): continue if ( - (node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id) - or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id) + ( + node_parallel_mapping.get(target_node_id) + and node_parallel_mapping.get(target_node_id) == parent_parallel_id + ) + or ( + parent_parallel + and parent_parallel.end_to_node_id + and target_node_id == parent_parallel.end_to_node_id + ) or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) ): outside_parallel_target_node_ids.add(target_node_id) if len(outside_parallel_target_node_ids) == 1: - if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id: + if ( + parent_parallel + and parent_parallel.end_to_node_id + and parallel.end_to_node_id == parent_parallel.end_to_node_id + ): parallel.end_to_node_id = None else: parallel.end_to_node_id = outside_parallel_target_node_ids.pop() @@ -420,18 +404,20 @@ class Graph(BaseModel): if parallel: current_parallel = parallel elif parent_parallel: - if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id): + if not parent_parallel.end_to_node_id or ( + parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id + ): current_parallel = parent_parallel else: # fetch parent parallel's parent parallel parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id if parent_parallel_parent_parallel_id: parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) - if ( - parent_parallel_parent_parallel - and ( - not parent_parallel_parent_parallel.end_to_node_id - or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id) + if parent_parallel_parent_parallel and ( + not parent_parallel_parent_parallel.end_to_node_id + or ( + parent_parallel_parent_parallel.end_to_node_id + and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id ) ): current_parallel = parent_parallel_parent_parallel @@ -442,7 +428,7 @@ class Graph(BaseModel): start_node_id=graph_edge.target_node_id, parallel_mapping=parallel_mapping, node_parallel_mapping=node_parallel_mapping, - parent_parallel=current_parallel + parent_parallel=current_parallel, ) @classmethod @@ -451,7 +437,7 @@ class Graph(BaseModel): parallel_mapping: dict[str, GraphParallel], level_limit: int, parent_parallel_id: str, - current_level: int = 1 + current_level: int = 1, ) -> None: """ Check if it exceeds N layers of parallel @@ -459,25 +445,27 @@ class Graph(BaseModel): parent_parallel = parallel_mapping.get(parent_parallel_id) if not parent_parallel: return - + current_level += 1 if current_level > level_limit: raise ValueError(f"Exceeds {level_limit} layers of parallel") - + if parent_parallel.parent_parallel_id: cls._check_exceed_parallel_limit( parallel_mapping=parallel_mapping, level_limit=level_limit, parent_parallel_id=parent_parallel.parent_parallel_id, - current_level=current_level + current_level=current_level, ) @classmethod - def _recursively_add_parallel_node_ids(cls, - branch_node_ids: list[str], - edge_mapping: dict[str, list[GraphEdge]], - merge_node_id: str, - start_node_id: str) -> None: + def _recursively_add_parallel_node_ids( + cls, + branch_node_ids: list[str], + edge_mapping: dict[str, list[GraphEdge]], + merge_node_id: str, + start_node_id: str, + ) -> None: """ Recursively add node ids @@ -487,21 +475,22 @@ class Graph(BaseModel): :param start_node_id: start node id """ for graph_edge in edge_mapping.get(start_node_id, []): - if (graph_edge.target_node_id != merge_node_id - and graph_edge.target_node_id not in branch_node_ids): + if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids: branch_node_ids.append(graph_edge.target_node_id) cls._recursively_add_parallel_node_ids( branch_node_ids=branch_node_ids, edge_mapping=edge_mapping, merge_node_id=merge_node_id, - start_node_id=graph_edge.target_node_id + start_node_id=graph_edge.target_node_id, ) @classmethod - def _fetch_all_node_ids_in_parallels(cls, - edge_mapping: dict[str, list[GraphEdge]], - reverse_edge_mapping: dict[str, list[GraphEdge]], - parallel_branch_node_ids: list[str]) -> dict[str, list[str]]: + def _fetch_all_node_ids_in_parallels( + cls, + edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], + parallel_branch_node_ids: list[str], + ) -> dict[str, list[str]]: """ Fetch all node ids in parallels """ @@ -513,7 +502,7 @@ class Graph(BaseModel): cls._recursively_fetch_routes( edge_mapping=edge_mapping, start_node_id=parallel_branch_node_id, - routes_node_ids=routes_node_ids[parallel_branch_node_id] + routes_node_ids=routes_node_ids[parallel_branch_node_id], ) # fetch leaf node ids from routes node ids @@ -529,13 +518,13 @@ class Graph(BaseModel): for branch_node_id2, inner_route2 in routes_node_ids.items(): if ( - branch_node_id != branch_node_id2 + branch_node_id != branch_node_id2 and node_id in inner_route2 and len(reverse_edge_mapping.get(node_id, [])) > 1 and cls._is_node_in_routes( reverse_edge_mapping=reverse_edge_mapping, start_node_id=node_id, - routes_node_ids=routes_node_ids + routes_node_ids=routes_node_ids, ) ): if node_id not in merge_branch_node_ids: @@ -551,23 +540,18 @@ class Graph(BaseModel): for node_id, branch_node_ids in merge_branch_node_ids.items(): for node_id2, branch_node_ids2 in merge_branch_node_ids.items(): if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2): - if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids: + if (node_id, node_id2) not in duplicate_end_node_ids and ( + node_id2, + node_id, + ) not in duplicate_end_node_ids: duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids - + for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): # check which node is after - if cls._is_node2_after_node1( - node1_id=node_id, - node2_id=node_id2, - edge_mapping=edge_mapping - ): + if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping): if node_id in merge_branch_node_ids: del merge_branch_node_ids[node_id2] - elif cls._is_node2_after_node1( - node1_id=node_id2, - node2_id=node_id, - edge_mapping=edge_mapping - ): + elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping): if node_id2 in merge_branch_node_ids: del merge_branch_node_ids[node_id] @@ -599,16 +583,15 @@ class Graph(BaseModel): branch_node_ids=in_branch_node_ids[branch_node_id], edge_mapping=edge_mapping, merge_node_id=merge_node_id, - start_node_id=branch_node_id + start_node_id=branch_node_id, ) return in_branch_node_ids @classmethod - def _recursively_fetch_routes(cls, - edge_mapping: dict[str, list[GraphEdge]], - start_node_id: str, - routes_node_ids: list[str]) -> None: + def _recursively_fetch_routes( + cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str] + ) -> None: """ Recursively fetch route """ @@ -621,28 +604,25 @@ class Graph(BaseModel): routes_node_ids.append(graph_edge.target_node_id) cls._recursively_fetch_routes( - edge_mapping=edge_mapping, - start_node_id=graph_edge.target_node_id, - routes_node_ids=routes_node_ids + edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids ) @classmethod - def _is_node_in_routes(cls, - reverse_edge_mapping: dict[str, list[GraphEdge]], - start_node_id: str, - routes_node_ids: dict[str, list[str]]) -> bool: + def _is_node_in_routes( + cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]] + ) -> bool: """ Recursively check if the node is in the routes """ if start_node_id not in reverse_edge_mapping: return False - + all_routes_node_ids = set() parallel_start_node_ids: dict[str, list[str]] = {} for branch_node_id, node_ids in routes_node_ids.items(): for node_id in node_ids: all_routes_node_ids.add(node_id) - + if branch_node_id in reverse_edge_mapping: for graph_edge in reverse_edge_mapping[branch_node_id]: if graph_edge.source_node_id not in parallel_start_node_ids: @@ -655,38 +635,34 @@ class Graph(BaseModel): if set(branch_node_ids) == set(routes_node_ids.keys()): parallel_start_node_id = p_start_node_id return True - + if not parallel_start_node_id: raise Exception("Parallel start node id not found") - + for graph_edge in reverse_edge_mapping[start_node_id]: - if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id: + if ( + graph_edge.source_node_id not in all_routes_node_ids + or graph_edge.source_node_id != parallel_start_node_id + ): return False - + return True @classmethod - def _is_node2_after_node1( - cls, - node1_id: str, - node2_id: str, - edge_mapping: dict[str, list[GraphEdge]] - ) -> bool: + def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool: """ is node2 after node1 """ if node1_id not in edge_mapping: return False - + for graph_edge in edge_mapping[node1_id]: if graph_edge.target_node_id == node2_id: return True - + if cls._is_node2_after_node1( - node1_id=graph_edge.target_node_id, - node2_id=node2_id, - edge_mapping=edge_mapping + node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping ): return True - - return False \ No newline at end of file + + return False diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index c7d484ddf5..afc09bfac5 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -10,7 +10,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRoute class GraphRuntimeState(BaseModel): variable_pool: VariablePool = Field(..., description="variable pool") """variable pool""" - + start_at: float = Field(..., description="start time") """start time""" total_tokens: int = 0 diff --git a/api/core/workflow/graph_engine/entities/run_condition.py b/api/core/workflow/graph_engine/entities/run_condition.py index 0362343568..eedce8842b 100644 --- a/api/core/workflow/graph_engine/entities/run_condition.py +++ b/api/core/workflow/graph_engine/entities/run_condition.py @@ -18,4 +18,4 @@ class RunCondition(BaseModel): @property def hash(self) -> str: - return hashlib.sha256(self.model_dump_json().encode()).hexdigest() \ No newline at end of file + return hashlib.sha256(self.model_dump_json().encode()).hexdigest() diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index b5d6e4c09d..8fc8047426 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -68,13 +68,11 @@ class RouteNodeState(BaseModel): class RuntimeRouteState(BaseModel): routes: dict[str, list[str]] = Field( - default_factory=dict, - description="graph state routes (source_node_state_id: target_node_state_id)" + default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)" ) node_state_mapping: dict[str, RouteNodeState] = Field( - default_factory=dict, - description="node state mapping (route_node_state_id: route_node_state)" + default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)" ) def create_node_state(self, node_id: str) -> RouteNodeState: @@ -99,13 +97,13 @@ class RuntimeRouteState(BaseModel): self.routes[source_node_state_id].append(target_node_state_id) - def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \ - -> list[RouteNodeState]: + def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]: """ Get routes with node state by source node id :param source_node_state_id: source node state id :return: routes with node state """ - return [self.node_state_mapping[target_state_id] - for target_state_id in self.routes.get(source_node_state_id, [])] + return [ + self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, []) + ] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 65d9ab8446..c6bd122b37 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -48,8 +48,9 @@ logger = logging.getLogger(__name__) class GraphEngineThreadPool(ThreadPoolExecutor): - def __init__(self, max_workers=None, thread_name_prefix='', - initializer=None, initargs=(), max_submit_count=100) -> None: + def __init__( + self, max_workers=None, thread_name_prefix="", initializer=None, initargs=(), max_submit_count=100 + ) -> None: super().__init__(max_workers, thread_name_prefix, initializer, initargs) self.max_submit_count = max_submit_count self.submit_count = 0 @@ -57,9 +58,9 @@ class GraphEngineThreadPool(ThreadPoolExecutor): def submit(self, fn, *args, **kwargs): self.submit_count += 1 self.check_is_full() - + return super().submit(fn, *args, **kwargs) - + def check_is_full(self) -> None: print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}") if self.submit_count > self.max_submit_count: @@ -70,21 +71,21 @@ class GraphEngine: workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} def __init__( - self, - tenant_id: str, - app_id: str, - workflow_type: WorkflowType, - workflow_id: str, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - call_depth: int, - graph: Graph, - graph_config: Mapping[str, Any], - variable_pool: VariablePool, - max_execution_steps: int, - max_execution_time: int, - thread_pool_id: Optional[str] = None + self, + tenant_id: str, + app_id: str, + workflow_type: WorkflowType, + workflow_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + graph: Graph, + graph_config: Mapping[str, Any], + variable_pool: VariablePool, + max_execution_steps: int, + max_execution_time: int, + thread_pool_id: Optional[str] = None, ) -> None: thread_pool_max_submit_count = 100 thread_pool_max_workers = 10 @@ -93,12 +94,14 @@ class GraphEngine: if thread_pool_id: if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping: raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.") - + self.thread_pool_id = thread_pool_id self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id] self.is_main_thread_pool = False else: - self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count) + self.thread_pool = GraphEngineThreadPool( + max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count + ) self.thread_pool_id = str(uuid.uuid4()) self.is_main_thread_pool = True GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool @@ -113,13 +116,10 @@ class GraphEngine: user_id=user_id, user_from=user_from, invoke_from=invoke_from, - call_depth=call_depth + call_depth=call_depth, ) - self.graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=time.perf_counter() - ) + self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) self.max_execution_steps = max_execution_steps self.max_execution_time = max_execution_time @@ -136,37 +136,40 @@ class GraphEngine: stream_processor_cls = EndStreamProcessor stream_processor = stream_processor_cls( - graph=self.graph, - variable_pool=self.graph_runtime_state.variable_pool + graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool ) # run graph - generator = stream_processor.process( - self._run(start_node_id=self.graph.root_node_id) - ) + generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) for item in generator: try: yield item if isinstance(item, NodeRunFailedEvent): - yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.') + yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.") return elif isinstance(item, NodeRunSucceededEvent): if item.node_type == NodeType.END: - self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs - if item.route_node_state.node_run_result - and item.route_node_state.node_run_result.outputs - else {}) + self.graph_runtime_state.outputs = ( + item.route_node_state.node_run_result.outputs + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else {} + ) elif item.node_type == NodeType.ANSWER: if "answer" not in self.graph_runtime_state.outputs: self.graph_runtime_state.outputs["answer"] = "" - self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "") - if item.route_node_state.node_run_result - and item.route_node_state.node_run_result.outputs - else "") - - self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip() + self.graph_runtime_state.outputs["answer"] += "\n" + ( + item.route_node_state.node_run_result.outputs.get("answer", "") + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else "" + ) + + self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs[ + "answer" + ].strip() except Exception as e: logger.exception(f"Graph run failed: {str(e)}") yield GraphRunFailedEvent(error=str(e)) @@ -186,12 +189,12 @@ class GraphEngine: del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] def _run( - self, - start_node_id: str, - in_parallel_id: Optional[str] = None, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None - ) -> Generator[GraphEngineEvent, None, None]: + self, + start_node_id: str, + in_parallel_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent, None, None]: parallel_start_node_id = None if in_parallel_id: parallel_start_node_id = start_node_id @@ -201,31 +204,28 @@ class GraphEngine: while True: # max steps reached if self.graph_runtime_state.node_run_steps > self.max_execution_steps: - raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps)) + raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps)) # or max execution time reached if self._is_timed_out( - start_at=self.graph_runtime_state.start_at, - max_execution_time=self.max_execution_time + start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time ): - raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time)) + raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time)) # init route node state - route_node_state = self.graph_runtime_state.node_run_state.create_node_state( - node_id=next_node_id - ) + route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id) # get node config node_id = route_node_state.node_id node_config = self.graph.node_id_config_mapping.get(node_id) if not node_config: - raise GraphRunFailedError(f'Node {node_id} config not found.') + raise GraphRunFailedError(f"Node {node_id} config not found.") # convert to specific node - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_type = NodeType.value_of(node_config.get("data", {}).get("type")) node_cls = node_classes.get(node_type) if not node_cls: - raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.') + raise GraphRunFailedError(f"Node {node_id} type {node_type} not found.") previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None @@ -237,7 +237,7 @@ class GraphEngine: graph=self.graph, graph_runtime_state=self.graph_runtime_state, previous_node_id=previous_node_id, - thread_pool_id=self.thread_pool_id + thread_pool_id=self.thread_pool_id, ) try: @@ -248,7 +248,7 @@ class GraphEngine: parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) for item in generator: @@ -263,8 +263,7 @@ class GraphEngine: # append route if previous_route_node_state: self.graph_runtime_state.node_run_state.add_route( - source_node_state_id=previous_route_node_state.id, - target_node_state_id=route_node_state.id + source_node_state_id=previous_route_node_state.id, target_node_state_id=route_node_state.id ) except Exception as e: route_node_state.status = RouteNodeState.Status.FAILED @@ -279,13 +278,15 @@ class GraphEngine: parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) raise e # It may not be necessary, but it is necessary. :) - if (self.graph.node_id_config_mapping[next_node_id] - .get("data", {}).get("type", "").lower() == NodeType.END.value): + if ( + self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() + == NodeType.END.value + ): break previous_route_node_state = route_node_state @@ -305,7 +306,7 @@ class GraphEngine: run_condition=edge.run_condition, ).check( graph_runtime_state=self.graph_runtime_state, - previous_route_node_state=previous_route_node_state + previous_route_node_state=previous_route_node_state, ) if not result: @@ -343,14 +344,14 @@ class GraphEngine: if not result: continue - + if len(sub_edge_mappings) == 1: final_node_id = edge.target_node_id else: parallel_generator = self._run_parallel_branches( edge_mappings=sub_edge_mappings, in_parallel_id=in_parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, ) for item in parallel_generator: @@ -369,7 +370,7 @@ class GraphEngine: parallel_generator = self._run_parallel_branches( edge_mappings=edge_mappings, in_parallel_id=in_parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, ) for item in parallel_generator: @@ -383,14 +384,14 @@ class GraphEngine: next_node_id = final_node_id - if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id: + if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, "") != in_parallel_id: break def _run_parallel_branches( - self, - edge_mappings: list[GraphEdge], - in_parallel_id: Optional[str] = None, - parallel_start_node_id: Optional[str] = None, + self, + edge_mappings: list[GraphEdge], + in_parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, ) -> Generator[GraphEngineEvent | str, None, None]: # if nodes has no run conditions, parallel run all nodes parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) @@ -398,14 +399,18 @@ class GraphEngine: node_id = edge_mappings[0].target_node_id node_config = self.graph.node_id_config_mapping.get(node_id) if not node_config: - raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.') + raise GraphRunFailedError( + f"Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches." + ) - node_title = node_config.get('data', {}).get('title') - raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.') + node_title = node_config.get("data", {}).get("title") + raise GraphRunFailedError( + f"Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches." + ) parallel = self.graph.parallel_mapping.get(parallel_id) if not parallel: - raise GraphRunFailedError(f'Parallel {parallel_id} not found.') + raise GraphRunFailedError(f"Parallel {parallel_id} not found.") # run parallel nodes, run in new thread and use queue to get results q: queue.Queue = queue.Queue() @@ -417,19 +422,22 @@ class GraphEngine: for edge in edge_mappings: if ( edge.target_node_id not in self.graph.node_parallel_mapping - or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id + or self.graph.node_parallel_mapping.get(edge.target_node_id, "") != parallel_id ): continue futures.append( - self.thread_pool.submit(self._run_parallel_node, **{ - 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] - 'q': q, - 'parallel_id': parallel_id, - 'parallel_start_node_id': edge.target_node_id, - 'parent_parallel_id': in_parallel_id, - 'parent_parallel_start_node_id': parallel_start_node_id, - }) + self.thread_pool.submit( + self._run_parallel_node, + **{ + "flask_app": current_app._get_current_object(), # type: ignore[attr-defined] + "q": q, + "parallel_id": parallel_id, + "parallel_start_node_id": edge.target_node_id, + "parent_parallel_id": in_parallel_id, + "parent_parallel_start_node_id": parallel_start_node_id, + }, + ) ) succeeded_count = 0 @@ -451,7 +459,7 @@ class GraphEngine: raise GraphRunFailedError(event.error) except queue.Empty: continue - + # wait all threads wait(futures) @@ -461,72 +469,80 @@ class GraphEngine: yield final_node_id def _run_parallel_node( - self, - flask_app: Flask, - q: queue.Queue, - parallel_id: str, - parallel_start_node_id: str, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None, + self, + flask_app: Flask, + q: queue.Queue, + parallel_id: str, + parallel_start_node_id: str, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, ) -> None: """ Run parallel nodes """ with flask_app.app_context(): try: - q.put(ParallelBranchRunStartedEvent( - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id - )) + q.put( + ParallelBranchRunStartedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + ) # run node generator = self._run( start_node_id=parallel_start_node_id, in_parallel_id=parallel_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) for item in generator: q.put(item) # trigger graph run success event - q.put(ParallelBranchRunSucceededEvent( - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id - )) + q.put( + ParallelBranchRunSucceededEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + ) except GraphRunFailedError as e: - q.put(ParallelBranchRunFailedEvent( - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - error=e.error - )) + q.put( + ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=e.error, + ) + ) except Exception as e: logger.exception("Unknown Error when generating in parallel") - q.put(ParallelBranchRunFailedEvent( - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - error=str(e) - )) + q.put( + ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=str(e), + ) + ) finally: db.session.remove() def _run_node( - self, - node_instance: BaseNode, - route_node_state: RouteNodeState, - parallel_id: Optional[str] = None, - parallel_start_node_id: Optional[str] = None, - parent_parallel_id: Optional[str] = None, - parent_parallel_start_node_id: Optional[str] = None, + self, + node_instance: BaseNode, + route_node_state: RouteNodeState, + parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, ) -> Generator[GraphEngineEvent, None, None]: """ Run node @@ -542,7 +558,7 @@ class GraphEngine: parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) db.session.close() @@ -567,7 +583,7 @@ class GraphEngine: if run_result.status == WorkflowNodeExecutionStatus.FAILED: yield NodeRunFailedEvent( - error=route_node_state.failed_reason or 'Unknown error.', + error=route_node_state.failed_reason or "Unknown error.", id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, @@ -576,7 +592,7 @@ class GraphEngine: parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): @@ -596,7 +612,7 @@ class GraphEngine: self._append_variables_recursively( node_id=node_instance.node_id, variable_key_list=[variable_key], - variable_value=variable_value + variable_value=variable_value, ) # add parallel info to run result metadata @@ -608,7 +624,9 @@ class GraphEngine: run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id if parent_parallel_id and parent_parallel_start_node_id: run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( + parent_parallel_start_node_id + ) yield NodeRunSucceededEvent( id=node_instance.id, @@ -619,7 +637,7 @@ class GraphEngine: parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) break @@ -635,7 +653,7 @@ class GraphEngine: parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) elif isinstance(item, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( @@ -649,7 +667,7 @@ class GraphEngine: parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) except GenerateTaskStoppedException: # trigger node run failed event @@ -665,7 +683,7 @@ class GraphEngine: parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id + parent_parallel_start_node_id=parent_parallel_start_node_id, ) return except Exception as e: @@ -674,10 +692,7 @@ class GraphEngine: finally: db.session.close() - def _append_variables_recursively(self, - node_id: str, - variable_key_list: list[str], - variable_value: VariableValue): + def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): """ Append variables recursively :param node_id: node id @@ -685,10 +700,7 @@ class GraphEngine: :param variable_value: variable value :return: """ - self.graph_runtime_state.variable_pool.add( - [node_id] + variable_key_list, - variable_value - ) + self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) # if variable_value is a dict, then recursively append variables if isinstance(variable_value, dict): @@ -696,9 +708,7 @@ class GraphEngine: # construct new key list new_key_list = variable_key_list + [key] self._append_variables_recursively( - node_id=node_id, - variable_key_list=new_key_list, - variable_value=value + node_id=node_id, variable_key_list=new_key_list, variable_value=value ) def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 8cf01727ec..deacbbbbb0 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -29,14 +29,12 @@ class AnswerNode(BaseNode): # generate routes generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data) - answer = '' + answer = "" for part in generate_routes: if part.type == GenerateRouteChunk.ChunkType.VAR: part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector - value = self.graph_runtime_state.variable_pool.get( - value_selector - ) + value = self.graph_runtime_state.variable_pool.get(value_selector) if value: answer += value.markdown @@ -44,19 +42,11 @@ class AnswerNode(BaseNode): part = cast(TextGenerateRouteChunk, part) answer += part.text - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "answer": answer - } - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer}) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: AnswerNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -73,6 +63,6 @@ class AnswerNode(BaseNode): variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector return variable_mapping diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 6cb80091c9..2562b9ce96 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -1,4 +1,3 @@ - from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.answer.entities import ( @@ -12,12 +11,12 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser class AnswerStreamGeneratorRouter: - @classmethod - def init(cls, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined] - ) -> AnswerStreamGenerateRoute: + def init( + cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + ) -> AnswerStreamGenerateRoute: """ Get stream generate routes. :return: @@ -25,7 +24,7 @@ class AnswerStreamGeneratorRouter: # parse stream output node value selectors of answer nodes answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} for answer_node_id, node_config in node_id_config_mapping.items(): - if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value: + if not node_config.get("data", {}).get("type") == NodeType.ANSWER.value: continue # get generate route for stream output @@ -37,12 +36,11 @@ class AnswerStreamGeneratorRouter: answer_dependencies = cls._fetch_answers_dependencies( answer_node_ids=answer_node_ids, reverse_edge_mapping=reverse_edge_mapping, - node_id_config_mapping=node_id_config_mapping + node_id_config_mapping=node_id_config_mapping, ) return AnswerStreamGenerateRoute( - answer_generate_route=answer_generate_route, - answer_dependencies=answer_dependencies + answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies ) @classmethod @@ -56,8 +54,7 @@ class AnswerStreamGeneratorRouter: variable_selectors = variable_template_parser.extract_variable_selectors() value_selector_mapping = { - variable_selector.variable: variable_selector.value_selector - for variable_selector in variable_selectors + variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors } variable_keys = list(value_selector_mapping.keys()) @@ -71,21 +68,17 @@ class AnswerStreamGeneratorRouter: template = node_data.answer for var in variable_keys: - template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') + template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω") generate_routes: list[GenerateRouteChunk] = [] - for part in template.split('Ω'): + for part in template.split("Ω"): if part: if cls._is_variable(part, variable_keys): - var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') + var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "") value_selector = value_selector_mapping[var_key] - generate_routes.append(VarGenerateRouteChunk( - value_selector=value_selector - )) + generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector)) else: - generate_routes.append(TextGenerateRouteChunk( - text=part - )) + generate_routes.append(TextGenerateRouteChunk(text=part)) return generate_routes @@ -101,15 +94,16 @@ class AnswerStreamGeneratorRouter: @classmethod def _is_variable(cls, part, variable_keys): - cleaned_part = part.replace('{{', '').replace('}}', '') - return part.startswith('{{') and cleaned_part in variable_keys + cleaned_part = part.replace("{{", "").replace("}}", "") + return part.startswith("{{") and cleaned_part in variable_keys @classmethod - def _fetch_answers_dependencies(cls, - answer_node_ids: list[str], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - node_id_config_mapping: dict[str, dict] - ) -> dict[str, list[str]]: + def _fetch_answers_dependencies( + cls, + answer_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict], + ) -> dict[str, list[str]]: """ Fetch answer dependencies :param answer_node_ids: answer node ids @@ -127,19 +121,20 @@ class AnswerStreamGeneratorRouter: answer_node_id=answer_node_id, node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies + answer_dependencies=answer_dependencies, ) return answer_dependencies @classmethod - def _recursive_fetch_answer_dependencies(cls, - current_node_id: str, - answer_node_id: str, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - answer_dependencies: dict[str, list[str]] - ) -> None: + def _recursive_fetch_answer_dependencies( + cls, + current_node_id: str, + answer_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + answer_dependencies: dict[str, list[str]], + ) -> None: """ Recursive fetch answer dependencies :param current_node_id: current node id @@ -152,11 +147,11 @@ class AnswerStreamGeneratorRouter: reverse_edges = reverse_edge_mapping.get(current_node_id, []) for edge in reverse_edges: source_node_id = edge.source_node_id - source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') + source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") if source_node_type in ( - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER, + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, ): answer_dependencies[answer_node_id].append(source_node_id) else: @@ -165,5 +160,5 @@ class AnswerStreamGeneratorRouter: answer_node_id=answer_node_id, node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies + answer_dependencies=answer_dependencies, ) diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index c2a5dd5163..9776ce5810 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -18,7 +18,6 @@ logger = logging.getLogger(__name__) class AnswerStreamProcessor(StreamProcessor): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: super().__init__(graph, variable_pool) self.generate_routes = graph.answer_stream_generate_routes @@ -27,9 +26,7 @@ class AnswerStreamProcessor(StreamProcessor): self.route_position[answer_node_id] = 0 self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} - def process(self, - generator: Generator[GraphEngineEvent, None, None] - ) -> Generator[GraphEngineEvent, None, None]: + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: for event in generator: if isinstance(event, NodeRunStartedEvent): if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: @@ -47,9 +44,9 @@ class AnswerStreamProcessor(StreamProcessor): ] else: stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event) - self.current_stream_chunk_generating_node_ids[ - event.route_node_state.node_id - ] = stream_out_answer_node_ids + self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = ( + stream_out_answer_node_ids + ) for _ in stream_out_answer_node_ids: yield event @@ -77,9 +74,9 @@ class AnswerStreamProcessor(StreamProcessor): self.rest_node_ids = self.graph.node_ids.copy() self.current_stream_chunk_generating_node_ids = {} - def _generate_stream_outputs_when_node_finished(self, - event: NodeRunSucceededEvent - ) -> Generator[GraphEngineEvent, None, None]: + def _generate_stream_outputs_when_node_finished( + self, event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: """ Generate stream outputs. :param event: node run succeeded event @@ -87,10 +84,13 @@ class AnswerStreamProcessor(StreamProcessor): """ for answer_node_id, position in self.route_position.items(): # all depends on answer node id not in rest node ids - if (event.route_node_state.node_id != answer_node_id - and (answer_node_id not in self.rest_node_ids - or not all(dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))): + if event.route_node_state.node_id != answer_node_id and ( + answer_node_id not in self.rest_node_ids + or not all( + dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id] + ) + ): continue route_position = self.route_position[answer_node_id] @@ -115,9 +115,7 @@ class AnswerStreamProcessor(StreamProcessor): if not value_selector: break - value = self.variable_pool.get( - value_selector - ) + value = self.variable_pool.get(value_selector) if value is None: break @@ -158,8 +156,9 @@ class AnswerStreamProcessor(StreamProcessor): continue # all depends on answer node id not in rest node ids - if all(dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): + if all( + dep_id not in self.rest_node_ids for dep_id in self.generate_routes.answer_dependencies[answer_node_id] + ): if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): continue @@ -213,7 +212,7 @@ class AnswerStreamProcessor(StreamProcessor): return None if isinstance(value, dict): - if '__variant' in value and value['__variant'] == FileVar.__name__: + if "__variant" in value and value["__variant"] == FileVar.__name__: return value elif isinstance(value, FileVar): return value.to_dict() diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index cbabbca37d..36c3fe180a 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -7,16 +7,13 @@ from core.workflow.graph_engine.entities.graph import Graph class StreamProcessor(ABC): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: self.graph = graph self.variable_pool = variable_pool self.rest_node_ids = graph.node_ids.copy() @abstractmethod - def process(self, - generator: Generator[GraphEngineEvent, None, None] - ) -> Generator[GraphEngineEvent, None, None]: + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: raise NotImplementedError def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: @@ -35,9 +32,11 @@ class StreamProcessor(ABC): reachable_node_ids = [] unreachable_first_node_ids = [] for edge in self.graph.edge_mapping[finished_node_id]: - if (edge.run_condition - and edge.run_condition.branch_identify - and run_result.edge_source_handle == edge.run_condition.branch_identify): + if ( + edge.run_condition + and edge.run_condition.branch_identify + and run_result.edge_source_handle == edge.run_condition.branch_identify + ): reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) continue else: diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 620c2c426b..e356e7fd70 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -9,6 +9,7 @@ class AnswerNodeData(BaseNodeData): """ Answer Node Data. """ + answer: str = Field(..., description="answer template string") @@ -28,6 +29,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk): """ Var Generate Route Chunk. """ + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR """generate route chunk type""" value_selector: list[str] = Field(..., description="value selector") @@ -37,6 +39,7 @@ class TextGenerateRouteChunk(GenerateRouteChunk): """ Text Generate Route Chunk. """ + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT """generate route chunk type""" text: str = Field(..., description="text") @@ -52,11 +55,10 @@ class AnswerStreamGenerateRoute(BaseModel): """ AnswerStreamGenerateRoute entity """ + answer_dependencies: dict[str, list[str]] = Field( - ..., - description="answer dependencies (answer node id -> dependent answer node ids)" + ..., description="answer dependencies (answer node id -> dependent answer node ids)" ) answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( - ..., - description="answer generate route (answer node id -> generate route chunks)" + ..., description="answer generate route (answer node id -> generate route chunks)" ) diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index b9912314f1..7bfe45a13c 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -15,14 +15,16 @@ class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType - def __init__(self, - id: str, - config: Mapping[str, Any], - graph_init_params: GraphInitParams, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - previous_node_id: Optional[str] = None, - thread_pool_id: Optional[str] = None) -> None: + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: GraphInitParams, + graph: Graph, + graph_runtime_state: GraphRuntimeState, + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + ) -> None: self.id = id self.tenant_id = graph_init_params.tenant_id self.app_id = graph_init_params.app_id @@ -46,8 +48,7 @@ class BaseNode(ABC): self.node_data = self._node_data_cls(**config.get("data", {})) @abstractmethod - def _run(self) \ - -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: + def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]: """ Run node :return: @@ -62,14 +63,14 @@ class BaseNode(ABC): result = self._run() if isinstance(result, NodeRunResult): - yield RunCompletedEvent( - run_result=result - ) + yield RunCompletedEvent(run_result=result) else: yield from result @classmethod - def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]: + def extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], config: dict + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping :param graph_config: graph config @@ -82,17 +83,12 @@ class BaseNode(ABC): node_data = cls._node_data_cls(**config.get("data", {})) return cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, - node_id=node_id, - node_data=node_data + graph_config=graph_config, node_id=node_id, node_data=node_data ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: BaseNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 955afdfa1d..4a1787c8c1 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -25,11 +25,10 @@ class CodeNode(BaseNode): """ code_language = CodeLanguage.PYTHON3 if filters: - code_language = (filters.get("code_language", CodeLanguage.PYTHON3)) + code_language = filters.get("code_language", CodeLanguage.PYTHON3) providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] - code_provider: type[CodeNodeProvider] = next(p for p in providers - if p.is_accept_language(code_language)) + code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language)) return code_provider.get_default_config() @@ -63,17 +62,9 @@ class CodeNode(BaseNode): # Transform result result = self._transform_result(result, node_data.outputs) except (CodeExecutionException, ValueError) as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e) - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs=result - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) def _check_string(self, value: str, variable: str) -> str: """ @@ -87,12 +78,14 @@ class CodeNode(BaseNode): return None else: raise ValueError(f"Output variable `{variable}` must be a string") - - if len(value) > dify_config.CODE_MAX_STRING_LENGTH: - raise ValueError(f'The length of output variable `{variable}` must be' - f' less than {dify_config.CODE_MAX_STRING_LENGTH} characters') - return value.replace('\x00', '') + if len(value) > dify_config.CODE_MAX_STRING_LENGTH: + raise ValueError( + f"The length of output variable `{variable}` must be" + f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters" + ) + + return value.replace("\x00", "") def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: """ @@ -108,20 +101,24 @@ class CodeNode(BaseNode): raise ValueError(f"Output variable `{variable}` must be a number") if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: - raise ValueError(f'Output variable `{variable}` is out of range,' - f' it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}.') + raise ValueError( + f"Output variable `{variable}` is out of range," + f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}." + ) if isinstance(value, float): # raise error if precision is too high - if len(str(value).split('.')[1]) > dify_config.CODE_MAX_PRECISION: - raise ValueError(f'Output variable `{variable}` has too high precision,' - f' it must be less than {dify_config.CODE_MAX_PRECISION} digits.') + if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION: + raise ValueError( + f"Output variable `{variable}` has too high precision," + f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." + ) return value - def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], - prefix: str = '', - depth: int = 1) -> dict: + def _transform_result( + self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1 + ) -> dict: """ Transform result :param result: result @@ -139,183 +136,181 @@ class CodeNode(BaseNode): self._transform_result( result=output_value, output_schema=None, - prefix=f'{prefix}.{output_name}' if prefix else output_name, - depth=depth + 1 + prefix=f"{prefix}.{output_name}" if prefix else output_name, + depth=depth + 1, ) elif isinstance(output_value, int | float): self._check_number( - value=output_value, - variable=f'{prefix}.{output_name}' if prefix else output_name + value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name ) elif isinstance(output_value, str): self._check_string( - value=output_value, - variable=f'{prefix}.{output_name}' if prefix else output_name + value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name ) elif isinstance(output_value, list): first_element = output_value[0] if len(output_value) > 0 else None if first_element is not None: - if isinstance(first_element, int | float) and all(value is None or isinstance(value, int | float) for value in output_value): + if isinstance(first_element, int | float) and all( + value is None or isinstance(value, int | float) for value in output_value + ): for i, value in enumerate(output_value): self._check_number( value=value, - variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]' + variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", ) - elif isinstance(first_element, str) and all(value is None or isinstance(value, str) for value in output_value): + elif isinstance(first_element, str) and all( + value is None or isinstance(value, str) for value in output_value + ): for i, value in enumerate(output_value): self._check_string( value=value, - variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]' + variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", ) - elif isinstance(first_element, dict) and all(value is None or isinstance(value, dict) for value in output_value): + elif isinstance(first_element, dict) and all( + value is None or isinstance(value, dict) for value in output_value + ): for i, value in enumerate(output_value): if value is not None: self._transform_result( result=value, output_schema=None, - prefix=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]', - depth=depth + 1 + prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", + depth=depth + 1, ) else: - raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.') + raise ValueError( + f"Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type." + ) elif isinstance(output_value, type(None)): pass else: - raise ValueError(f'Output {prefix}.{output_name} is not a valid type.') - + raise ValueError(f"Output {prefix}.{output_name} is not a valid type.") + return result parameters_validated = {} for output_name, output_config in output_schema.items(): - dot = '.' if prefix else '' + dot = "." if prefix else "" if output_name not in result: - raise ValueError(f'Output {prefix}{dot}{output_name} is missing.') - - if output_config.type == 'object': + raise ValueError(f"Output {prefix}{dot}{output_name} is missing.") + + if output_config.type == "object": # check if output is object if not isinstance(result.get(output_name), dict): if isinstance(result.get(output_name), type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead." ) else: transformed_result[output_name] = self._transform_result( result=result[output_name], output_schema=output_config.children, - prefix=f'{prefix}.{output_name}', - depth=depth + 1 + prefix=f"{prefix}.{output_name}", + depth=depth + 1, ) - elif output_config.type == 'number': + elif output_config.type == "number": # check if number available transformed_result[output_name] = self._check_number( - value=result[output_name], - variable=f'{prefix}{dot}{output_name}' + value=result[output_name], variable=f"{prefix}{dot}{output_name}" ) - elif output_config.type == 'string': + elif output_config.type == "string": # check if string available transformed_result[output_name] = self._check_string( value=result[output_name], - variable=f'{prefix}{dot}{output_name}', + variable=f"{prefix}{dot}{output_name}", ) - elif output_config.type == 'array[number]': + elif output_config.type == "array[number]": # check if array of number available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be' - f' less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements.' + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements." ) transformed_result[output_name] = [ - self._check_number( - value=value, - variable=f'{prefix}{dot}{output_name}[{i}]' - ) + self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") for i, value in enumerate(result[output_name]) ] - elif output_config.type == 'array[string]': + elif output_config.type == "array[string]": # check if array of string available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be' - f' less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements.' + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements." ) transformed_result[output_name] = [ - self._check_string( - value=value, - variable=f'{prefix}{dot}{output_name}[{i}]' - ) + self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") for i, value in enumerate(result[output_name]) ] - elif output_config.type == 'array[object]': + elif output_config.type == "array[object]": # check if array of object available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + f"Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be' - f' less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements.' + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements." ) - + for i, value in enumerate(result[output_name]): if not isinstance(value, dict): if isinstance(value, type(None)): pass else: raise ValueError( - f'Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}.' + f"Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}." ) transformed_result[output_name] = [ - None if value is None else self._transform_result( + None + if value is None + else self._transform_result( result=value, output_schema=output_config.children, - prefix=f'{prefix}{dot}{output_name}[{i}]', - depth=depth + 1 + prefix=f"{prefix}{dot}{output_name}[{i}]", + depth=depth + 1, ) for i, value in enumerate(result[output_name]) ] else: - raise ValueError(f'Output type {output_config.type} is not supported.') - + raise ValueError(f"Output type {output_config.type} is not supported.") + parameters_validated[output_name] = True # check if all output parameters are validated if len(parameters_validated) != len(result): - raise ValueError('Not all output parameters are validated.') + raise ValueError("Not all output parameters are validated.") return transformed_result @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: CodeNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: CodeNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -325,5 +320,6 @@ class CodeNode(BaseNode): :return: """ return { - node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + node_id + "." + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index c0701ecccd..5eb0e0f63f 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -11,9 +11,10 @@ class CodeNodeData(BaseNodeData): """ Code Node Data. """ + class Output(BaseModel): - type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]'] - children: Optional[dict[str, 'Output']] = None + type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + children: Optional[dict[str, "Output"]] = None class Dependency(BaseModel): name: str @@ -23,4 +24,4 @@ class CodeNodeData(BaseNodeData): code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] code: str outputs: dict[str, Output] - dependencies: Optional[list[Dependency]] = None \ No newline at end of file + dependencies: Optional[list[Dependency]] = None diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 552914b308..7b78d67be8 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -25,18 +25,11 @@ class EndNode(BaseNode): value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) outputs[variable_selector.variable] = value - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=outputs, - outputs=outputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: EndNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index 8390f6d81b..30ce8fe018 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -3,13 +3,13 @@ from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam class EndStreamGeneratorRouter: - @classmethod - def init(cls, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - node_parallel_mapping: dict[str, str] - ) -> EndStreamParam: + def init( + cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_parallel_mapping: dict[str, str], + ) -> EndStreamParam: """ Get stream generate routes. :return: @@ -17,7 +17,7 @@ class EndStreamGeneratorRouter: # parse stream output node value selector of end nodes end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {} for end_node_id, node_config in node_id_config_mapping.items(): - if not node_config.get('data', {}).get('type') == NodeType.END.value: + if not node_config.get("data", {}).get("type") == NodeType.END.value: continue # skip end node in parallel @@ -33,18 +33,18 @@ class EndStreamGeneratorRouter: end_dependencies = cls._fetch_ends_dependencies( end_node_ids=end_node_ids, reverse_edge_mapping=reverse_edge_mapping, - node_id_config_mapping=node_id_config_mapping + node_id_config_mapping=node_id_config_mapping, ) return EndStreamParam( end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping, - end_dependencies=end_dependencies + end_dependencies=end_dependencies, ) @classmethod - def extract_stream_variable_selector_from_node_data(cls, - node_id_config_mapping: dict[str, dict], - node_data: EndNodeData) -> list[list[str]]: + def extract_stream_variable_selector_from_node_data( + cls, node_id_config_mapping: dict[str, dict], node_data: EndNodeData + ) -> list[list[str]]: """ Extract stream variable selector from node data :param node_id_config_mapping: node id config mapping @@ -59,21 +59,22 @@ class EndStreamGeneratorRouter: continue node_id = variable_selector.value_selector[0] - if node_id != 'sys' and node_id in node_id_config_mapping: + if node_id != "sys" and node_id in node_id_config_mapping: node = node_id_config_mapping[node_id] - node_type = node.get('data', {}).get('type') + node_type = node.get("data", {}).get("type") if ( variable_selector.value_selector not in value_selectors - and node_type == NodeType.LLM.value - and variable_selector.value_selector[1] == 'text' + and node_type == NodeType.LLM.value + and variable_selector.value_selector[1] == "text" ): value_selectors.append(variable_selector.value_selector) return value_selectors @classmethod - def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \ - -> list[list[str]]: + def _extract_stream_variable_selector( + cls, node_id_config_mapping: dict[str, dict], config: dict + ) -> list[list[str]]: """ Extract stream variable selector from node config :param node_id_config_mapping: node id config mapping @@ -84,11 +85,12 @@ class EndStreamGeneratorRouter: return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data) @classmethod - def _fetch_ends_dependencies(cls, - end_node_ids: list[str], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - node_id_config_mapping: dict[str, dict] - ) -> dict[str, list[str]]: + def _fetch_ends_dependencies( + cls, + end_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict], + ) -> dict[str, list[str]]: """ Fetch end dependencies :param end_node_ids: end node ids @@ -106,20 +108,21 @@ class EndStreamGeneratorRouter: end_node_id=end_node_id, node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - end_dependencies=end_dependencies + end_dependencies=end_dependencies, ) return end_dependencies @classmethod - def _recursive_fetch_end_dependencies(cls, - current_node_id: str, - end_node_id: str, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], - # type: ignore[name-defined] - end_dependencies: dict[str, list[str]] - ) -> None: + def _recursive_fetch_end_dependencies( + cls, + current_node_id: str, + end_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], + # type: ignore[name-defined] + end_dependencies: dict[str, list[str]], + ) -> None: """ Recursive fetch end dependencies :param current_node_id: current node id @@ -132,10 +135,10 @@ class EndStreamGeneratorRouter: reverse_edges = reverse_edge_mapping.get(current_node_id, []) for edge in reverse_edges: source_node_id = edge.source_node_id - source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') + source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") if source_node_type in ( - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, ): end_dependencies[end_node_id].append(source_node_id) else: @@ -144,5 +147,5 @@ class EndStreamGeneratorRouter: end_node_id=end_node_id, node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping, - end_dependencies=end_dependencies + end_dependencies=end_dependencies, ) diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 4474c2a78a..0366d7965d 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -15,7 +15,6 @@ logger = logging.getLogger(__name__) class EndStreamProcessor(StreamProcessor): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: super().__init__(graph, variable_pool) self.end_stream_param = graph.end_stream_param @@ -26,9 +25,7 @@ class EndStreamProcessor(StreamProcessor): self.has_outputed = False self.outputed_node_ids = set() - def process(self, - generator: Generator[GraphEngineEvent, None, None] - ) -> Generator[GraphEngineEvent, None, None]: + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: for event in generator: if isinstance(event, NodeRunStartedEvent): if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: @@ -38,7 +35,7 @@ class EndStreamProcessor(StreamProcessor): elif isinstance(event, NodeRunStreamChunkEvent): if event.in_iteration_id: if self.has_outputed and event.node_id not in self.outputed_node_ids: - event.chunk_content = '\n' + event.chunk_content + event.chunk_content = "\n" + event.chunk_content self.outputed_node_ids.add(event.node_id) self.has_outputed = True @@ -51,13 +48,13 @@ class EndStreamProcessor(StreamProcessor): ] else: stream_out_end_node_ids = self._get_stream_out_end_node_ids(event) - self.current_stream_chunk_generating_node_ids[ - event.route_node_state.node_id - ] = stream_out_end_node_ids + self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = ( + stream_out_end_node_ids + ) if stream_out_end_node_ids: if self.has_outputed and event.node_id not in self.outputed_node_ids: - event.chunk_content = '\n' + event.chunk_content + event.chunk_content = "\n" + event.chunk_content self.outputed_node_ids.add(event.node_id) self.has_outputed = True @@ -86,9 +83,9 @@ class EndStreamProcessor(StreamProcessor): self.rest_node_ids = self.graph.node_ids.copy() self.current_stream_chunk_generating_node_ids = {} - def _generate_stream_outputs_when_node_finished(self, - event: NodeRunSucceededEvent - ) -> Generator[GraphEngineEvent, None, None]: + def _generate_stream_outputs_when_node_finished( + self, event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: """ Generate stream outputs. :param event: node run succeeded event @@ -96,10 +93,12 @@ class EndStreamProcessor(StreamProcessor): """ for end_node_id, position in self.route_position.items(): # all depends on end node id not in rest node ids - if (event.route_node_state.node_id != end_node_id - and (end_node_id not in self.rest_node_ids - or not all(dep_id not in self.rest_node_ids - for dep_id in self.end_stream_param.end_dependencies[end_node_id]))): + if event.route_node_state.node_id != end_node_id and ( + end_node_id not in self.rest_node_ids + or not all( + dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id] + ) + ): continue route_position = self.route_position[end_node_id] @@ -116,9 +115,7 @@ class EndStreamProcessor(StreamProcessor): if not value_selector: continue - value = self.variable_pool.get( - value_selector - ) + value = self.variable_pool.get(value_selector) if value is None: break @@ -128,7 +125,7 @@ class EndStreamProcessor(StreamProcessor): if text: current_node_id = value_selector[0] if self.has_outputed and current_node_id not in self.outputed_node_ids: - text = '\n' + text + text = "\n" + text self.outputed_node_ids.add(current_node_id) self.has_outputed = True @@ -165,8 +162,7 @@ class EndStreamProcessor(StreamProcessor): continue # all depends on end node id not in rest node ids - if all(dep_id not in self.rest_node_ids - for dep_id in self.end_stream_param.end_dependencies[end_node_id]): + if all(dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]): if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]): continue @@ -178,7 +174,7 @@ class EndStreamProcessor(StreamProcessor): break position += 1 - + if not value_selector: continue diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index a0edf7b579..c3270ac22a 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -8,6 +8,7 @@ class EndNodeData(BaseNodeData): """ END Node Data. """ + outputs: list[VariableSelector] @@ -15,11 +16,10 @@ class EndStreamParam(BaseModel): """ EndStreamParam entity """ + end_dependencies: dict[str, list[str]] = Field( - ..., - description="end dependencies (end node id -> dependent node ids)" + ..., description="end dependencies (end node id -> dependent node ids)" ) end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( - ..., - description="end stream variable selector mapping (end node id -> stream variable selectors)" + ..., description="end stream variable selector mapping (end node id -> stream variable selectors)" ) diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index c066d469d8..66dd1f2dc6 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -7,32 +7,32 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class HttpRequestNodeAuthorizationConfig(BaseModel): - type: Literal[None, 'basic', 'bearer', 'custom'] + type: Literal[None, "basic", "bearer", "custom"] api_key: Union[None, str] = None header: Union[None, str] = None class HttpRequestNodeAuthorization(BaseModel): - type: Literal['no-auth', 'api-key'] + type: Literal["no-auth", "api-key"] config: Optional[HttpRequestNodeAuthorizationConfig] = None - @field_validator('config', mode='before') + @field_validator("config", mode="before") @classmethod def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo): """ Check config, if type is no-auth, config should be None, otherwise it should be a dict. """ - if values.data['type'] == 'no-auth': + if values.data["type"] == "no-auth": return None else: if not v or not isinstance(v, dict): - raise ValueError('config should be a dict') + raise ValueError("config should be a dict") return v class HttpRequestNodeBody(BaseModel): - type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json'] + type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json"] data: Union[None, str] = None @@ -47,7 +47,7 @@ class HttpRequestNodeData(BaseNodeData): Code Node Data. """ - method: Literal['get', 'post', 'put', 'patch', 'delete', 'head'] + method: Literal["get", "post", "put", "patch", "delete", "head"] url: str authorization: HttpRequestNodeAuthorization headers: str diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index d16bff58bd..49102dc3ab 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -33,12 +33,12 @@ class HttpExecutorResponse: check if response is file """ content_type = self.get_content_type() - file_content_types = ['image', 'audio', 'video'] + file_content_types = ["image", "audio", "video"] return any(v in content_type for v in file_content_types) def get_content_type(self) -> str: - return self.headers.get('content-type', '') + return self.headers.get("content-type", "") def extract_file(self) -> tuple[str, bytes]: """ @@ -47,28 +47,28 @@ class HttpExecutorResponse: if self.is_file: return self.get_content_type(), self.body - return '', b'' + return "", b"" @property def content(self) -> str: if isinstance(self.response, httpx.Response): return self.response.text else: - raise ValueError(f'Invalid response type {type(self.response)}') + raise ValueError(f"Invalid response type {type(self.response)}") @property def body(self) -> bytes: if isinstance(self.response, httpx.Response): return self.response.content else: - raise ValueError(f'Invalid response type {type(self.response)}') + raise ValueError(f"Invalid response type {type(self.response)}") @property def status_code(self) -> int: if isinstance(self.response, httpx.Response): return self.response.status_code else: - raise ValueError(f'Invalid response type {type(self.response)}') + raise ValueError(f"Invalid response type {type(self.response)}") @property def size(self) -> int: @@ -77,11 +77,11 @@ class HttpExecutorResponse: @property def readable_size(self) -> str: if self.size < 1024: - return f'{self.size} bytes' + return f"{self.size} bytes" elif self.size < 1024 * 1024: - return f'{(self.size / 1024):.2f} KB' + return f"{(self.size / 1024):.2f} KB" else: - return f'{(self.size / 1024 / 1024):.2f} MB' + return f"{(self.size / 1024 / 1024):.2f} MB" class HttpExecutor: @@ -120,7 +120,7 @@ class HttpExecutor: """ check if body is json """ - if body and body.type == 'json' and body.data: + if body and body.type == "json" and body.data: try: json.loads(body.data) return True @@ -134,15 +134,15 @@ class HttpExecutor: """ Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` """ - kv_paris = convert_text.split('\n') + kv_paris = convert_text.split("\n") result = {} for kv in kv_paris: if not kv.strip(): continue - kv = kv.split(':', maxsplit=1) + kv = kv.split(":", maxsplit=1) if len(kv) == 1: - k, v = kv[0], '' + k, v = kv[0], "" else: k, v = kv result[k.strip()] = v @@ -166,31 +166,31 @@ class HttpExecutor: # check if it's a valid JSON is_valid_json = self._is_json_body(node_data.body) - body_data = node_data.body.data or '' + body_data = node_data.body.data or "" if body_data: body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json) - content_type_is_set = any(key.lower() == 'content-type' for key in self.headers) - if node_data.body.type == 'json' and not content_type_is_set: - self.headers['Content-Type'] = 'application/json' - elif node_data.body.type == 'x-www-form-urlencoded' and not content_type_is_set: - self.headers['Content-Type'] = 'application/x-www-form-urlencoded' + content_type_is_set = any(key.lower() == "content-type" for key in self.headers) + if node_data.body.type == "json" and not content_type_is_set: + self.headers["Content-Type"] = "application/json" + elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set: + self.headers["Content-Type"] = "application/x-www-form-urlencoded" - if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: + if node_data.body.type in ["form-data", "x-www-form-urlencoded"]: body = self._to_dict(body_data) - if node_data.body.type == 'form-data': - self.files = {k: ('', v) for k, v in body.items()} - random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)]) - self.boundary = f'----WebKitFormBoundary{random_str(16)}' + if node_data.body.type == "form-data": + self.files = {k: ("", v) for k, v in body.items()} + random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)]) + self.boundary = f"----WebKitFormBoundary{random_str(16)}" - self.headers['Content-Type'] = f'multipart/form-data; boundary={self.boundary}' + self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" else: self.body = urlencode(body) - elif node_data.body.type in ['json', 'raw-text']: + elif node_data.body.type in ["json", "raw-text"]: self.body = body_data - elif node_data.body.type == 'none': - self.body = '' + elif node_data.body.type == "none": + self.body = "" self.variable_selectors = ( server_url_variable_selectors @@ -202,23 +202,23 @@ class HttpExecutor: def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.authorization) headers = deepcopy(self.headers) or {} - if self.authorization.type == 'api-key': + if self.authorization.type == "api-key": if self.authorization.config is None: - raise ValueError('self.authorization config is required') + raise ValueError("self.authorization config is required") if authorization.config is None: - raise ValueError('authorization config is required') + raise ValueError("authorization config is required") if self.authorization.config.api_key is None: - raise ValueError('api_key is required') + raise ValueError("api_key is required") if not authorization.config.header: - authorization.config.header = 'Authorization' + authorization.config.header = "Authorization" - if self.authorization.config.type == 'bearer': - headers[authorization.config.header] = f'Bearer {authorization.config.api_key}' - elif self.authorization.config.type == 'basic': - headers[authorization.config.header] = f'Basic {authorization.config.api_key}' - elif self.authorization.config.type == 'custom': + if self.authorization.config.type == "bearer": + headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" + elif self.authorization.config.type == "basic": + headers[authorization.config.header] = f"Basic {authorization.config.api_key}" + elif self.authorization.config.type == "custom": headers[authorization.config.header] = authorization.config.api_key return headers @@ -230,10 +230,13 @@ class HttpExecutor: if isinstance(response, httpx.Response): executor_response = HttpExecutorResponse(response) else: - raise ValueError(f'Invalid response type {type(response)}') + raise ValueError(f"Invalid response type {type(response)}") - threshold_size = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE if executor_response.is_file \ + threshold_size = ( + dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE + if executor_response.is_file else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE + ) if executor_response.size > threshold_size: raise ValueError( f'{"File" if executor_response.is_file else "Text"} size is too large,' @@ -248,17 +251,17 @@ class HttpExecutor: do http request depending on api bundle """ kwargs = { - 'url': self.server_url, - 'headers': headers, - 'params': self.params, - 'timeout': (self.timeout.connect, self.timeout.read, self.timeout.write), - 'follow_redirects': True, + "url": self.server_url, + "headers": headers, + "params": self.params, + "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), + "follow_redirects": True, } - if self.method in ('get', 'head', 'post', 'put', 'delete', 'patch'): + if self.method in ("get", "head", "post", "put", "delete", "patch"): response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) else: - raise ValueError(f'Invalid http method {self.method}') + raise ValueError(f"Invalid http method {self.method}") return response def invoke(self) -> HttpExecutorResponse: @@ -280,15 +283,15 @@ class HttpExecutor: """ server_url = self.server_url if self.params: - server_url += f'?{urlencode(self.params)}' + server_url += f"?{urlencode(self.params)}" - raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n' + raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n" headers = self._assembling_headers() for k, v in headers.items(): # get authorization header - if self.authorization.type == 'api-key': - authorization_header = 'Authorization' + if self.authorization.type == "api-key": + authorization_header = "Authorization" if self.authorization.config and self.authorization.config.header: authorization_header = self.authorization.config.header @@ -296,21 +299,21 @@ class HttpExecutor: raw_request += f'{k}: {"*" * len(v)}\n' continue - raw_request += f'{k}: {v}\n' + raw_request += f"{k}: {v}\n" - raw_request += '\n' + raw_request += "\n" # if files, use multipart/form-data with boundary if self.files: boundary = self.boundary - raw_request += f'--{boundary}' + raw_request += f"--{boundary}" for k, v in self.files.items(): raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n' - raw_request += f'{v[1]}\n' - raw_request += f'--{boundary}' - raw_request += '--' + raw_request += f"{v[1]}\n" + raw_request += f"--{boundary}" + raw_request += "--" else: - raw_request += self.body or '' + raw_request += self.body or "" return raw_request @@ -328,9 +331,9 @@ class HttpExecutor: for variable_selector in variable_selectors: variable = variable_pool.get_any(variable_selector.value_selector) if variable is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") if escape_quotes and isinstance(variable, str): - value = variable.replace('"', '\\"').replace('\n', '\\n') + value = variable.replace('"', '\\"').replace("\n", "\\n") else: value = variable variable_value_mapping[variable_selector.variable] = value diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 3f68c8b1d0..cd40819126 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -31,18 +31,18 @@ class HttpRequestNode(BaseNode): @classmethod def get_default_config(cls, filters: dict | None = None) -> dict: return { - 'type': 'http-request', - 'config': { - 'method': 'get', - 'authorization': { - 'type': 'no-auth', + "type": "http-request", + "config": { + "method": "get", + "authorization": { + "type": "no-auth", }, - 'body': {'type': 'none'}, - 'timeout': { + "body": {"type": "none"}, + "timeout": { **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), - 'max_connect_timeout': dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, - 'max_read_timeout': dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, - 'max_write_timeout': dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, + "max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, }, }, } @@ -52,9 +52,8 @@ class HttpRequestNode(BaseNode): # TODO: Switch to use segment directly if node_data.authorization.config and node_data.authorization.config.api_key: node_data.authorization.config.api_key = parser.convert_template( - template=node_data.authorization.config.api_key, - variable_pool=self.graph_runtime_state.variable_pool - ).text + template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool + ).text # init http executor http_executor = None @@ -62,7 +61,7 @@ class HttpRequestNode(BaseNode): http_executor = HttpExecutor( node_data=node_data, timeout=self._get_request_timeout(node_data), - variable_pool=self.graph_runtime_state.variable_pool + variable_pool=self.graph_runtime_state.variable_pool, ) # invoke http executor @@ -71,7 +70,7 @@ class HttpRequestNode(BaseNode): process_data = {} if http_executor: process_data = { - 'request': http_executor.to_raw_request(), + "request": http_executor.to_raw_request(), } return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -84,13 +83,13 @@ class HttpRequestNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ - 'status_code': response.status_code, - 'body': response.content if not files else '', - 'headers': response.headers, - 'files': files, + "status_code": response.status_code, + "body": response.content if not files else "", + "headers": response.headers, + "files": files, }, process_data={ - 'request': http_executor.to_raw_request(), + "request": http_executor.to_raw_request(), }, ) @@ -107,10 +106,7 @@ class HttpRequestNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HttpRequestNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -126,11 +122,11 @@ class HttpRequestNode(BaseNode): variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector return variable_mapping except Exception as e: - logging.exception(f'Failed to extract variable selector to variable mapping: {e}') + logging.exception(f"Failed to extract variable selector to variable mapping: {e}") return {} def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]: @@ -144,7 +140,7 @@ class HttpRequestNode(BaseNode): # extract filename from url filename = path.basename(url) # extract extension if possible - extension = guess_extension(mimetype) or '.bin' + extension = guess_extension(mimetype) or ".bin" tool_file = ToolFileManager.create_file_by_raw( user_id=self.user_id, diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index 338277ace1..54c1081fd3 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -15,6 +15,7 @@ class IfElseNodeData(BaseNodeData): """ Case entity representing a single logical condition group """ + case_id: str logical_operator: Literal["and", "or"] conditions: list[Condition] diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index ca87eecd0d..5b4737c6e5 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -20,13 +20,9 @@ class IfElseNode(BaseNode): node_data = self.node_data node_data = cast(IfElseNodeData, node_data) - node_inputs: dict[str, list] = { - "conditions": [] - } + node_inputs: dict[str, list] = {"conditions": []} - process_datas: dict[str, list] = { - "condition_results": [] - } + process_datas: dict[str, list] = {"condition_results": []} input_conditions = [] final_result = False @@ -37,8 +33,7 @@ class IfElseNode(BaseNode): if node_data.cases: for case in node_data.cases: input_conditions, group_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=case.conditions + variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions ) # Apply the logical operator for the current case @@ -60,8 +55,7 @@ class IfElseNode(BaseNode): else: # Fallback to old structure if cases are not defined input_conditions, group_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=node_data.conditions + variable_pool=self.graph_runtime_state.variable_pool, conditions=node_data.conditions ) final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result) @@ -69,21 +63,14 @@ class IfElseNode(BaseNode): selected_case_id = "true" if final_result else "false" process_datas["condition_results"].append( - { - "group": "default", - "results": group_result, - "final_result": final_result - } + {"group": "default", "results": group_result, "final_result": final_result} ) node_inputs["conditions"] = input_conditions except Exception as e: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=node_inputs, - process_data=process_datas, - error=str(e) + status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_datas, error=str(e) ) outputs = {"result": final_result, "selected_case_id": selected_case_id} @@ -93,17 +80,14 @@ class IfElseNode(BaseNode): inputs=node_inputs, process_data=process_datas, edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default' - outputs=outputs + outputs=outputs, ) return data @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IfElseNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IfElseNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 5fc5a827ae..3c2c189159 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -7,21 +7,25 @@ class IterationNodeData(BaseIterationNodeData): """ Iteration Node Data. """ - parent_loop_id: Optional[str] = None # redundant field, not used currently - iterator_selector: list[str] # variable selector - output_selector: list[str] # output selector + + parent_loop_id: Optional[str] = None # redundant field, not used currently + iterator_selector: list[str] # variable selector + output_selector: list[str] # output selector class IterationStartNodeData(BaseNodeData): """ Iteration Start Node Data. """ + pass + class IterationState(BaseIterationState): """ Iteration State. """ + outputs: list[Any] = None current_output: Optional[Any] = None @@ -29,6 +33,7 @@ class IterationState(BaseIterationState): """ Data. """ + iterator_length: int def get_last_output(self) -> Optional[Any]: @@ -38,9 +43,9 @@ class IterationState(BaseIterationState): if self.outputs: return self.outputs[-1] return None - + def get_current_output(self) -> Optional[Any]: """ Get current output. """ - return self.current_output \ No newline at end of file + return self.current_output diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 93eff16c33..77b14e36a1 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -33,6 +33,7 @@ class IterationNode(BaseNode): """ Iteration Node. """ + _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION @@ -45,31 +46,26 @@ class IterationNode(BaseNode): if not iterator_list_segment: raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found") - + iterator_list_value = iterator_list_segment.to_object() if not isinstance(iterator_list_value, list): raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - inputs = { - "iterator_selector": iterator_list_value - } + inputs = {"iterator_selector": iterator_list_value} graph_config = self.graph_config - + if not self.node_data.start_node_id: - raise ValueError(f'field start_node_id in iteration {self.node_id} not found') + raise ValueError(f"field start_node_id in iteration {self.node_id} not found") root_node_id = self.node_data.start_node_id # init graph - iteration_graph = Graph.init( - graph_config=graph_config, - root_node_id=root_node_id - ) + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) if not iteration_graph: - raise ValueError('iteration graph not found') + raise ValueError("iteration graph not found") leaf_node_ids = iteration_graph.get_leaf_node_ids() iteration_leaf_node_ids = [] @@ -97,26 +93,21 @@ class IterationNode(BaseNode): Condition( variable_selector=[self.node_id, "index"], comparison_operator="<", - value=str(len(iterator_list_value)) + value=str(len(iterator_list_value)), ) - ] - ) + ], + ), ) variable_pool = self.graph_runtime_state.variable_pool # append iteration variable (item, index) to variable pool - variable_pool.add( - [self.node_id, 'index'], - 0 - ) - variable_pool.add( - [self.node_id, 'item'], - iterator_list_value[0] - ) + variable_pool.add([self.node_id, "index"], 0) + variable_pool.add([self.node_id, "item"], iterator_list_value[0]) # init graph engine from core.workflow.graph_engine.graph_engine import GraphEngine + graph_engine = GraphEngine( tenant_id=self.tenant_id, app_id=self.app_id, @@ -130,7 +121,7 @@ class IterationNode(BaseNode): graph_config=graph_config, variable_pool=variable_pool, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, - max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, ) start_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -142,10 +133,8 @@ class IterationNode(BaseNode): iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - metadata={ - "iterator_length": len(iterator_list_value) - }, - predecessor_node_id=self.previous_node_id + metadata={"iterator_length": len(iterator_list_value)}, + predecessor_node_id=self.previous_node_id, ) yield IterationRunNextEvent( @@ -154,7 +143,7 @@ class IterationNode(BaseNode): iteration_node_type=self.node_type, iteration_node_data=self.node_data, index=0, - pre_iteration_output=None + pre_iteration_output=None, ) outputs: list[Any] = [] @@ -176,7 +165,9 @@ class IterationNode(BaseNode): if NodeRunMetadataKey.ITERATION_ID not in metadata: metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index']) + metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( + [self.node_id, "index"] + ) event.route_node_state.node_run_result.metadata = metadata yield event @@ -192,21 +183,15 @@ class IterationNode(BaseNode): variable_pool.remove_node(node_id) # move to next iteration - current_index = variable_pool.get([self.node_id, 'index']) + current_index = variable_pool.get([self.node_id, "index"]) if current_index is None: - raise ValueError(f'iteration {self.node_id} current index not found') + raise ValueError(f"iteration {self.node_id} current index not found") next_index = int(current_index.to_object()) + 1 - variable_pool.add( - [self.node_id, 'index'], - next_index - ) + variable_pool.add([self.node_id, "index"], next_index) if next_index < len(iterator_list_value): - variable_pool.add( - [self.node_id, 'item'], - iterator_list_value[next_index] - ) + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) yield IterationRunNextEvent( iteration_id=self.id, @@ -214,8 +199,9 @@ class IterationNode(BaseNode): iteration_node_type=self.node_type, iteration_node_data=self.node_data, index=next_index, - pre_iteration_output=jsonable_encoder( - current_iteration_output) if current_iteration_output else None + pre_iteration_output=jsonable_encoder(current_iteration_output) + if current_iteration_output + else None, ) elif isinstance(event, BaseGraphEvent): if isinstance(event, GraphRunFailedEvent): @@ -227,13 +213,9 @@ class IterationNode(BaseNode): iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={ - "output": jsonable_encoder(outputs) - }, + outputs={"output": jsonable_encoder(outputs)}, steps=len(iterator_list_value), - metadata={ - "total_tokens": graph_engine.graph_runtime_state.total_tokens - }, + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=event.error, ) @@ -255,21 +237,14 @@ class IterationNode(BaseNode): iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={ - "output": jsonable_encoder(outputs) - }, + outputs={"output": jsonable_encoder(outputs)}, steps=len(iterator_list_value), - metadata={ - "total_tokens": graph_engine.graph_runtime_state.total_tokens - } + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, ) yield RunCompletedEvent( run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - 'output': jsonable_encoder(outputs) - } + status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)} ) ) except Exception as e: @@ -282,16 +257,11 @@ class IterationNode(BaseNode): iteration_node_data=self.node_data, start_at=start_at, inputs=inputs, - outputs={ - "output": jsonable_encoder(outputs) - }, + outputs={"output": jsonable_encoder(outputs)}, steps=len(iterator_list_value), - metadata={ - "total_tokens": graph_engine.graph_runtime_state.total_tokens - }, + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, error=str(e), ) - yield RunCompletedEvent( run_result=NodeRunResult( @@ -301,15 +271,12 @@ class IterationNode(BaseNode): ) finally: # remove iteration variable (item, index) from variable pool after iteration run completed - variable_pool.remove([self.node_id, 'index']) - variable_pool.remove([self.node_id, 'item']) - + variable_pool.remove([self.node_id, "index"]) + variable_pool.remove([self.node_id, "item"]) + @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IterationNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -319,36 +286,33 @@ class IterationNode(BaseNode): :return: """ variable_mapping = { - f'{node_id}.input_selector': node_data.iterator_selector, + f"{node_id}.input_selector": node_data.iterator_selector, } # init graph - iteration_graph = Graph.init( - graph_config=graph_config, - root_node_id=node_data.start_node_id - ) + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) if not iteration_graph: - raise ValueError('iteration graph not found') - + raise ValueError("iteration graph not found") + for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items(): - if sub_node_config.get('data', {}).get('iteration_id') != node_id: + if sub_node_config.get("data", {}).get("iteration_id") != node_id: continue # variable selector to variable mapping try: # Get node class from core.workflow.nodes.node_mapping import node_classes - node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type')) + + node_type = NodeType.value_of(sub_node_config.get("data", {}).get("type")) node_cls = node_classes.get(node_type) if not node_cls: continue node_cls = cast(BaseNode, node_cls) - + sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, - config=sub_node_config + graph_config=graph_config, config=sub_node_config ) sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping) except NotImplementedError: @@ -356,7 +320,8 @@ class IterationNode(BaseNode): # remove iteration variables sub_node_variable_mapping = { - sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items() + sub_node_id + "." + key: value + for key, value in sub_node_variable_mapping.items() if value[0] != node_id } @@ -364,8 +329,7 @@ class IterationNode(BaseNode): # remove variable out from iteration variable_mapping = { - key: value for key, value in variable_mapping.items() - if value[0] not in iteration_graph.node_ids + key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids } - + return variable_mapping diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 25044cf3eb..88b9665ac6 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -11,6 +11,7 @@ class IterationStartNode(BaseNode): """ Iteration Start Node. """ + _node_data_cls = IterationStartNodeData _node_type = NodeType.ITERATION_START @@ -18,16 +19,11 @@ class IterationStartNode(BaseNode): """ Run the node. """ - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED - ) - + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) + @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IterationNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 7cf392277c..1cd88039b1 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -9,6 +9,7 @@ class RerankingModelConfig(BaseModel): """ Reranking Model Config. """ + provider: str model: str @@ -17,6 +18,7 @@ class VectorSetting(BaseModel): """ Vector Setting. """ + vector_weight: float embedding_provider_name: str embedding_model_name: str @@ -26,6 +28,7 @@ class KeywordSetting(BaseModel): """ Keyword Setting. """ + keyword_weight: float @@ -33,6 +36,7 @@ class WeightedScoreConfig(BaseModel): """ Weighted score Config. """ + vector_setting: VectorSetting keyword_setting: KeywordSetting @@ -41,17 +45,20 @@ class MultipleRetrievalConfig(BaseModel): """ Multiple Retrieval Config. """ + top_k: int score_threshold: Optional[float] = None - reranking_mode: str = 'reranking_model' + reranking_mode: str = "reranking_model" reranking_enable: bool = True reranking_model: Optional[RerankingModelConfig] = None weights: Optional[WeightedScoreConfig] = None + class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str @@ -62,6 +69,7 @@ class SingleRetrievalConfig(BaseModel): """ Single Retrieval Config. """ + model: ModelConfig @@ -69,9 +77,10 @@ class KnowledgeRetrievalNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - type: str = 'knowledge-retrieval' + + type: str = "knowledge-retrieval" query_variable_selector: list[str] dataset_ids: list[str] - retrieval_mode: Literal['single', 'multiple'] + retrieval_mode: Literal["single", "multiple"] multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 2d1ac4731c..19deca162a 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -24,14 +24,11 @@ from models.workflow import WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -45,62 +42,47 @@ class KnowledgeRetrievalNode(BaseNode): # extract variables variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector) query = variable - variables = { - 'query': query - } + variables = {"query": query} if not query: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error="Query is required." + status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." ) # retrieve knowledge try: - results = self._fetch_dataset_retriever( - node_data=node_data, query=query - ) - outputs = { - 'result': results - } + results = self._fetch_dataset_retriever(node_data=node_data, query=query) + outputs = {"result": results} return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - process_data=None, - outputs=outputs + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs ) except Exception as e: logger.exception("Error when running knowledge retrieval node") - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e) - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) - def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[ - dict[str, Any]]: + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: available_datasets = [] dataset_ids = node_data.dataset_ids # Subquery: Count the number of available documents for each dataset - subquery = db.session.query( - Document.dataset_id, - func.count(Document.id).label('available_document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.dataset_id.in_(dataset_ids) - ).group_by(Document.dataset_id).having( - func.count(Document.id) > 0 - ).subquery() + subquery = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.dataset_id.in_(dataset_ids), + ) + .group_by(Document.dataset_id) + .having(func.count(Document.id) > 0) + .subquery() + ) - results = db.session.query(Dataset).join( - subquery, Dataset.id == subquery.c.dataset_id - ).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id.in_(dataset_ids) - ).all() + results = ( + db.session.query(Dataset) + .join(subquery, Dataset.id == subquery.c.dataset_id) + .filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) + .all() + ) for dataset in results: # pass if dataset is not available @@ -117,16 +99,14 @@ class KnowledgeRetrievalNode(BaseNode): model_type_instance = cast(LargeLanguageModel, model_type_instance) # get model schema model_schema = model_type_instance.get_model_schema( - model=model_config.model, - credentials=model_config.credentials + model=model_config.model, credentials=model_config.credentials ) if model_schema: planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: + if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER all_documents = dataset_retrieval.single_retrieve( available_datasets=available_datasets, @@ -137,111 +117,111 @@ class KnowledgeRetrievalNode(BaseNode): query=query, model_config=model_config, model_instance=model_instance, - planning_strategy=planning_strategy + planning_strategy=planning_strategy, ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: - if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model': + if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": reranking_model = { - 'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model.provider, - 'reranking_model_name': node_data.multiple_retrieval_config.reranking_model.model + "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, + "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, } weights = None - elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score': + elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": reranking_model = None weights = { - 'vector_setting': { + "vector_setting": { "vector_weight": node_data.multiple_retrieval_config.weights.vector_setting.vector_weight, "embedding_provider_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_provider_name, "embedding_model_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_model_name, }, - 'keyword_setting': { + "keyword_setting": { "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight - } + }, } else: reranking_model = None weights = None - all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id, - self.user_from.value, - available_datasets, query, - node_data.multiple_retrieval_config.top_k, - node_data.multiple_retrieval_config.score_threshold, - node_data.multiple_retrieval_config.reranking_mode, - reranking_model, - weights, - node_data.multiple_retrieval_config.reranking_enable, - ) + all_documents = dataset_retrieval.multiple_retrieve( + self.app_id, + self.tenant_id, + self.user_id, + self.user_from.value, + available_datasets, + query, + node_data.multiple_retrieval_config.top_k, + node_data.multiple_retrieval_config.score_threshold, + node_data.multiple_retrieval_config.reranking_mode, + reranking_model, + weights, + node_data.multiple_retrieval_config.reranking_enable, + ) context_list = [] if all_documents: document_score_list = {} page_number_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] # both 'page' and 'score' are metadata fields - if item.metadata.get('page'): - page_number_list[item.metadata['doc_id']] = item.metadata['page'] + if item.metadata.get("page"): + page_number_list[item.metadata["doc_id"]] = item.metadata["page"] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ).first() - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() resource_number = 1 if dataset and document: source = { - 'metadata': { - '_source': 'knowledge', - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'document_data_source_type': document.data_source_type, - 'page': page_number_list.get(segment.index_node_id, None), - 'segment_id': segment.id, - 'retriever_from': 'workflow', - 'score': document_score_list.get(segment.index_node_id, None), - 'segment_hit_count': segment.hit_count, - 'segment_word_count': segment.word_count, - 'segment_position': segment.position, - 'segment_index_node_hash': segment.index_node_hash, + "metadata": { + "_source": "knowledge", + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "document_data_source_type": document.data_source_type, + "page": page_number_list.get(segment.index_node_id, None), + "segment_id": segment.id, + "retriever_from": "workflow", + "score": document_score_list.get(segment.index_node_id, None), + "segment_hit_count": segment.hit_count, + "segment_word_count": segment.word_count, + "segment_position": segment.position, + "segment_index_node_hash": segment.index_node_hash, }, - 'title': document.name + "title": document.name, } if segment.answer: - source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" else: - source['content'] = segment.get_sign_content() + source["content"] = segment.get_sign_content() context_list.append(source) resource_number += 1 return context_list @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: KnowledgeRetrievalNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: KnowledgeRetrievalNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -251,11 +231,12 @@ class KnowledgeRetrievalNode(BaseNode): :return: """ variable_mapping = {} - variable_mapping[node_id + '.query'] = node_data.query_variable_selector + variable_mapping[node_id + ".query"] = node_data.query_variable_selector return variable_mapping - def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data: KnowledgeRetrievalNodeData + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data: node data @@ -266,10 +247,7 @@ class KnowledgeRetrievalNode(BaseNode): model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.LLM, - provider=provider_name, - model=model_name + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) provider_model_bundle = model_instance.provider_model_bundle @@ -280,8 +258,7 @@ class KnowledgeRetrievalNode(BaseNode): # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, - model_type=ModelType.LLM + model=model_name, model_type=ModelType.LLM ) if provider_model is None: @@ -297,19 +274,16 @@ class KnowledgeRetrievalNode(BaseNode): # model config completion_params = node_data.single_retrieval_config.model.completion_params stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = node_data.single_retrieval_config.model.mode if not model_mode: raise ValueError("LLM mode is required.") - model_schema = model_type_instance.get_model_schema( - model_name, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: raise ValueError(f"Model {model_name} not exist.") diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 1e48a10bc7..93ee0ac250 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -11,6 +11,7 @@ class ModelConfig(BaseModel): """ Model Config. """ + provider: str name: str mode: str @@ -21,6 +22,7 @@ class ContextConfig(BaseModel): """ Context Config. """ + enabled: bool variable_selector: Optional[list[str]] = None @@ -29,37 +31,47 @@ class VisionConfig(BaseModel): """ Vision Config. """ + class Configs(BaseModel): """ Configs. """ - detail: Literal['low', 'high'] + + detail: Literal["low", "high"] enabled: bool configs: Optional[Configs] = None + class PromptConfig(BaseModel): """ Prompt Config. """ + jinja2_variables: Optional[list[VariableSelector]] = None + class LLMNodeChatModelMessage(ChatModelMessage): """ LLM Node Chat Model Message. """ + jinja2_text: Optional[str] = None + class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): """ LLM Node Chat Model Prompt Template. """ + jinja2_text: Optional[str] = None + class LLMNodeData(BaseNodeData): """ LLM Node Data. """ + model: ModelConfig prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate] prompt_config: Optional[PromptConfig] = None diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index f26ec1b0b5..6dfd27861e 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -45,11 +45,11 @@ if TYPE_CHECKING: from core.file.file_obj import FileVar - class ModelInvokeCompleted(BaseModel): """ Model invoke completed """ + text: str usage: LLMUsage finish_reason: Optional[str] = None @@ -89,7 +89,7 @@ class LLMNode(BaseNode): files = self._fetch_files(node_data, variable_pool) if files: - node_inputs['#files#'] = [file.to_dict() for file in files] + node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value generator = self._fetch_context(node_data, variable_pool) @@ -100,7 +100,7 @@ class LLMNode(BaseNode): yield event if context: - node_inputs['#context#'] = context # type: ignore + node_inputs["#context#"] = context # type: ignore # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) @@ -111,24 +111,22 @@ class LLMNode(BaseNode): # fetch prompt messages prompt_messages, stop = self._fetch_prompt_messages( node_data=node_data, - query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value]) - if node_data.memory else None, + query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, inputs=inputs, files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'model_provider': model_config.provider, - 'model_name': model_config.model, + "model_provider": model_config.provider, + "model_name": model_config.model, } # handle invoke result @@ -136,10 +134,10 @@ class LLMNode(BaseNode): node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, - stop=stop + stop=stop, ) - result_text = '' + result_text = "" usage = LLMUsage.empty_usage() finish_reason = None for event in generator: @@ -156,16 +154,12 @@ class LLMNode(BaseNode): status=WorkflowNodeExecutionStatus.FAILED, error=str(e), inputs=node_inputs, - process_data=process_data + process_data=process_data, ) ) return - outputs = { - 'text': result_text, - 'usage': jsonable_encoder(usage), - 'finish_reason': finish_reason - } + outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} yield RunCompletedEvent( run_result=NodeRunResult( @@ -176,17 +170,19 @@ class LLMNode(BaseNode): metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency + NodeRunMetadataKey.CURRENCY: usage.currency, }, - llm_usage=usage + llm_usage=usage, ) ) - def _invoke_llm(self, node_data_model: ModelConfig, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None) \ - -> Generator[RunEvent | ModelInvokeCompleted, None, None]: + def _invoke_llm( + self, + node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: """ Invoke large language model :param node_data_model: node data model @@ -206,9 +202,7 @@ class LLMNode(BaseNode): ) # handle invoke result - generator = self._handle_invoke_result( - invoke_result=invoke_result - ) + generator = self._handle_invoke_result(invoke_result=invoke_result) usage = LLMUsage.empty_usage() for event in generator: @@ -219,8 +213,9 @@ class LLMNode(BaseNode): # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ - -> Generator[RunEvent | ModelInvokeCompleted, None, None]: + def _handle_invoke_result( + self, invoke_result: LLMResult | Generator + ) -> Generator[RunEvent | ModelInvokeCompleted, None, None]: """ Handle invoke result :param invoke_result: invoke result @@ -231,17 +226,14 @@ class LLMNode(BaseNode): model = None prompt_messages: list[PromptMessage] = [] - full_text = '' + full_text = "" usage = None finish_reason = None for result in invoke_result: text = result.delta.message.content full_text += text - yield RunStreamChunkEvent( - chunk_content=text, - from_variable_selector=[self.node_id, 'text'] - ) + yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) if not model: model = result.model @@ -258,15 +250,11 @@ class LLMNode(BaseNode): if not usage: usage = LLMUsage.empty_usage() - yield ModelInvokeCompleted( - text=full_text, - usage=usage, - finish_reason=finish_reason - ) + yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason) - def _transform_chat_messages(self, - messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: + def _transform_chat_messages( + self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: """ Transform chat messages @@ -275,13 +263,13 @@ class LLMNode(BaseNode): """ if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == 'jinja2' and messages.jinja2_text: + if messages.edition_type == "jinja2" and messages.jinja2_text: messages.text = messages.jinja2_text return messages for message in messages: - if message.edition_type == 'jinja2' and message.jinja2_text: + if message.edition_type == "jinja2" and message.jinja2_text: message.text = message.jinja2_text return messages @@ -300,17 +288,15 @@ class LLMNode(BaseNode): for variable_selector in node_data.prompt_config.jinja2_variables or []: variable = variable_selector.variable - value = variable_pool.get_any( - variable_selector.value_selector - ) + value = variable_pool.get_any(variable_selector.value_selector) def parse_dict(d: dict) -> str: """ Parse dict into string """ # check if it's a context structure - if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: - return d['content'] + if "metadata" in d and "_source" in d["metadata"] and "content" in d: + return d["content"] # else, parse the dict try: @@ -321,7 +307,7 @@ class LLMNode(BaseNode): if isinstance(value, str): value = value elif isinstance(value, list): - result = '' + result = "" for item in value: if isinstance(item, dict): result += parse_dict(item) @@ -331,7 +317,7 @@ class LLMNode(BaseNode): result += str(item) else: result += str(item) - result += '\n' + result += "\n" value = result.strip() elif isinstance(value, dict): value = parse_dict(value) @@ -366,18 +352,19 @@ class LLMNode(BaseNode): for variable_selector in variable_selectors: variable_value = variable_pool.get_any(variable_selector.value_selector) if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") inputs[variable_selector.variable] = variable_value memory = node_data.memory if memory and memory.query_prompt_template: - query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) - .extract_variable_selectors()) + query_variable_selectors = VariableTemplateParser( + template=memory.query_prompt_template + ).extract_variable_selectors() for variable_selector in query_variable_selectors: variable_value = variable_pool.get_any(variable_selector.value_selector) if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") inputs[variable_selector.variable] = variable_value @@ -393,7 +380,7 @@ class LLMNode(BaseNode): if not node_data.vision.enabled: return [] - files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value]) + files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value]) if not files: return [] @@ -415,29 +402,25 @@ class LLMNode(BaseNode): context_value = variable_pool.get_any(node_data.context.variable_selector) if context_value: if isinstance(context_value, str): - yield RunRetrieverResourceEvent( - retriever_resources=[], - context=context_value - ) + yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value) elif isinstance(context_value, list): - context_str = '' + context_str = "" original_retriever_resource = [] for item in context_value: if isinstance(item, str): - context_str += item + '\n' + context_str += item + "\n" else: - if 'content' not in item: - raise ValueError(f'Invalid context structure: {item}') + if "content" not in item: + raise ValueError(f"Invalid context structure: {item}") - context_str += item['content'] + '\n' + context_str += item["content"] + "\n" retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, - context=context_str.strip() + retriever_resources=original_retriever_resource, context=context_str.strip() ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: @@ -446,34 +429,38 @@ class LLMNode(BaseNode): :param context_dict: context dict :return: """ - if ('metadata' in context_dict and '_source' in context_dict['metadata'] - and context_dict['metadata']['_source'] == 'knowledge'): - metadata = context_dict.get('metadata', {}) + if ( + "metadata" in context_dict + and "_source" in context_dict["metadata"] + and context_dict["metadata"]["_source"] == "knowledge" + ): + metadata = context_dict.get("metadata", {}) source = { - 'position': metadata.get('position'), - 'dataset_id': metadata.get('dataset_id'), - 'dataset_name': metadata.get('dataset_name'), - 'document_id': metadata.get('document_id'), - 'document_name': metadata.get('document_name'), - 'data_source_type': metadata.get('document_data_source_type'), - 'segment_id': metadata.get('segment_id'), - 'retriever_from': metadata.get('retriever_from'), - 'score': metadata.get('score'), - 'hit_count': metadata.get('segment_hit_count'), - 'word_count': metadata.get('segment_word_count'), - 'segment_position': metadata.get('segment_position'), - 'index_node_hash': metadata.get('segment_index_node_hash'), - 'content': context_dict.get('content'), - 'page': metadata.get('page'), + "position": metadata.get("position"), + "dataset_id": metadata.get("dataset_id"), + "dataset_name": metadata.get("dataset_name"), + "document_id": metadata.get("document_id"), + "document_name": metadata.get("document_name"), + "data_source_type": metadata.get("document_data_source_type"), + "segment_id": metadata.get("segment_id"), + "retriever_from": metadata.get("retriever_from"), + "score": metadata.get("score"), + "hit_count": metadata.get("segment_hit_count"), + "word_count": metadata.get("segment_word_count"), + "segment_position": metadata.get("segment_position"), + "index_node_hash": metadata.get("segment_index_node_hash"), + "content": context_dict.get("content"), + "page": metadata.get("page"), } return source return None - def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data_model: ModelConfig + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data_model: node data model @@ -484,10 +471,7 @@ class LLMNode(BaseNode): model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.LLM, - provider=provider_name, - model=model_name + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) provider_model_bundle = model_instance.provider_model_bundle @@ -498,8 +482,7 @@ class LLMNode(BaseNode): # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, - model_type=ModelType.LLM + model=model_name, model_type=ModelType.LLM ) if provider_model is None: @@ -515,19 +498,16 @@ class LLMNode(BaseNode): # model config completion_params = node_data_model.completion_params stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = node_data_model.mode if not model_mode: raise ValueError("LLM mode is required.") - model_schema = model_type_instance.get_model_schema( - model_name, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: raise ValueError(f"Model {model_name} not exist.") @@ -543,9 +523,9 @@ class LLMNode(BaseNode): stop=stop, ) - def _fetch_memory(self, node_data_memory: Optional[MemoryConfig], - variable_pool: VariablePool, - model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + def _fetch_memory( + self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance + ) -> Optional[TokenBufferMemory]: """ Fetch memory :param node_data_memory: node data memory @@ -556,35 +536,35 @@ class LLMNode(BaseNode): return None # get conversation id - conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value]) + conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value]) if conversation_id is None: return None # get conversation - conversation = db.session.query(Conversation).filter( - Conversation.app_id == self.app_id, - Conversation.id == conversation_id - ).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) + .first() + ) if not conversation: return None - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) return memory - def _fetch_prompt_messages(self, node_data: LLMNodeData, - query: Optional[str], - query_prompt_template: Optional[str], - inputs: dict[str, str], - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _fetch_prompt_messages( + self, + node_data: LLMNodeData, + query: Optional[str], + query_prompt_template: Optional[str], + inputs: dict[str, str], + files: list["FileVar"], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Fetch prompt messages :param node_data: node data @@ -601,7 +581,7 @@ class LLMNode(BaseNode): prompt_messages = prompt_transform.get_prompt( prompt_template=node_data.prompt_template, inputs=inputs, - query=query if query else '', + query=query if query else "", files=files, context=context, memory_config=node_data.memory, @@ -621,8 +601,11 @@ class LLMNode(BaseNode): if not isinstance(prompt_message.content, str): prompt_message_content = [] for content_item in prompt_message.content: - if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance( - content_item, ImagePromptMessageContent): + if ( + vision_enabled + and content_item.type == PromptMessageContentType.IMAGE + and isinstance(content_item, ImagePromptMessageContent) + ): # Override vision config if LLM node has vision config if vision_detail: content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) @@ -632,15 +615,18 @@ class LLMNode(BaseNode): if len(prompt_message_content) > 1: prompt_message.content = prompt_message_content - elif (len(prompt_message_content) == 1 - and prompt_message_content[0].type == PromptMessageContentType.TEXT): + elif ( + len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT + ): prompt_message.content = prompt_message_content[0].data filtered_prompt_messages.append(prompt_message) if not filtered_prompt_messages: - raise ValueError("No prompt found in the LLM configuration. " - "Please ensure a prompt is properly configured before proceeding.") + raise ValueError( + "No prompt found in the LLM configuration. " + "Please ensure a prompt is properly configured before proceeding." + ) return filtered_prompt_messages, stop @@ -678,7 +664,7 @@ class LLMNode(BaseNode): elif quota_unit == QuotaUnit.CREDITS: used_quota = 1 - if 'gpt-4' in model_instance.model: + if "gpt-4" in model_instance.model: used_quota = 20 else: used_quota = 1 @@ -689,16 +675,13 @@ class LLMNode(BaseNode): Provider.provider_name == model_instance.provider, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used - ).update({'quota_used': Provider.quota_used + used_quota}) + Provider.quota_limit > Provider.quota_used, + ).update({"quota_used": Provider.quota_used + used_quota}) db.session.commit() @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LLMNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -712,11 +695,11 @@ class LLMNode(BaseNode): variable_selectors = [] if isinstance(prompt_template, list): for prompt in prompt_template: - if prompt.edition_type != 'jinja2': + if prompt.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) else: - if prompt_template.edition_type != 'jinja2': + if prompt_template.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() @@ -726,39 +709,38 @@ class LLMNode(BaseNode): memory = node_data.memory if memory and memory.query_prompt_template: - query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) - .extract_variable_selectors()) + query_variable_selectors = VariableTemplateParser( + template=memory.query_prompt_template + ).extract_variable_selectors() for variable_selector in query_variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector if node_data.context.enabled: - variable_mapping['#context#'] = node_data.context.variable_selector + variable_mapping["#context#"] = node_data.context.variable_selector if node_data.vision.enabled: - variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value] + variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value] if node_data.memory: - variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value] + variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] if node_data.prompt_config: enable_jinja = False if isinstance(prompt_template, list): for prompt in prompt_template: - if prompt.edition_type == 'jinja2': + if prompt.edition_type == "jinja2": enable_jinja = True break else: - if prompt_template.edition_type == 'jinja2': + if prompt_template.edition_type == "jinja2": enable_jinja = True if enable_jinja: for variable_selector in node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector - variable_mapping = { - node_id + '.' + key: value for key, value in variable_mapping.items() - } + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} return variable_mapping @@ -775,26 +757,19 @@ class LLMNode(BaseNode): "prompt_templates": { "chat_model": { "prompts": [ - { - "role": "system", - "text": "You are a helpful AI assistant.", - "edition_type": "basic" - } + {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} ] }, "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - }, + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, "prompt": { "text": "Here is the chat histories between human and assistant, inside " - " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", - "edition_type": "basic" + " XML tags.\n\n\n{{" + "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", + "edition_type": "basic", }, - "stop": ["Human:"] - } + "stop": ["Human:"], + }, } - } + }, } diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 8a5684551e..a8a0debe64 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,4 +1,3 @@ - from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState @@ -7,7 +6,8 @@ class LoopNodeData(BaseIterationNodeData): Loop Node Data. """ + class LoopState(BaseIterationState): """ Loop State. - """ \ No newline at end of file + """ diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 526404e30d..fbc68b79cb 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -10,6 +10,7 @@ class LoopNode(BaseNode): """ Loop Node. """ + _node_data_cls = LoopNodeData _node_type = NodeType.LOOP @@ -21,14 +22,16 @@ class LoopNode(BaseNode): """ Get conditions. """ - node_id = node_config.get('id') + node_id = node_config.get("id") if not node_id: return [] # TODO waiting for implementation - return [Condition( - variable_selector=[node_id, 'index'], - comparison_operator="≤", - value_type="value_selector", - value_selector=[] - )] + return [ + Condition( + variable_selector=[node_id, "index"], + comparison_operator="≤", + value_type="value_selector", + value_selector=[], + ) + ] diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 7bb123b126..802ed31e27 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -8,47 +8,52 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str completion_params: dict[str, Any] = {} + class ParameterConfig(BaseModel): """ Parameter Config. """ + name: str - type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]'] + type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"] options: Optional[list[str]] = None description: str required: bool - @field_validator('name', mode='before') + @field_validator("name", mode="before") @classmethod def validate_name(cls, value) -> str: if not value: - raise ValueError('Parameter name is required') - if value in ['__reason', '__is_success']: - raise ValueError('Invalid parameter name, __reason and __is_success are reserved') + raise ValueError("Parameter name is required") + if value in ["__reason", "__is_success"]: + raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return value + class ParameterExtractorNodeData(BaseNodeData): """ Parameter Extractor Node Data. """ + model: ModelConfig query: list[str] parameters: list[ParameterConfig] instruction: Optional[str] = None memory: Optional[MemoryConfig] = None - reasoning_mode: Literal['function_call', 'prompt'] + reasoning_mode: Literal["function_call", "prompt"] - @field_validator('reasoning_mode', mode='before') + @field_validator("reasoning_mode", mode="before") @classmethod def set_reasoning_mode(cls, v) -> str: - return v or 'function_call' + return v or "function_call" def get_parameter_json_schema(self) -> dict: """ @@ -56,32 +61,26 @@ class ParameterExtractorNodeData(BaseNodeData): :return: parameter json schema """ - parameters = { - 'type': 'object', - 'properties': {}, - 'required': [] - } + parameters = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: - parameter_schema = { - 'description': parameter.description - } + parameter_schema = {"description": parameter.description} - if parameter.type in ['string', 'select']: - parameter_schema['type'] = 'string' - elif parameter.type.startswith('array'): - parameter_schema['type'] = 'array' + if parameter.type in ["string", "select"]: + parameter_schema["type"] = "string" + elif parameter.type.startswith("array"): + parameter_schema["type"] = "array" nested_type = parameter.type[6:-1] - parameter_schema['items'] = {'type': nested_type} + parameter_schema["items"] = {"type": nested_type} else: - parameter_schema['type'] = parameter.type + parameter_schema["type"] = parameter.type - if parameter.type == 'select': - parameter_schema['enum'] = parameter.options + if parameter.type == "select": + parameter_schema["enum"] = parameter.options + + parameters["properties"][parameter.name] = parameter_schema - parameters['properties'][parameter.name] = parameter_schema - if parameter.required: - parameters['required'].append(parameter.name) + parameters["required"].append(parameter.name) - return parameters \ No newline at end of file + return parameters diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2e65705f10..131d26b19e 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -45,6 +45,7 @@ class ParameterExtractorNode(LLMNode): """ Parameter Extractor Node. """ + _node_data_cls = ParameterExtractorNodeData _node_type = NodeType.PARAMETER_EXTRACTOR @@ -57,11 +58,8 @@ class ParameterExtractorNode(LLMNode): "model": { "prompt_templates": { "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - }, - "stop": ["Human:"] + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, + "stop": ["Human:"], } } } @@ -78,9 +76,9 @@ class ParameterExtractorNode(LLMNode): query = variable inputs = { - 'query': query, - 'parameters': jsonable_encoder(node_data.parameters), - 'instruction': jsonable_encoder(node_data.instruction), + "query": query, + "parameters": jsonable_encoder(node_data.parameters), + "instruction": jsonable_encoder(node_data.instruction), } model_instance, model_config = self._fetch_model_config(node_data.model) @@ -95,30 +93,29 @@ class ParameterExtractorNode(LLMNode): # fetch memory memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance) - if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ - and node_data.reasoning_mode == 'function_call': - # use function call + if ( + set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} + and node_data.reasoning_mode == "function_call" + ): + # use function call prompt_messages, prompt_message_tools = self._generate_function_call_prompt( node_data, query, self.graph_runtime_state.variable_pool, model_config, memory ) else: # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt(node_data, - query, - self.graph_runtime_state.variable_pool, - model_config, - memory) + prompt_messages = self._generate_prompt_engineering_prompt( + node_data, query, self.graph_runtime_state.variable_pool, model_config, memory + ) prompt_message_tools = [] process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'usage': None, - 'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), - 'tool_call': None, + "usage": None, + "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), + "tool_call": None, } try: @@ -129,20 +126,17 @@ class ParameterExtractorNode(LLMNode): tools=prompt_message_tools, stop=model_config.stop, ) - process_data['usage'] = jsonable_encoder(usage) - process_data['tool_call'] = jsonable_encoder(tool_call) - process_data['llm_text'] = text + process_data["usage"] = jsonable_encoder(usage) + process_data["tool_call"] = jsonable_encoder(tool_call) + process_data["llm_text"] = text except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=inputs, process_data=process_data, - outputs={ - '__is_success': 0, - '__reason': str(e) - }, + outputs={"__is_success": 0, "__reason": str(e)}, error=str(e), - metadata={} + metadata={}, ) error = None @@ -167,24 +161,23 @@ class ParameterExtractorNode(LLMNode): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, - outputs={ - '__is_success': 1 if not error else 0, - '__reason': error, - **result - }, + outputs={"__is_success": 1 if not error else 0, "__reason": error, **result}, metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency + NodeRunMetadataKey.CURRENCY: usage.currency, }, - llm_usage=usage + llm_usage=usage, ) - def _invoke_llm(self, node_data_model: ModelConfig, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: + def _invoke_llm( + self, + node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + stop: list[str], + ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: """ Invoke large language model :param node_data_model: node data model @@ -217,32 +210,35 @@ class ParameterExtractorNode(LLMNode): return text, usage, tool_call - def _generate_function_call_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: + def _generate_function_call_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: """ Generate function call prompt. """ - query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps( - node_data.get_parameter_json_schema())) + query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format( + content=query, structure=json.dumps(node_data.get_parameter_json_schema()) + ) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') - prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, - rest_token) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + prompt_template = self._get_function_calling_prompt_template( + node_data, query, variable_pool, memory, rest_token + ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) # find last user message @@ -255,124 +251,125 @@ class ParameterExtractorNode(LLMNode): example_messages = [] for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE: id = uuid.uuid4().hex - example_messages.extend([ - UserPromptMessage(content=example['user']['query']), - AssistantPromptMessage( - content=example['assistant']['text'], - tool_calls=[ - AssistantPromptMessage.ToolCall( - id=id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=example['assistant']['function_call']['name'], - arguments=json.dumps(example['assistant']['function_call']['parameters'] - ) - )) - ] - ), - ToolPromptMessage( - content='Great! You have called the function with the correct parameters.', - tool_call_id=id - ), - AssistantPromptMessage( - content='I have extracted the parameters, let\'s move on.', - ) - ]) + example_messages.extend( + [ + UserPromptMessage(content=example["user"]["query"]), + AssistantPromptMessage( + content=example["assistant"]["text"], + tool_calls=[ + AssistantPromptMessage.ToolCall( + id=id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=example["assistant"]["function_call"]["name"], + arguments=json.dumps(example["assistant"]["function_call"]["parameters"]), + ), + ) + ], + ), + ToolPromptMessage( + content="Great! You have called the function with the correct parameters.", tool_call_id=id + ), + AssistantPromptMessage( + content="I have extracted the parameters, let's move on.", + ), + ] + ) - prompt_messages = prompt_messages[:last_user_message_idx] + \ - example_messages + prompt_messages[last_user_message_idx:] + prompt_messages = ( + prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] + ) # generate tool tool = PromptMessageTool( name=FUNCTION_CALLING_EXTRACTOR_NAME, - description='Extract parameters from the natural language text', + description="Extract parameters from the natural language text", parameters=node_data.get_parameter_json_schema(), ) return prompt_messages, [tool] - def _generate_prompt_engineering_prompt(self, - data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_prompt( + self, + data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: """ Generate prompt engineering prompt. """ model_mode = ModelMode.value_of(data.model.mode) if model_mode == ModelMode.COMPLETION: - return self._generate_prompt_engineering_completion_prompt( - data, query, variable_pool, model_config, memory - ) + return self._generate_prompt_engineering_completion_prompt(data, query, variable_pool, model_config, memory) elif model_mode == ModelMode.CHAT: - return self._generate_prompt_engineering_chat_prompt( - data, query, variable_pool, model_config, memory - ) + return self._generate_prompt_engineering_chat_prompt(data, query, variable_pool, model_config, memory) else: raise ValueError(f"Invalid model mode: {model_mode}") - def _generate_prompt_engineering_completion_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_completion_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: """ Generate completion prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') - prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, - rest_token) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + prompt_template = self._get_prompt_engineering_prompt_template( + node_data, query, variable_pool, memory, rest_token + ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, - inputs={ - 'structure': json.dumps(node_data.get_parameter_json_schema()) - }, - query='', + inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, + query="", files=[], - context='', + context="", memory_config=node_data.memory, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages - def _generate_prompt_engineering_chat_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_chat_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + ) -> list[PromptMessage]: """ Generate chat prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") prompt_template = self._get_prompt_engineering_prompt_template( node_data, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(node_data.get_parameter_json_schema()), - text=query + structure=json.dumps(node_data.get_parameter_json_schema()), text=query ), - variable_pool, memory, rest_token + variable_pool, + memory, + rest_token, ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) # find last user message @@ -384,18 +381,23 @@ class ParameterExtractorNode(LLMNode): # add example messages before last user message example_messages = [] for example in CHAT_EXAMPLE: - example_messages.extend([ - UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(example['user']['json']), - text=example['user']['query'], - )), - AssistantPromptMessage( - content=json.dumps(example['assistant']['json']), - ) - ]) + example_messages.extend( + [ + UserPromptMessage( + content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( + structure=json.dumps(example["user"]["json"]), + text=example["user"]["query"], + ) + ), + AssistantPromptMessage( + content=json.dumps(example["assistant"]["json"]), + ), + ] + ) - prompt_messages = prompt_messages[:last_user_message_idx] + \ - example_messages + prompt_messages[last_user_message_idx:] + prompt_messages = ( + prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] + ) return prompt_messages @@ -410,28 +412,28 @@ class ParameterExtractorNode(LLMNode): if parameter.required and parameter.name not in result: raise ValueError(f"Parameter {parameter.name} is required") - if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options: + if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options: raise ValueError(f"Invalid `select` value for parameter {parameter.name}") - if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float): + if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float): raise ValueError(f"Invalid `number` value for parameter {parameter.name}") - if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool): + if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool): raise ValueError(f"Invalid `bool` value for parameter {parameter.name}") - if parameter.type == 'string' and not isinstance(result.get(parameter.name), str): + if parameter.type == "string" and not isinstance(result.get(parameter.name), str): raise ValueError(f"Invalid `string` value for parameter {parameter.name}") - if parameter.type.startswith('array'): + if parameter.type.startswith("array"): if not isinstance(result.get(parameter.name), list): raise ValueError(f"Invalid `array` value for parameter {parameter.name}") nested_type = parameter.type[6:-1] for item in result.get(parameter.name): - if nested_type == 'number' and not isinstance(item, int | float): + if nested_type == "number" and not isinstance(item, int | float): raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") - if nested_type == 'string' and not isinstance(item, str): + if nested_type == "string" and not isinstance(item, str): raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}") - if nested_type == 'object' and not isinstance(item, dict): + if nested_type == "object" and not isinstance(item, dict): raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}") return result @@ -443,12 +445,12 @@ class ParameterExtractorNode(LLMNode): for parameter in data.parameters: if parameter.name in result: # transform value - if parameter.type == 'number': + if parameter.type == "number": if isinstance(result[parameter.name], int | float): transformed_result[parameter.name] = result[parameter.name] elif isinstance(result[parameter.name], str): try: - if '.' in result[parameter.name]: + if "." in result[parameter.name]: result[parameter.name] = float(result[parameter.name]) else: result[parameter.name] = int(result[parameter.name]) @@ -465,40 +467,40 @@ class ParameterExtractorNode(LLMNode): # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true') # elif isinstance(result[parameter.name], int): # transformed_result[parameter.name] = bool(result[parameter.name]) - elif parameter.type in ['string', 'select']: + elif parameter.type in ["string", "select"]: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] - elif parameter.type.startswith('array'): + elif parameter.type.startswith("array"): if isinstance(result[parameter.name], list): nested_type = parameter.type[6:-1] transformed_result[parameter.name] = [] for item in result[parameter.name]: - if nested_type == 'number': + if nested_type == "number": if isinstance(item, int | float): transformed_result[parameter.name].append(item) elif isinstance(item, str): try: - if '.' in item: + if "." in item: transformed_result[parameter.name].append(float(item)) else: transformed_result[parameter.name].append(int(item)) except ValueError: pass - elif nested_type == 'string': + elif nested_type == "string": if isinstance(item, str): transformed_result[parameter.name].append(item) - elif nested_type == 'object': + elif nested_type == "object": if isinstance(item, dict): transformed_result[parameter.name].append(item) if parameter.name not in transformed_result: - if parameter.type == 'number': + if parameter.type == "number": transformed_result[parameter.name] = 0 - elif parameter.type == 'bool': + elif parameter.type == "bool": transformed_result[parameter.name] = False - elif parameter.type in ['string', 'select']: - transformed_result[parameter.name] = '' - elif parameter.type.startswith('array'): + elif parameter.type in ["string", "select"]: + transformed_result[parameter.name] = "" + elif parameter.type.startswith("array"): transformed_result[parameter.name] = [] return transformed_result @@ -514,24 +516,24 @@ class ParameterExtractorNode(LLMNode): """ stack = [] for i, c in enumerate(text): - if c == '{' or c == '[': + if c == "{" or c == "[": stack.append(c) - elif c == '}' or c == ']': + elif c == "}" or c == "]": # check if stack is empty if not stack: return text[:i] # check if the last element in stack is matching - if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['): + if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): stack.pop() if not stack: - return text[:i + 1] + return text[: i + 1] else: return text[:i] return None # extract json from the text for idx in range(len(result)): - if result[idx] == '{' or result[idx] == '[': + if result[idx] == "{" or result[idx] == "[": json_str = extract_json(result[idx:]) if json_str: try: @@ -554,12 +556,12 @@ class ParameterExtractorNode(LLMNode): """ result = {} for parameter in data.parameters: - if parameter.type == 'number': + if parameter.type == "number": result[parameter.name] = 0 - elif parameter.type == 'bool': + elif parameter.type == "bool": result[parameter.name] = False - elif parameter.type in ['string', 'select']: - result[parameter.name] = '' + elif parameter.type in ["string", "select"]: + result[parameter.name] = "" return result @@ -575,71 +577,76 @@ class ParameterExtractorNode(LLMNode): return variable_template_parser.format(inputs) - def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> list[ChatModelMessage]: + def _get_function_calling_prompt_template( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> list[ChatModelMessage]: model_mode = ModelMode.value_of(node_data.model.mode) input_text = query - memory_str = '' - instruction = self._render_instruction(node_data.instruction or '', variable_pool) + memory_str = "" + instruction = self._render_instruction(node_data.instruction or "", variable_pool) if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) - ) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=input_text + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] else: raise ValueError(f"Model mode {model_mode} not support.") - def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> list[ChatModelMessage]: - + def _get_prompt_engineering_prompt_template( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> list[ChatModelMessage]: model_mode = ModelMode.value_of(node_data.model.mode) input_text = query - memory_str = '' - instruction = self._render_instruction(node_data.instruction or '', variable_pool) + memory_str = "" + instruction = self._render_instruction(node_data.instruction or "", variable_pool) if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) - ) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=input_text + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] elif model_mode == ModelMode.COMPLETION: return CompletionModelPromptTemplate( - text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str, - text=input_text, - instruction=instruction) - .replace('{γγγ', '') - .replace('}γγγ', '') + text=COMPLETION_GENERATE_JSON_PROMPT.format( + histories=memory_str, text=input_text, instruction=instruction + ) + .replace("{γγγ", "") + .replace("}γγγ", "") ) else: raise ValueError(f"Model mode {model_mode} not support.") - def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - context: Optional[str]) -> int: + def _calculate_rest_token( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str], + ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) model_instance, model_config = self._fetch_model_config(node_data.model) @@ -659,12 +666,12 @@ class ParameterExtractorNode(LLMNode): prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) rest_tokens = 2000 @@ -673,26 +680,28 @@ class ParameterExtractorNode(LLMNode): model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) - curr_message_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, - prompt_messages - ) + 1000 # add 1000 to ensure tool call messages + curr_message_tokens = ( + model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000 + ) # add 1000 to ensure tool call messages max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data_model: ModelConfig + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config. """ @@ -703,10 +712,7 @@ class ParameterExtractorNode(LLMNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ParameterExtractorNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: ParameterExtractorNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -715,17 +721,13 @@ class ParameterExtractorNode(LLMNode): :param node_data: node data :return: """ - variable_mapping = { - 'query': node_data.query - } + variable_mapping = {"query": node_data.query} if node_data.instruction: variable_template_parser = VariableTemplateParser(template=node_data.instruction) for selector in variable_template_parser.extract_variable_selectors(): variable_mapping[selector.variable] = selector.value_selector - variable_mapping = { - node_id + '.' + key: value for key, value in variable_mapping.items() - } + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} return variable_mapping diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index 499c58d505..c63fded4d0 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -1,4 +1,4 @@ -FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters' +FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. ### Task @@ -35,61 +35,48 @@ FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information fr """ -FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{ - 'user': { - 'query': 'What is the weather today in SF?', - 'function': { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'type': 'object', - 'properties': { - 'location': { - 'type': 'string', - 'description': 'The location to get the weather information', - 'required': True +FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [ + { + "user": { + "query": "What is the weather today in SF?", + "function": { + "name": FUNCTION_CALLING_EXTRACTOR_NAME, + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather information", + "required": True, + }, }, + "required": ["location"], }, - 'required': ['location'] - } - } + }, + }, + "assistant": { + "text": "I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.", + "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}}, + }, }, - 'assistant': { - 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.', - 'function_call' : { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'location': 'San Francisco' - } - } - } -}, { - 'user': { - 'query': 'I want to eat some apple pie.', - 'function': { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'type': 'object', - 'properties': { - 'food': { - 'type': 'string', - 'description': 'The food to eat', - 'required': True - } + { + "user": { + "query": "I want to eat some apple pie.", + "function": { + "name": FUNCTION_CALLING_EXTRACTOR_NAME, + "parameters": { + "type": "object", + "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, + "required": ["food"], }, - 'required': ['food'] - } - } + }, + }, + "assistant": { + "text": "I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.", + "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}}, + }, }, - 'assistant': { - 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.', - 'function_call' : { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'food': 'apple pie' - } - } - } -}] +] COMPLETION_GENERATE_JSON_PROMPT = """### Instructions: Some extra information are provided below, I should always follow the instructions as possible as I can. @@ -161,46 +148,33 @@ Inside XML tags, there is a text that you should convert to a JSON """ -CHAT_EXAMPLE = [{ - 'user': { - 'query': 'What is the weather today in SF?', - 'json': { - 'type': 'object', - 'properties': { - 'location': { - 'type': 'string', - 'description': 'The location to get the weather information', - 'required': True - } +CHAT_EXAMPLE = [ + { + "user": { + "query": "What is the weather today in SF?", + "json": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather information", + "required": True, + } + }, + "required": ["location"], }, - 'required': ['location'] - } + }, + "assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}}, }, - 'assistant': { - 'text': 'I need to output a valid JSON object.', - 'json': { - 'location': 'San Francisco' - } - } -}, { - 'user': { - 'query': 'I want to eat some apple pie.', - 'json': { - 'type': 'object', - 'properties': { - 'food': { - 'type': 'string', - 'description': 'The food to eat', - 'required': True - } + { + "user": { + "query": "I want to eat some apple pie.", + "json": { + "type": "object", + "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, + "required": ["food"], }, - 'required': ['food'] - } + }, + "assistant": {"text": "I need to output a valid JSON object.", "json": {"result": "apple pie"}}, }, - 'assistant': { - 'text': 'I need to output a valid JSON object.', - 'json': { - 'result': 'apple pie' - } - } -}] \ No newline at end of file +] diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index c0b0a8b696..40f7ce7582 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -8,8 +8,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str @@ -20,6 +21,7 @@ class ClassConfig(BaseModel): """ Class Config. """ + id: str name: str @@ -28,8 +30,9 @@ class QuestionClassifierNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ + query_variable_selector: list[str] - type: str = 'question-classifier' + type: str = "question-classifier" model: ModelConfig classes: list[ClassConfig] instruction: Optional[str] = None diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index ecab8db9b6..d860f848ec 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -45,34 +45,25 @@ class QuestionClassifierNode(LLMNode): # extract variables variable = variable_pool.get(node_data.query_variable_selector) query = variable.value if variable else None - variables = { - 'query': query - } + variables = {"query": query} # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) # fetch instruction - instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else '' + instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else "" node_data.instruction = instruction # fetch prompt messages prompt_messages, stop = self._fetch_prompt( - node_data=node_data, - context='', - query=query, - memory=memory, - model_config=model_config + node_data=node_data, context="", query=query, memory=memory, model_config=model_config ) # handle invoke result generator = self._invoke_llm( - node_data_model=node_data.model, - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop + node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop ) - result_text = '' + result_text = "" usage = LLMUsage.empty_usage() finish_reason = None for event in generator: @@ -87,8 +78,8 @@ class QuestionClassifierNode(LLMNode): try: result_text_json = parse_and_check_json_markdown(result_text, []) # result_text_json = json.loads(result_text.strip('```JSON\n')) - if 'category_name' in result_text_json and 'category_id' in result_text_json: - category_id_result = result_text_json['category_id'] + if "category_name" in result_text_json and "category_id" in result_text_json: + category_id_result = result_text_json["category_id"] classes = node_data.classes classes_map = {class_.id: class_.name for class_ in classes} category_ids = [_class.id for _class in classes] @@ -100,17 +91,14 @@ class QuestionClassifierNode(LLMNode): logging.error(f"Failed to parse result text: {result_text}") try: process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'usage': jsonable_encoder(usage), - 'finish_reason': finish_reason - } - outputs = { - 'class_name': category_name + "usage": jsonable_encoder(usage), + "finish_reason": finish_reason, } + outputs = {"class_name": category_name} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -121,9 +109,9 @@ class QuestionClassifierNode(LLMNode): metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency + NodeRunMetadataKey.CURRENCY: usage.currency, }, - llm_usage=usage + llm_usage=usage, ) except ValueError as e: @@ -134,17 +122,14 @@ class QuestionClassifierNode(LLMNode): metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency + NodeRunMetadataKey.CURRENCY: usage.currency, }, - llm_usage=usage + llm_usage=usage, ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: QuestionClassifierNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: QuestionClassifierNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -153,7 +138,7 @@ class QuestionClassifierNode(LLMNode): :param node_data: node data :return: """ - variable_mapping = {'query': node_data.query_variable_selector} + variable_mapping = {"query": node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: variable_template_parser = VariableTemplateParser(template=node_data.instruction) @@ -161,10 +146,8 @@ class QuestionClassifierNode(LLMNode): for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - variable_mapping = { - node_id + '.' + key: value for key, value in variable_mapping.items() - } - + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} + return variable_mapping @classmethod @@ -174,19 +157,16 @@ class QuestionClassifierNode(LLMNode): :param filters: filter by node config parameters. :return: """ - return { - "type": "question-classifier", - "config": { - "instructions": "" - } - } + return {"type": "question-classifier", "config": {"instructions": ""}} - def _fetch_prompt(self, node_data: QuestionClassifierNodeData, - query: str, - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _fetch_prompt( + self, + node_data: QuestionClassifierNodeData, + query: str, + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Fetch prompt :param node_data: node data @@ -202,118 +182,122 @@ class QuestionClassifierNode(LLMNode): prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) stop = model_config.stop return prompt_messages, stop - def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str, - model_config: ModelConfigWithCredentialsEntity, - context: Optional[str]) -> int: + def _calculate_rest_token( + self, + node_data: QuestionClassifierNodeData, + query: str, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str], + ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_template = self._get_prompt_template(node_data, query, None, 2000) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) - curr_message_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: + def _get_prompt_template( + self, + node_data: QuestionClassifierNodeData, + query: str, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: model_mode = ModelMode.value_of(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: - category = { - 'category_id': class_.id, - 'category_name': class_.name - } + category = {"category_id": class_.id, "category_name": class_.name} categories.append(category) - instruction = node_data.instruction if node_data.instruction else '' + instruction = node_data.instruction if node_data.instruction else "" input_text = query - memory_str = '' + memory_str = "" if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) prompt_messages = [] if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) + role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) prompt_messages.append(system_prompt_messages) user_prompt_message_1 = ChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_1 + role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1 ) prompt_messages.append(user_prompt_message_1) assistant_prompt_message_1 = ChatModelMessage( - role=PromptMessageRole.ASSISTANT, - text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 + role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 ) prompt_messages.append(assistant_prompt_message_1) user_prompt_message_2 = ChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_2 + role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 ) prompt_messages.append(user_prompt_message_2) assistant_prompt_message_2 = ChatModelMessage( - role=PromptMessageRole.ASSISTANT, - text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 + role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 ) prompt_messages.append(assistant_prompt_message_2) user_prompt_message_3 = ChatModelMessage( role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction) + text=QUESTION_CLASSIFIER_USER_PROMPT_3.format( + input_text=input_text, + categories=json.dumps(categories, ensure_ascii=False), + classification_instructions=instruction, + ), ) prompt_messages.append(user_prompt_message_3) return prompt_messages elif model_mode == ModelMode.COMPLETION: return CompletionModelPromptTemplate( - text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str, - input_text=input_text, - categories=json.dumps(categories), - classification_instructions=instruction, - ensure_ascii=False) + text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( + histories=memory_str, + input_text=input_text, + categories=json.dumps(categories), + classification_instructions=instruction, + ensure_ascii=False, + ) ) else: @@ -329,14 +313,12 @@ class QuestionClassifierNode(LLMNode): variable = variable_pool.get(variable_selector.value_selector) variable_value = variable.value if variable else None if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') + raise ValueError(f"Variable {variable_selector.variable} not found") inputs[variable_selector.variable] = variable_value prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - instruction = prompt_template.format( - prompt_inputs - ) + instruction = prompt_template.format(prompt_inputs) return instruction diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py index e0de148cc2..581f986922 100644 --- a/api/core/workflow/nodes/question_classifier/template_prompts.py +++ b/api/core/workflow/nodes/question_classifier/template_prompts.py @@ -1,5 +1,3 @@ - - QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ ### Job Description', You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index b81ce15bd7..11d2ebe5dd 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -10,4 +10,5 @@ class StartNodeData(BaseNodeData): """ Start Node Data """ + variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 69cdec6a92..96c887c58d 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,4 +1,3 @@ - from collections.abc import Mapping, Sequence from typing import Any @@ -22,20 +21,13 @@ class StartNode(BaseNode): system_inputs = self.graph_runtime_state.variable_pool.system_variables for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var] + node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - outputs=node_inputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: StartNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: StartNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py index d9099a8118..e934d69fa3 100644 --- a/api/core/workflow/nodes/template_transform/entities.py +++ b/api/core/workflow/nodes/template_transform/entities.py @@ -1,5 +1,3 @@ - - from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -8,5 +6,6 @@ class TemplateTransformNodeData(BaseNodeData): """ Code Node Data. """ + variables: list[VariableSelector] - template: str \ No newline at end of file + template: str diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index b14a394a0a..2829144ead 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -8,7 +8,7 @@ from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus -MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000')) +MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) class TemplateTransformNode(BaseNode): @@ -24,15 +24,7 @@ class TemplateTransformNode(BaseNode): """ return { "type": "template-transform", - "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - } - ], - "template": "{{ arg1 }}" - } + "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, } def _run(self) -> NodeRunResult: @@ -51,38 +43,25 @@ class TemplateTransformNode(BaseNode): # Run code try: result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=node_data.template, - inputs=variables + language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables ) except CodeExecutionException as e: - return NodeRunResult( - inputs=variables, - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e) - ) + return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) - if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: + if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: return NodeRunResult( inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, - error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters" + error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters", ) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs={ - 'output': result['result'] - } + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]} ) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: TemplateTransformNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -92,5 +71,6 @@ class TemplateTransformNode(BaseNode): :return: """ return { - node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + node_id + "." + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 5da5cd0727..28fbf789fd 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -8,46 +8,47 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class ToolEntity(BaseModel): provider_id: str - provider_type: Literal['builtin', 'api', 'workflow'] - provider_name: str # redundancy + provider_type: Literal["builtin", "api", "workflow"] + provider_name: str # redundancy tool_name: str - tool_label: str # redundancy + tool_label: str # redundancy tool_configurations: dict[str, Any] - @field_validator('tool_configurations', mode='before') + @field_validator("tool_configurations", mode="before") @classmethod def validate_tool_configurations(cls, value, values: ValidationInfo): if not isinstance(value, dict): - raise ValueError('tool_configurations must be a dictionary') - - for key in values.data.get('tool_configurations', {}).keys(): - value = values.data.get('tool_configurations', {}).get(key) + raise ValueError("tool_configurations must be a dictionary") + + for key in values.data.get("tool_configurations", {}).keys(): + value = values.data.get("tool_configurations", {}).get(key) if not isinstance(value, str | int | float | bool): - raise ValueError(f'{key} must be a string') - + raise ValueError(f"{key} must be a string") + return value + class ToolNodeData(BaseNodeData, ToolEntity): class ToolInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] - type: Literal['mixed', 'variable', 'constant'] + type: Literal["mixed", "variable", "constant"] - @field_validator('type', mode='before') + @field_validator("type", mode="before") @classmethod def check_type(cls, value, validation_info: ValidationInfo): typ = value - value = validation_info.data.get('value') - if typ == 'mixed' and not isinstance(value, str): - raise ValueError('value must be a string') - elif typ == 'variable': + value = validation_info.data.get("value") + if typ == "mixed" and not isinstance(value, str): + raise ValueError("value must be a string") + elif typ == "variable": if not isinstance(value, list): - raise ValueError('value must be a list') + raise ValueError("value must be a list") for val in value: if not isinstance(val, str): - raise ValueError('value must be a list of strings') - elif typ == 'constant' and not isinstance(value, str | int | float | bool): - raise ValueError('value must be a string, int, float, or bool') + raise ValueError("value must be a list of strings") + elif typ == "constant" and not isinstance(value, str | int | float | bool): + raise ValueError("value must be a string, int, float, or bool") return typ """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index feedeb6dad..e55adfc1f4 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -34,10 +34,7 @@ class ToolNode(BaseNode): node_data = cast(ToolNodeData, self.node_data) # fetch tool icon - tool_info = { - 'provider_type': node_data.provider_type, - 'provider_id': node_data.provider_id - } + tool_info = {"provider_type": node_data.provider_type, "provider_id": node_data.provider_id} # get tool runtime try: @@ -48,16 +45,21 @@ class ToolNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - error=f'Failed to get tool runtime: {str(e)}' + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to get tool runtime: {str(e)}", ) # get parameters tool_parameters = tool_runtime.get_runtime_parameters() or [] - parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data) - parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True) + parameters = self._generate_parameters( + tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data + ) + parameters_for_log = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=node_data, + for_log=True, + ) try: messages = ToolEngine.workflow_invoke( @@ -72,10 +74,8 @@ class ToolNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - error=f'Failed to invoke tool: {str(e)}', + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to invoke tool: {str(e)}", ) # convert tool messages @@ -83,15 +83,9 @@ class ToolNode(BaseNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - 'text': plain_text, - 'files': files, - 'json': json - }, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - inputs=parameters_for_log + outputs={"text": plain_text, "files": files, "json": json}, + metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + inputs=parameters_for_log, ) def _generate_parameters( @@ -123,12 +117,10 @@ class ToolNode(BaseNode): result[parameter_name] = None continue if parameter.type == ToolParameter.ToolParameterType.FILE: - result[parameter_name] = [ - v.to_dict() for v in self._fetch_files(variable_pool) - ] + result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)] else: tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == 'variable': + if tool_input.type == "variable": # TODO: check if the variable exists in the variable pool parameter_value = variable_pool.get(tool_input.value).value else: @@ -142,12 +134,11 @@ class ToolNode(BaseNode): return result def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - variable = variable_pool.get(['sys', SystemVariableKey.FILES.value]) + variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\ - -> tuple[str, list[FileVar], list[dict]]: + def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar], list[dict]]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -172,38 +163,44 @@ class ToolNode(BaseNode): result = [] for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: + if ( + response.type == ToolInvokeMessage.MessageType.IMAGE_LINK + or response.type == ToolInvokeMessage.MessageType.IMAGE + ): url = response.message ext = path.splitext(url)[1] - mimetype = response.meta.get('mime_type', 'image/jpeg') - filename = response.save_as or url.split('/')[-1] - transfer_method = response.meta.get('transfer_method', FileTransferMethod.TOOL_FILE) + mimetype = response.meta.get("mime_type", "image/jpeg") + filename = response.save_as or url.split("/")[-1] + transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) # get tool file id - tool_file_id = url.split('/')[-1].split('.')[0] - result.append(FileVar( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - url=url, - related_id=tool_file_id, - filename=filename, - extension=ext, - mime_type=mimetype, - )) + tool_file_id = url.split("/")[-1].split(".")[0] + result.append( + FileVar( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=transfer_method, + url=url, + related_id=tool_file_id, + filename=filename, + extension=ext, + mime_type=mimetype, + ) + ) elif response.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id - tool_file_id = response.message.split('/')[-1].split('.')[0] - result.append(FileVar( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file_id, - filename=response.save_as, - extension=path.splitext(response.save_as)[1], - mime_type=response.meta.get('mime_type', 'application/octet-stream'), - )) + tool_file_id = response.message.split("/")[-1].split(".")[0] + result.append( + FileVar( + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=tool_file_id, + filename=response.save_as, + extension=path.splitext(response.save_as)[1], + mime_type=response.meta.get("mime_type", "application/octet-stream"), + ) + ) elif response.type == ToolInvokeMessage.MessageType.LINK: pass # TODO: @@ -213,21 +210,23 @@ class ToolNode(BaseNode): """ Extract tool response text """ - return '\n'.join([ - f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else - f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else '' - for message in tool_response - ]) + return "\n".join( + [ + f"{message.message}" + if message.type == ToolInvokeMessage.MessageType.TEXT + else f"Link: {message.message}" + if message.type == ToolInvokeMessage.MessageType.LINK + else "" + for message in tool_response + ] + ) def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON] @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ToolNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -239,17 +238,15 @@ class ToolNode(BaseNode): result = {} for parameter_name in node_data.tool_parameters: input = node_data.tool_parameters[parameter_name] - if input.type == 'mixed': + if input.type == "mixed": selectors = VariableTemplateParser(input.value).extract_variable_selectors() for selector in selectors: result[selector.variable] = selector.value_selector - elif input.type == 'variable': + elif input.type == "variable": result[parameter_name] = input.value - elif input.type == 'constant': + elif input.type == "constant": pass - result = { - node_id + '.' + key: value for key, value in result.items() - } + result = {node_id + "." + key: value for key, value in result.items()} return result diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index e5de38dc0f..eb893a04e3 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -1,5 +1,3 @@ - - from typing import Literal, Optional from pydantic import BaseModel @@ -11,23 +9,27 @@ class AdvancedSettings(BaseModel): """ Advanced setting. """ + group_enabled: bool class Group(BaseModel): """ Group. """ - output_type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]'] + + output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] variables: list[list[str]] group_name: str groups: list[Group] + class VariableAssignerNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - type: str = 'variable-assigner' + + type: str = "variable-assigner" output_type: str variables: list[list[str]] advanced_settings: Optional[AdvancedSettings] = None diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 6944d9e82d..f03eae257a 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -21,13 +21,9 @@ class VariableAggregatorNode(BaseNode): for selector in node_data.variables: variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: - outputs = { - "output": variable - } + outputs = {"output": variable} - inputs = { - '.'.join(selector[1:]): variable - } + inputs = {".".join(selector[1:]): variable} break else: for group in node_data.advanced_settings.groups: @@ -35,24 +31,15 @@ class VariableAggregatorNode(BaseNode): variable = self.graph_runtime_state.variable_pool.get_any(selector) if variable is not None: - outputs[group.group_name] = { - 'output': variable - } - inputs['.'.join(selector[1:])] = variable + outputs[group.group_name] = {"output": variable} + inputs[".".join(selector[1:])] = variable break - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - inputs=inputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) @classmethod def _extract_variable_selector_to_variable_mapping( - cls, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerNodeData + cls, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py index d791d51523..83da4bdc79 100644 --- a/api/core/workflow/nodes/variable_assigner/__init__.py +++ b/api/core/workflow/nodes/variable_assigner/__init__.py @@ -2,7 +2,7 @@ from .node import VariableAssignerNode from .node_data import VariableAssignerData, WriteMode __all__ = [ - 'VariableAssignerNode', - 'VariableAssignerData', - 'WriteMode', + "VariableAssignerNode", + "VariableAssignerData", + "WriteMode", ] diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py index b2f32c6aaa..3969299795 100644 --- a/api/core/workflow/nodes/variable_assigner/node.py +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -24,43 +24,43 @@ class VariableAssignerNode(BaseNode): # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector) if not isinstance(original_variable, Variable): - raise VariableAssignerNodeError('assigned variable not found') + raise VariableAssignerNodeError("assigned variable not found") match data.write_mode: case WriteMode.OVER_WRITE: income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: - raise VariableAssignerNodeError('input value not found') - updated_variable = original_variable.model_copy(update={'value': income_value.value}) + raise VariableAssignerNodeError("input value not found") + updated_variable = original_variable.model_copy(update={"value": income_value.value}) case WriteMode.APPEND: income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector) if not income_value: - raise VariableAssignerNodeError('input value not found') + raise VariableAssignerNodeError("input value not found") updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={'value': updated_value}) + updated_variable = original_variable.model_copy(update={"value": updated_value}) case WriteMode.CLEAR: income_value = get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) + updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: - raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') + raise VariableAssignerNodeError(f"unsupported write mode: {data.write_mode}") # Over write the variable. self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable) # TODO: Move database operation to the pipeline. # Update conversation variable. - conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id']) + conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) if not conversation_id: - raise VariableAssignerNodeError('conversation_id not found') + raise VariableAssignerNodeError("conversation_id not found") update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={ - 'value': income_value.to_object(), + "value": income_value.to_object(), }, ) @@ -72,7 +72,7 @@ def update_conversation_variable(conversation_id: str, variable: Variable): with Session(db.engine) as session: row = session.scalar(stmt) if not row: - raise VariableAssignerNodeError('conversation variable not found in the database') + raise VariableAssignerNodeError("conversation variable not found in the database") row.data = variable.model_dump_json() session.commit() @@ -84,8 +84,8 @@ def get_zero_value(t: SegmentType): case SegmentType.OBJECT: return factory.build_segment({}) case SegmentType.STRING: - return factory.build_segment('') + return factory.build_segment("") case SegmentType.NUMBER: return factory.build_segment(0) case _: - raise VariableAssignerNodeError(f'unsupported variable type: {t}') + raise VariableAssignerNodeError(f"unsupported variable type: {t}") diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py index b3652b6802..8ac8eadf7c 100644 --- a/api/core/workflow/nodes/variable_assigner/node_data.py +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -6,14 +6,14 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData class WriteMode(str, Enum): - OVER_WRITE = 'over-write' - APPEND = 'append' - CLEAR = 'clear' + OVER_WRITE = "over-write" + APPEND = "append" + CLEAR = "clear" class VariableAssignerData(BaseNodeData): - title: str = 'Variable Assigner' - desc: Optional[str] = 'Assign a value to a variable' + title: str = "Variable Assigner" + desc: Optional[str] = "Assign a value to a variable" assigned_variable_selector: Sequence[str] write_mode: WriteMode input_variable_selector: Sequence[str] diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py index e195730a31..b8e8b881a5 100644 --- a/api/core/workflow/utils/condition/entities.py +++ b/api/core/workflow/utils/condition/entities.py @@ -7,11 +7,26 @@ class Condition(BaseModel): """ Condition entity """ + variable_selector: list[str] comparison_operator: Literal[ # for string or array - "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", + "contains", + "not contains", + "start with", + "end with", + "is", + "is not", + "empty", + "not empty", # for number - "=", "≠", ">", "<", "≥", "≤", "null", "not null" + "=", + "≠", + ">", + "<", + "≥", + "≤", + "null", + "not null", ] value: Optional[str] = None diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index 5ff61aab3d..395ee82478 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -15,9 +15,7 @@ class ConditionProcessor: index = 0 for condition in conditions: index += 1 - actual_value = variable_pool.get_any( - condition.variable_selector - ) + actual_value = variable_pool.get_any(condition.variable_selector) expected_value = None if condition.value is not None: @@ -25,9 +23,7 @@ class ConditionProcessor: variable_selectors = variable_template_parser.extract_variable_selectors() if variable_selectors: for variable_selector in variable_selectors: - value = variable_pool.get_any( - variable_selector.value_selector - ) + value = variable_pool.get_any(variable_selector.value_selector) expected_value = variable_template_parser.format({variable_selector.variable: value}) if expected_value is None: @@ -40,7 +36,7 @@ class ConditionProcessor: { "actual_value": actual_value, "expected_value": expected_value, - "comparison_operator": comparison_operator + "comparison_operator": comparison_operator, } ) @@ -50,10 +46,10 @@ class ConditionProcessor: return input_conditions, group_result def evaluate_condition( - self, - actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], - comparison_operator: str, - expected_value: Optional[str] = None + self, + actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], + comparison_operator: str, + expected_value: Optional[str] = None, ) -> bool: """ Evaluate condition @@ -109,7 +105,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') + raise ValueError("Invalid actual value type: string or array") if expected_value not in actual_value: return False @@ -126,7 +122,7 @@ class ConditionProcessor: return True if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') + raise ValueError("Invalid actual value type: string or array") if expected_value in actual_value: return False @@ -143,7 +139,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') + raise ValueError("Invalid actual value type: string") if not actual_value.startswith(expected_value): return False @@ -160,7 +156,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') + raise ValueError("Invalid actual value type: string") if not actual_value.endswith(expected_value): return False @@ -177,7 +173,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') + raise ValueError("Invalid actual value type: string") if actual_value != expected_value: return False @@ -194,7 +190,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') + raise ValueError("Invalid actual value type: string") if actual_value == expected_value: return False @@ -231,7 +227,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -253,7 +249,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -275,7 +271,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -297,7 +293,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -308,8 +304,9 @@ class ConditionProcessor: return False return True - def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], - expected_value: str | int | float) -> bool: + def _assert_greater_than_or_equal( + self, actual_value: Optional[int | float], expected_value: str | int | float + ) -> bool: """ Assert greater than or equal :param actual_value: actual value @@ -320,7 +317,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) @@ -331,8 +328,9 @@ class ConditionProcessor: return False return True - def _assert_less_than_or_equal(self, actual_value: Optional[int | float], - expected_value: str | int | float) -> bool: + def _assert_less_than_or_equal( + self, actual_value: Optional[int | float], expected_value: str | int | float + ) -> bool: """ Assert less than or equal :param actual_value: actual value @@ -343,7 +341,7 @@ class ConditionProcessor: return False if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') + raise ValueError("Invalid actual value type: number") if isinstance(actual_value, int): expected_value = int(expected_value) diff --git a/api/core/workflow/utils/variable_template_parser.py b/api/core/workflow/utils/variable_template_parser.py index c43fde172c..fd0e48b862 100644 --- a/api/core/workflow/utils/variable_template_parser.py +++ b/api/core/workflow/utils/variable_template_parser.py @@ -5,7 +5,7 @@ from typing import Any from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool -REGEX = re.compile(r'\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}') +REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: @@ -20,7 +20,7 @@ def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: # e.g. ('#node_id.query.name#', ['node_id', 'query', 'name']) key_selectors = filter( lambda t: len(t[1]) >= 2, - ((key, selector.replace('#', '').split('.')) for key, selector in zip(variable_keys, variable_keys)), + ((key, selector.replace("#", "").split(".")) for key, selector in zip(variable_keys, variable_keys)), ) inputs = {key: variable_pool.get_any(selector) for key, selector in key_selectors} @@ -29,13 +29,13 @@ def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: # return original matched string if key not found value = inputs.get(key, match.group(0)) if value is None: - value = '' + value = "" value = str(value) # remove template variables if required - return re.sub(REGEX, r'{\1}', value) + return re.sub(REGEX, r"{\1}", value) result = re.sub(REGEX, replacer, template) - result = re.sub(r'<\|.*?\|>', '', result) + result = re.sub(r"<\|.*?\|>", "", result) return result @@ -101,8 +101,8 @@ class VariableTemplateParser: """ variable_selectors = [] for variable_key in self.variable_keys: - remove_hash = variable_key.replace('#', '') - split_result = remove_hash.split('.') + remove_hash = variable_key.replace("#", "") + split_result = remove_hash.split(".") if len(split_result) < 2: continue @@ -127,7 +127,7 @@ class VariableTemplateParser: value = inputs.get(key, match.group(0)) # return original matched string if key not found if value is None: - value = '' + value = "" # convert the value to string if isinstance(value, list | dict | bool | int | float): value = str(value) @@ -136,7 +136,7 @@ class VariableTemplateParser: return VariableTemplateParser.remove_template_variables(value) prompt = re.sub(REGEX, replacer, self.template) - return re.sub(r'<\|.*?\|>', '', prompt) + return re.sub(r"<\|.*?\|>", "", prompt) @classmethod def remove_template_variables(cls, text: str): @@ -149,4 +149,4 @@ class VariableTemplateParser: Returns: The text with template variables removed. """ - return re.sub(REGEX, r'{\1}', text) + return re.sub(REGEX, r"{\1}", text) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index a359bd606e..25021935ee 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -33,19 +33,19 @@ logger = logging.getLogger(__name__) class WorkflowEntry: def __init__( - self, - tenant_id: str, - app_id: str, - workflow_id: str, - workflow_type: WorkflowType, - graph_config: Mapping[str, Any], - graph: Graph, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - call_depth: int, - variable_pool: VariablePool, - thread_pool_id: Optional[str] = None + self, + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_type: WorkflowType, + graph_config: Mapping[str, Any], + graph: Graph, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + variable_pool: VariablePool, + thread_pool_id: Optional[str] = None, ) -> None: """ Init workflow entry @@ -65,7 +65,7 @@ class WorkflowEntry: # check call depth workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH if call_depth > workflow_call_max_depth: - raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) + raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) # init workflow run state self.graph_engine = GraphEngine( @@ -82,13 +82,13 @@ class WorkflowEntry: variable_pool=variable_pool, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, - thread_pool_id=thread_pool_id + thread_pool_id=thread_pool_id, ) def run( - self, - *, - callbacks: Sequence[WorkflowCallback], + self, + *, + callbacks: Sequence[WorkflowCallback], ) -> Generator[GraphEngineEvent, None, None]: """ :param callbacks: workflow callbacks @@ -101,9 +101,7 @@ class WorkflowEntry: for event in generator: if callbacks: for callback in callbacks: - callback.on_event( - event=event - ) + callback.on_event(event=event) yield event except GenerateTaskStoppedException: pass @@ -111,20 +109,12 @@ class WorkflowEntry: logger.exception("Unknown Error when workflow entry running") if callbacks: for callback in callbacks: - callback.on_event( - event=GraphRunFailedEvent( - error=str(e) - ) - ) + callback.on_event(event=GraphRunFailedEvent(error=str(e))) return @classmethod def single_step_run( - cls, - workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict + cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]: """ Single step run workflow node @@ -137,30 +127,30 @@ class WorkflowEntry: # fetch node info from workflow graph graph = workflow.graph_dict if not graph: - raise ValueError('workflow graph not found') + raise ValueError("workflow graph not found") - nodes = graph.get('nodes') + nodes = graph.get("nodes") if not nodes: - raise ValueError('nodes not found in workflow graph') + raise ValueError("nodes not found in workflow graph") # fetch node config from node id node_config = None for node in nodes: - if node.get('id') == node_id: + if node.get("id") == node_id: node_config = node break if not node_config: - raise ValueError('node id not found in workflow graph') + raise ValueError("node id not found in workflow graph") # Get node class - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_type = NodeType.value_of(node_config.get("data", {}).get("type")) node_cls = node_classes.get(node_type) node_cls = cast(type[BaseNode], node_cls) if not node_cls: - raise ValueError(f'Node class not found for node type {node_type}') - + raise ValueError(f"Node class not found for node type {node_type}") + # init variable pool variable_pool = VariablePool( system_variables={}, @@ -169,9 +159,7 @@ class WorkflowEntry: ) # init graph - graph = Graph.init( - graph_config=workflow.graph_dict - ) + graph = Graph.init(graph_config=workflow.graph_dict) # init workflow run state node_instance: BaseNode = node_cls( @@ -186,21 +174,17 @@ class WorkflowEntry: user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, - call_depth=0 + call_depth=0, ), graph=graph, - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=time.perf_counter() - ) + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), ) try: # variable selector to variable mapping try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, - config=node_config + graph_config=workflow.graph_dict, config=node_config ) except NotImplementedError: variable_mapping = {} @@ -211,7 +195,7 @@ class WorkflowEntry: variable_pool=variable_pool, tenant_id=workflow.tenant_id, node_type=node_type, - node_data=node_instance.node_data + node_data=node_instance.node_data, ) # run node @@ -219,10 +203,7 @@ class WorkflowEntry: return node_instance, generator except Exception as e: - raise WorkflowNodeRunFailedError( - node_instance=node_instance, - error=str(e) - ) + raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) @classmethod def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: @@ -259,21 +240,20 @@ class WorkflowEntry: variable_pool: VariablePool, tenant_id: str, node_type: NodeType, - node_data: BaseNodeData + node_data: BaseNodeData, ) -> None: for node_variable, variable_selector in variable_mapping.items(): # fetch node id and variable key from node_variable - node_variable_list = node_variable.split('.') + node_variable_list = node_variable.split(".") if len(node_variable_list) < 1: - raise ValueError(f'Invalid node variable {node_variable}') - - node_variable_key = '.'.join(node_variable_list[1:]) + raise ValueError(f"Invalid node variable {node_variable}") - if ( - node_variable_key not in user_inputs - and node_variable not in user_inputs - ) and not variable_pool.get(variable_selector): - raise ValueError(f'Variable key {node_variable} not found in user inputs.') + node_variable_key = ".".join(node_variable_list[1:]) + + if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get( + variable_selector + ): + raise ValueError(f"Variable key {node_variable} not found in user inputs.") # fetch variable node id from variable selector variable_node_id = variable_selector[0] @@ -294,16 +274,17 @@ class WorkflowEntry: detail = node_data.vision.configs.detail if node_data.vision.configs else None for item in input_value: - if isinstance(item, dict) and 'type' in item and item['type'] == 'image': - transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) + if isinstance(item, dict) and "type" in item and item["type"] == "image": + transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) file = FileVar( tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=transfer_method, - url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=item.get( - 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), + url=item.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=item.get("upload_file_id") + if transfer_method == FileTransferMethod.LOCAL_FILE + else None, + extra_config=FileExtraConfig(image_config={"detail": detail} if detail else None), ) new_value.append(file) diff --git a/api/pyproject.toml b/api/pyproject.toml index 69d1fc4ee0..16b8c32c2a 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -68,7 +68,6 @@ ignore = [ [tool.ruff.format] exclude = [ - "core/**/*.py", "models/**/*.py", "migrations/**/*", ]