mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-20 18:39:06 +08:00
Merge branch 'main' into fix/chore-fix
This commit is contained in:
commit
7a2b2a04c9
1
.github/workflows/db-migration-test.yml
vendored
1
.github/workflows/db-migration-test.yml
vendored
@ -6,6 +6,7 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- api/migrations/**
|
||||
- .github/workflows/db-migration-test.yml
|
||||
|
||||
concurrency:
|
||||
group: db-migration-test-${{ github.ref }}
|
||||
|
@ -285,8 +285,9 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Model Configuration
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
@ -324,10 +325,10 @@ UNSTRUCTURED_API_KEY=
|
||||
SSRF_PROXY_HTTP_URL=
|
||||
SSRF_PROXY_HTTPS_URL=
|
||||
SSRF_DEFAULT_MAX_RETRIES=3
|
||||
SSRF_DEFAULT_TIME_OUT=
|
||||
SSRF_DEFAULT_CONNECT_TIME_OUT=
|
||||
SSRF_DEFAULT_READ_TIME_OUT=
|
||||
SSRF_DEFAULT_WRITE_TIME_OUT=
|
||||
SSRF_DEFAULT_TIME_OUT=5
|
||||
SSRF_DEFAULT_CONNECT_TIME_OUT=5
|
||||
SSRF_DEFAULT_READ_TIME_OUT=5
|
||||
SSRF_DEFAULT_WRITE_TIME_OUT=5
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
KEYWORD_DATA_SOURCE_TYPE=database
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
if not dify_config.DEBUG:
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
|
@ -1,6 +1,8 @@
|
||||
import os
|
||||
|
||||
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.DEBUG:
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
|
@ -329,6 +329,16 @@ class HttpConfig(BaseSettings):
|
||||
default=1 * 1024 * 1024,
|
||||
)
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES: PositiveInt = Field(
|
||||
description="Maximum number of retries for network requests (SSRF)",
|
||||
default=3,
|
||||
)
|
||||
|
||||
SSRF_PROXY_ALL_URL: Optional[str] = Field(
|
||||
description="Proxy URL for HTTP or HTTPS requests to prevent Server-Side Request Forgery (SSRF)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
|
||||
description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)",
|
||||
default=None,
|
||||
@ -677,12 +687,17 @@ class IndexingConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class ImageFormatConfig(BaseSettings):
|
||||
class VisionFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
|
||||
class CeleryBeatConfig(BaseSettings):
|
||||
CELERY_BEAT_SCHEDULER_TIME: int = Field(
|
||||
@ -787,7 +802,7 @@ class FeatureConfig(
|
||||
FileAccessConfig,
|
||||
FileUploadConfig,
|
||||
HttpConfig,
|
||||
ImageFormatConfig,
|
||||
VisionFormatConfig,
|
||||
InnerAPIConfig,
|
||||
IndexingConfig,
|
||||
LoggingConfig,
|
||||
|
@ -956,7 +956,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception as e:
|
||||
logging.error(f"Document {document_id} retry failed: {str(e)}")
|
||||
logging.exception(f"Document {document_id} retry failed: {str(e)}")
|
||||
continue
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
@ -7,7 +7,11 @@ from controllers.service_api import api
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from fields.conversation_fields import (
|
||||
conversation_delete_fields,
|
||||
conversation_infinite_scroll_pagination_fields,
|
||||
simple_conversation_fields,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
@ -49,7 +53,7 @@ class ConversationApi(Resource):
|
||||
|
||||
class ConversationDetailApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
@marshal_with(conversation_delete_fields)
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -58,10 +62,9 @@ class ConversationDetailApi(Resource):
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
return ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class ConversationRenameApi(Resource):
|
||||
|
@ -1,6 +1,5 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
@ -10,6 +9,7 @@ from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
@ -328,7 +328,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG", "false").lower() == "true":
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
@ -242,7 +242,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.exception(e)
|
||||
break
|
||||
if tts_publisher:
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
@ -8,6 +7,7 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
@ -235,7 +235,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
@ -8,6 +7,7 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
@ -8,6 +7,7 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
@ -213,7 +213,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
@ -1,6 +1,5 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
@ -10,6 +9,7 @@ from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
@ -273,7 +273,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true":
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
@ -216,7 +216,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
else:
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.exception(e)
|
||||
break
|
||||
if tts_publisher:
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
@ -3,7 +3,7 @@ import base64
|
||||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent
|
||||
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
@ -71,6 +71,12 @@ def to_prompt_message_content(f: File, /):
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||
case FileType.VIDEO:
|
||||
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
@ -112,7 +118,7 @@ def _download_file_content(path: str, /):
|
||||
def _get_encoded_string(f: File, /):
|
||||
match f.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url)
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
encoded_string = base64.b64encode(content).decode("utf-8")
|
||||
@ -140,6 +146,8 @@ def _file_to_encoded_string(f: File, /):
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
return _to_base64_data_string(f)
|
||||
case FileType.VIDEO:
|
||||
return _to_base64_data_string(f)
|
||||
case FileType.AUDIO:
|
||||
return _get_encoded_string(f)
|
||||
case _:
|
||||
|
@ -3,26 +3,20 @@ Proxy requests to avoid SSRF
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
|
||||
SSRF_PROXY_ALL_URL = os.getenv("SSRF_PROXY_ALL_URL", "")
|
||||
SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "")
|
||||
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
|
||||
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
|
||||
SSRF_DEFAULT_TIME_OUT = float(os.getenv("SSRF_DEFAULT_TIME_OUT", "5"))
|
||||
SSRF_DEFAULT_CONNECT_TIME_OUT = float(os.getenv("SSRF_DEFAULT_CONNECT_TIME_OUT", "5"))
|
||||
SSRF_DEFAULT_READ_TIME_OUT = float(os.getenv("SSRF_DEFAULT_READ_TIME_OUT", "5"))
|
||||
SSRF_DEFAULT_WRITE_TIME_OUT = float(os.getenv("SSRF_DEFAULT_WRITE_TIME_OUT", "5"))
|
||||
from configs import dify_config
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
|
||||
proxy_mounts = (
|
||||
{
|
||||
"http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL),
|
||||
"https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL),
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
|
||||
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
|
||||
}
|
||||
if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL
|
||||
else None
|
||||
)
|
||||
|
||||
@ -38,17 +32,17 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = httpx.Timeout(
|
||||
SSRF_DEFAULT_TIME_OUT,
|
||||
connect=SSRF_DEFAULT_CONNECT_TIME_OUT,
|
||||
read=SSRF_DEFAULT_READ_TIME_OUT,
|
||||
write=SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
|
||||
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
|
||||
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
|
||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
)
|
||||
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if SSRF_PROXY_ALL_URL:
|
||||
with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
elif proxy_mounts:
|
||||
with httpx.Client(mounts=proxy_mounts) as client:
|
||||
|
@ -1,8 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
@ -509,7 +509,7 @@ class LBModelManager:
|
||||
|
||||
continue
|
||||
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
if dify_config.DEBUG:
|
||||
logger.info(
|
||||
f"Model LB\nid: {config.id}\nname:{config.name}\n"
|
||||
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
|
||||
|
@ -12,11 +12,13 @@ from .message_entities import (
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from .model_entities import ModelPropertyKey
|
||||
|
||||
__all__ = [
|
||||
"ImagePromptMessageContent",
|
||||
"VideoPromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessageRole",
|
||||
"LLMUsage",
|
||||
|
@ -56,6 +56,7 @@ class PromptMessageContentType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
|
||||
|
||||
class PromptMessageContent(BaseModel):
|
||||
@ -75,6 +76,12 @@ class TextPromptMessageContent(PromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
|
||||
|
||||
class VideoPromptMessageContent(PromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.VIDEO
|
||||
data: str = Field(..., description="Base64 encoded video data")
|
||||
format: str = Field(..., description="Video format")
|
||||
|
||||
|
||||
class AudioPromptMessageContent(PromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.AUDIO
|
||||
data: str = Field(..., description="Base64 encoded audio data")
|
||||
|
@ -126,6 +126,6 @@ class OutputModeration(BaseModel):
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error("Moderation Output error: %s", e)
|
||||
logger.exception("Moderation Output error: %s", e)
|
||||
|
||||
return None
|
||||
|
@ -708,7 +708,7 @@ class TraceQueueManager:
|
||||
trace_task.app_id = self.app_id
|
||||
trace_manager_queue.put(trace_task)
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding trace task: {e}")
|
||||
logging.exception(f"Error adding trace task: {e}")
|
||||
finally:
|
||||
self.start_timer()
|
||||
|
||||
@ -727,7 +727,7 @@ class TraceQueueManager:
|
||||
if tasks:
|
||||
self.send_to_celery(tasks)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing trace tasks: {e}")
|
||||
logging.exception(f"Error processing trace tasks: {e}")
|
||||
|
||||
def start_timer(self):
|
||||
global trace_manager_timer
|
||||
|
@ -242,7 +242,7 @@ class CouchbaseVector(BaseVector):
|
||||
try:
|
||||
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.exception(e)
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
query = f"""
|
||||
|
@ -79,7 +79,7 @@ class LindormVectorStore(BaseVector):
|
||||
existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
|
||||
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching batch {batch_ids}: {e}")
|
||||
logger.exception(f"Error fetching batch {batch_ids}: {e}")
|
||||
return set()
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
|
||||
@ -96,7 +96,7 @@ class LindormVectorStore(BaseVector):
|
||||
)
|
||||
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching batch {batch_ids}: {e}")
|
||||
logger.exception(f"Error fetching batch {batch_ids}: {e}")
|
||||
return set()
|
||||
|
||||
if ids is None:
|
||||
@ -177,7 +177,7 @@ class LindormVectorStore(BaseVector):
|
||||
else:
|
||||
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while deleting the index: {e}")
|
||||
logger.exception(f"Error occurred while deleting the index: {e}")
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
@ -201,7 +201,7 @@ class LindormVectorStore(BaseVector):
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name, body=query)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing search: {e}")
|
||||
logger.exception(f"Error executing search: {e}")
|
||||
raise
|
||||
|
||||
docs_and_scores = []
|
||||
|
@ -86,7 +86,7 @@ class MilvusVector(BaseVector):
|
||||
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
||||
pks.extend(ids)
|
||||
except MilvusException as e:
|
||||
logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count)
|
||||
logger.exception("Failed to insert batch starting at entity: %s/%s", i, total_count)
|
||||
raise e
|
||||
return pks
|
||||
|
||||
|
@ -142,7 +142,7 @@ class MyScaleVector(BaseVector):
|
||||
for r in self._client.query(sql).named_results()
|
||||
]
|
||||
except Exception as e:
|
||||
logging.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
|
@ -129,7 +129,7 @@ class OpenSearchVector(BaseVector):
|
||||
if status == 404:
|
||||
logger.warning(f"Document not found for deletion: {doc_id}")
|
||||
else:
|
||||
logger.error(f"Error deleting document: {error}")
|
||||
logger.exception(f"Error deleting document: {error}")
|
||||
|
||||
def delete(self) -> None:
|
||||
self._client.indices.delete(index=self._collection_name.lower())
|
||||
@ -158,7 +158,7 @@ class OpenSearchVector(BaseVector):
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing search: {e}")
|
||||
logger.exception(f"Error executing search: {e}")
|
||||
raise
|
||||
|
||||
docs = []
|
||||
|
@ -89,7 +89,7 @@ class CacheEmbedding(Embeddings):
|
||||
db.session.rollback()
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
logger.error("Failed to embed documents: %s", ex)
|
||||
logger.exception("Failed to embed documents: %s", ex)
|
||||
raise ex
|
||||
|
||||
return text_embeddings
|
||||
|
@ -28,7 +28,6 @@ logger = logging.getLogger(__name__)
|
||||
class WordExtractor(BaseExtractor):
|
||||
"""Load docx files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
@ -51,9 +50,9 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
self.web_path = self.file_path
|
||||
# TODO: use a better way to handle the file
|
||||
self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115
|
||||
self.temp_file.write(r.content)
|
||||
self.file_path = self.temp_file.name
|
||||
with tempfile.NamedTemporaryFile(delete=False) as self.temp_file:
|
||||
self.temp_file.write(r.content)
|
||||
self.file_path = self.temp_file.name
|
||||
elif not os.path.isfile(self.file_path):
|
||||
raise ValueError(f"File path {self.file_path} is not a valid file or url")
|
||||
|
||||
@ -230,7 +229,7 @@ class WordExtractor(BaseExtractor):
|
||||
for i in url_pattern.findall(x.text):
|
||||
hyperlinks_url = str(i)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.exception(e)
|
||||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
|
@ -98,7 +98,7 @@ class ToolFileManager:
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download file from {file_url}: {e}")
|
||||
logger.exception(f"Failed to download file from {file_url}: {e}")
|
||||
raise
|
||||
|
||||
mimetype = guess_type(file_url)[0] or "octet/stream"
|
||||
|
@ -526,7 +526,7 @@ class ToolManager:
|
||||
yield provider
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"load builtin provider error: {e}")
|
||||
logger.exception(f"load builtin provider {provider} error: {e}")
|
||||
continue
|
||||
# set builtin providers loaded
|
||||
cls._builtin_providers_loaded = True
|
||||
|
@ -127,7 +127,9 @@ class FeishuRequest:
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
|
||||
url = f"{self.API_BASE_URL}/document/write_document"
|
||||
@ -135,7 +137,7 @@ class FeishuRequest:
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res
|
||||
|
||||
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> dict:
|
||||
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str:
|
||||
"""
|
||||
API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/raw_content
|
||||
Example Response:
|
||||
@ -154,7 +156,9 @@ class FeishuRequest:
|
||||
}
|
||||
url = f"{self.API_BASE_URL}/document/get_document_content"
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data").get("content")
|
||||
if "data" in res:
|
||||
return res.get("data").get("content")
|
||||
return ""
|
||||
|
||||
def list_document_blocks(
|
||||
self, document_id: str, page_token: str, user_id_type: str = "open_id", page_size: int = 500
|
||||
@ -170,7 +174,9 @@ class FeishuRequest:
|
||||
}
|
||||
url = f"{self.API_BASE_URL}/document/list_document_blocks"
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
|
||||
"""
|
||||
@ -186,7 +192,9 @@ class FeishuRequest:
|
||||
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/message/send_webhook_message"
|
||||
@ -220,7 +228,9 @@ class FeishuRequest:
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def get_thread_messages(
|
||||
self, container_id: str, page_token: str, sort_type: str = "ByCreateTimeAsc", page_size: int = 20
|
||||
@ -236,7 +246,9 @@ class FeishuRequest:
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
|
||||
# 创建任务
|
||||
@ -249,7 +261,9 @@ class FeishuRequest:
|
||||
"description": description,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def update_task(
|
||||
self, task_guid: str, summary: str, start_time: str, end_time: str, completed_time: str, description: str
|
||||
@ -265,7 +279,9 @@ class FeishuRequest:
|
||||
"description": description,
|
||||
}
|
||||
res = self._send_request(url, method="PATCH", payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def delete_task(self, task_guid: str) -> dict:
|
||||
# 删除任务
|
||||
@ -297,7 +313,9 @@ class FeishuRequest:
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/get_primary_calendar"
|
||||
@ -305,7 +323,9 @@ class FeishuRequest:
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def create_event(
|
||||
self,
|
||||
@ -328,7 +348,9 @@ class FeishuRequest:
|
||||
"attendee_ability": attendee_ability,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def update_event(
|
||||
self,
|
||||
@ -374,7 +396,9 @@ class FeishuRequest:
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def search_events(
|
||||
self,
|
||||
@ -395,7 +419,9 @@ class FeishuRequest:
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
|
||||
# 参加日程参会人
|
||||
@ -406,7 +432,9 @@ class FeishuRequest:
|
||||
"need_notification": need_notification,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def create_spreadsheet(
|
||||
self,
|
||||
@ -420,7 +448,9 @@ class FeishuRequest:
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def get_spreadsheet(
|
||||
self,
|
||||
@ -434,7 +464,9 @@ class FeishuRequest:
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def list_spreadsheet_sheets(
|
||||
self,
|
||||
@ -446,7 +478,9 @@ class FeishuRequest:
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def add_rows(
|
||||
self,
|
||||
@ -466,7 +500,9 @@ class FeishuRequest:
|
||||
"values": values,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def add_cols(
|
||||
self,
|
||||
@ -486,7 +522,9 @@ class FeishuRequest:
|
||||
"values": values,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def read_rows(
|
||||
self,
|
||||
@ -508,7 +546,9 @@ class FeishuRequest:
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def read_cols(
|
||||
self,
|
||||
@ -530,7 +570,9 @@ class FeishuRequest:
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def read_table(
|
||||
self,
|
||||
@ -552,7 +594,9 @@ class FeishuRequest:
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def create_base(
|
||||
self,
|
||||
@ -566,7 +610,9 @@ class FeishuRequest:
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def add_records(
|
||||
self,
|
||||
@ -588,7 +634,9 @@ class FeishuRequest:
|
||||
"records": convert_add_records(records),
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def update_records(
|
||||
self,
|
||||
@ -610,7 +658,9 @@ class FeishuRequest:
|
||||
"records": convert_update_records(records),
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def delete_records(
|
||||
self,
|
||||
@ -637,7 +687,9 @@ class FeishuRequest:
|
||||
"records": record_id_list,
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def search_record(
|
||||
self,
|
||||
@ -701,7 +753,10 @@ class FeishuRequest:
|
||||
if automatic_fields:
|
||||
payload["automatic_fields"] = automatic_fields
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
return res.get("data")
|
||||
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def get_base_info(
|
||||
self,
|
||||
@ -713,7 +768,9 @@ class FeishuRequest:
|
||||
"app_token": app_token,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def create_table(
|
||||
self,
|
||||
@ -741,7 +798,9 @@ class FeishuRequest:
|
||||
if default_view_name:
|
||||
payload["default_view_name"] = default_view_name
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def delete_tables(
|
||||
self,
|
||||
@ -774,8 +833,11 @@ class FeishuRequest:
|
||||
"table_ids": table_id_list,
|
||||
"table_names": table_name_list,
|
||||
}
|
||||
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def list_tables(
|
||||
self,
|
||||
@ -791,7 +853,9 @@ class FeishuRequest:
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def read_records(
|
||||
self,
|
||||
@ -819,4 +883,6 @@ class FeishuRequest:
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params, payload=payload)
|
||||
return res.get("data")
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
820
api/core/tools/utils/lark_api_utils.py
Normal file
820
api/core/tools/utils/lark_api_utils.py
Normal file
@ -0,0 +1,820 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
def lark_auth(credentials):
|
||||
app_id = credentials.get("app_id")
|
||||
app_secret = credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ToolProviderCredentialValidationError("app_id and app_secret is required")
|
||||
try:
|
||||
assert LarkRequest(app_id, app_secret).tenant_access_token is not None
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
|
||||
class LarkRequest:
|
||||
API_BASE_URL = "https://lark-plugin-api.solutionsuite.ai/lark-plugin"
|
||||
|
||||
def __init__(self, app_id: str, app_secret: str):
|
||||
self.app_id = app_id
|
||||
self.app_secret = app_secret
|
||||
|
||||
def convert_add_records(self, json_str):
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("Parsed data must be a list")
|
||||
converted_data = [{"fields": json.dumps(item, ensure_ascii=False)} for item in data]
|
||||
return converted_data
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
except Exception as e:
|
||||
raise ValueError(f"An error occurred while processing the data: {e}")
|
||||
|
||||
def convert_update_records(self, json_str):
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("Parsed data must be a list")
|
||||
|
||||
converted_data = [
|
||||
{"fields": json.dumps(record["fields"], ensure_ascii=False), "record_id": record["record_id"]}
|
||||
for record in data
|
||||
if "fields" in record and "record_id" in record
|
||||
]
|
||||
|
||||
if len(converted_data) != len(data):
|
||||
raise ValueError("Each record must contain 'fields' and 'record_id'")
|
||||
|
||||
return converted_data
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
except Exception as e:
|
||||
raise ValueError(f"An error occurred while processing the data: {e}")
|
||||
|
||||
@property
|
||||
def tenant_access_token(self) -> str:
|
||||
feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token"
|
||||
if redis_client.exists(feishu_tenant_access_token):
|
||||
return redis_client.get(feishu_tenant_access_token).decode()
|
||||
res = self.get_tenant_access_token(self.app_id, self.app_secret)
|
||||
redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token"))
|
||||
if "tenant_access_token" in res:
|
||||
return res.get("tenant_access_token")
|
||||
return ""
|
||||
|
||||
def _send_request(
|
||||
self,
|
||||
url: str,
|
||||
method: str = "post",
|
||||
require_token: bool = True,
|
||||
payload: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"user-agent": "Dify",
|
||||
}
|
||||
if require_token:
|
||||
headers["tenant-access-token"] = f"{self.tenant_access_token}"
|
||||
res = httpx.request(method=method, url=url, headers=headers, json=payload, params=params, timeout=30).json()
|
||||
if res.get("code") != 0:
|
||||
raise Exception(res)
|
||||
return res
|
||||
|
||||
def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token"
|
||||
payload = {"app_id": app_id, "app_secret": app_secret}
|
||||
res = self._send_request(url, require_token=False, payload=payload)
|
||||
return res
|
||||
|
||||
def create_document(self, title: str, content: str, folder_token: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/document/create_document"
|
||||
payload = {
|
||||
"title": title,
|
||||
"content": content,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
|
||||
url = f"{self.API_BASE_URL}/document/write_document"
|
||||
payload = {"document_id": document_id, "content": content, "position": position}
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res
|
||||
|
||||
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict:
|
||||
params = {
|
||||
"document_id": document_id,
|
||||
"mode": mode,
|
||||
"lang": lang,
|
||||
}
|
||||
url = f"{self.API_BASE_URL}/document/get_document_content"
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data").get("content")
|
||||
return ""
|
||||
|
||||
def list_document_blocks(
|
||||
self, document_id: str, page_token: str, user_id_type: str = "open_id", page_size: int = 500
|
||||
) -> dict:
|
||||
params = {
|
||||
"user_id_type": user_id_type,
|
||||
"document_id": document_id,
|
||||
"page_size": page_size,
|
||||
"page_token": page_token,
|
||||
}
|
||||
url = f"{self.API_BASE_URL}/document/list_document_blocks"
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/message/send_bot_message"
|
||||
params = {
|
||||
"receive_id_type": receive_id_type,
|
||||
}
|
||||
payload = {
|
||||
"receive_id": receive_id,
|
||||
"msg_type": msg_type,
|
||||
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/message/send_webhook_message"
|
||||
payload = {
|
||||
"webhook": webhook,
|
||||
"msg_type": msg_type,
|
||||
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
|
||||
}
|
||||
res = self._send_request(url, require_token=False, payload=payload)
|
||||
return res
|
||||
|
||||
def get_chat_messages(
|
||||
self,
|
||||
container_id: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
page_token: str,
|
||||
sort_type: str = "ByCreateTimeAsc",
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/message/get_chat_messages"
|
||||
params = {
|
||||
"container_id": container_id,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"sort_type": sort_type,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def get_thread_messages(
|
||||
self, container_id: str, page_token: str, sort_type: str = "ByCreateTimeAsc", page_size: int = 20
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/message/get_thread_messages"
|
||||
params = {
|
||||
"container_id": container_id,
|
||||
"sort_type": sort_type,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/task/create_task"
|
||||
payload = {
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"completed_at": completed_time,
|
||||
"description": description,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def update_task(
|
||||
self, task_guid: str, summary: str, start_time: str, end_time: str, completed_time: str, description: str
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/task/update_task"
|
||||
payload = {
|
||||
"task_guid": task_guid,
|
||||
"summary": summary,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"completed_time": completed_time,
|
||||
"description": description,
|
||||
}
|
||||
res = self._send_request(url, method="PATCH", payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def delete_task(self, task_guid: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/task/delete_task"
|
||||
payload = {
|
||||
"task_guid": task_guid,
|
||||
}
|
||||
res = self._send_request(url, method="DELETE", payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict:
|
||||
url = f"{self.API_BASE_URL}/task/add_members"
|
||||
payload = {
|
||||
"task_guid": task_guid,
|
||||
"member_phone_or_email": member_phone_or_email,
|
||||
"member_role": member_role,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict:
|
||||
url = f"{self.API_BASE_URL}/wiki/get_wiki_nodes"
|
||||
payload = {
|
||||
"space_id": space_id,
|
||||
"parent_node_token": parent_node_token,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/get_primary_calendar"
|
||||
params = {
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def create_event(
|
||||
self,
|
||||
summary: str,
|
||||
description: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
attendee_ability: str,
|
||||
need_notification: bool = True,
|
||||
auto_record: bool = False,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/create_event"
|
||||
payload = {
|
||||
"summary": summary,
|
||||
"description": description,
|
||||
"need_notification": need_notification,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"auto_record": auto_record,
|
||||
"attendee_ability": attendee_ability,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def update_event(
|
||||
self,
|
||||
event_id: str,
|
||||
summary: str,
|
||||
description: str,
|
||||
need_notification: bool,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
auto_record: bool,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}"
|
||||
payload = {}
|
||||
if summary:
|
||||
payload["summary"] = summary
|
||||
if description:
|
||||
payload["description"] = description
|
||||
if start_time:
|
||||
payload["start_time"] = start_time
|
||||
if end_time:
|
||||
payload["end_time"] = end_time
|
||||
if need_notification:
|
||||
payload["need_notification"] = need_notification
|
||||
if auto_record:
|
||||
payload["auto_record"] = auto_record
|
||||
res = self._send_request(url, method="PATCH", payload=payload)
|
||||
return res
|
||||
|
||||
def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/delete_event/{event_id}"
|
||||
params = {
|
||||
"need_notification": need_notification,
|
||||
}
|
||||
res = self._send_request(url, method="DELETE", params=params)
|
||||
return res
|
||||
|
||||
def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/list_events"
|
||||
params = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def search_events(
|
||||
self,
|
||||
query: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
page_token: str,
|
||||
user_id_type: str = "open_id",
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/search_events"
|
||||
payload = {
|
||||
"query": query,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"page_token": page_token,
|
||||
"user_id_type": user_id_type,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
|
||||
url = f"{self.API_BASE_URL}/calendar/add_event_attendees"
|
||||
payload = {
|
||||
"event_id": event_id,
|
||||
"attendee_phone_or_email": attendee_phone_or_email,
|
||||
"need_notification": need_notification,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def create_spreadsheet(
|
||||
self,
|
||||
title: str,
|
||||
folder_token: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/create_spreadsheet"
|
||||
payload = {
|
||||
"title": title,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def get_spreadsheet(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/get_spreadsheet"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def list_spreadsheet_sheets(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/list_spreadsheet_sheets"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def add_rows(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
length: int,
|
||||
values: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/add_rows"
|
||||
payload = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"length": length,
|
||||
"values": values,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def add_cols(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
length: int,
|
||||
values: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/add_cols"
|
||||
payload = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"length": length,
|
||||
"values": values,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def read_rows(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
start_row: int,
|
||||
num_rows: int,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/read_rows"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"start_row": start_row,
|
||||
"num_rows": num_rows,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def read_cols(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
start_col: int,
|
||||
num_cols: int,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/read_cols"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"start_col": start_col,
|
||||
"num_cols": num_cols,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def read_table(
|
||||
self,
|
||||
spreadsheet_token: str,
|
||||
sheet_id: str,
|
||||
sheet_name: str,
|
||||
num_range: str,
|
||||
query: str,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/spreadsheet/read_table"
|
||||
params = {
|
||||
"spreadsheet_token": spreadsheet_token,
|
||||
"sheet_id": sheet_id,
|
||||
"sheet_name": sheet_name,
|
||||
"range": num_range,
|
||||
"query": query,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def create_base(
|
||||
self,
|
||||
name: str,
|
||||
folder_token: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/create_base"
|
||||
payload = {
|
||||
"name": name,
|
||||
"folder_token": folder_token,
|
||||
}
|
||||
res = self._send_request(url, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def add_records(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
records: str,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/add_records"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
payload = {
|
||||
"records": self.convert_add_records(records),
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def update_records(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
records: str,
|
||||
user_id_type: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/update_records"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
payload = {
|
||||
"records": self.convert_update_records(records),
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def delete_records(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
record_ids: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/delete_records"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
}
|
||||
if not record_ids:
|
||||
record_id_list = []
|
||||
else:
|
||||
try:
|
||||
record_id_list = json.loads(record_ids)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
payload = {
|
||||
"records": record_id_list,
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def search_record(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
view_id: str,
|
||||
field_names: str,
|
||||
sort: str,
|
||||
filters: str,
|
||||
page_token: str,
|
||||
automatic_fields: bool = False,
|
||||
user_id_type: str = "open_id",
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/search_record"
|
||||
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
"user_id_type": user_id_type,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
if not field_names:
|
||||
field_name_list = []
|
||||
else:
|
||||
try:
|
||||
field_name_list = json.loads(field_names)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
if not sort:
|
||||
sort_list = []
|
||||
else:
|
||||
try:
|
||||
sort_list = json.loads(sort)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
if not filters:
|
||||
filter_dict = {}
|
||||
else:
|
||||
try:
|
||||
filter_dict = json.loads(filters)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
payload = {}
|
||||
|
||||
if view_id:
|
||||
payload["view_id"] = view_id
|
||||
if field_names:
|
||||
payload["field_names"] = field_name_list
|
||||
if sort:
|
||||
payload["sort"] = sort_list
|
||||
if filters:
|
||||
payload["filter"] = filter_dict
|
||||
if automatic_fields:
|
||||
payload["automatic_fields"] = automatic_fields
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def get_base_info(
|
||||
self,
|
||||
app_token: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/get_base_info"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def create_table(
|
||||
self,
|
||||
app_token: str,
|
||||
table_name: str,
|
||||
default_view_name: str,
|
||||
fields: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/create_table"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
}
|
||||
if not fields:
|
||||
fields_list = []
|
||||
else:
|
||||
try:
|
||||
fields_list = json.loads(fields)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
payload = {
|
||||
"name": table_name,
|
||||
"fields": fields_list,
|
||||
}
|
||||
if default_view_name:
|
||||
payload["default_view_name"] = default_view_name
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def delete_tables(
|
||||
self,
|
||||
app_token: str,
|
||||
table_ids: str,
|
||||
table_names: str,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/delete_tables"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
}
|
||||
if not table_ids:
|
||||
table_id_list = []
|
||||
else:
|
||||
try:
|
||||
table_id_list = json.loads(table_ids)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
if not table_names:
|
||||
table_name_list = []
|
||||
else:
|
||||
try:
|
||||
table_name_list = json.loads(table_names)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
|
||||
payload = {
|
||||
"table_ids": table_id_list,
|
||||
"table_names": table_name_list,
|
||||
}
|
||||
res = self._send_request(url, params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def list_tables(
|
||||
self,
|
||||
app_token: str,
|
||||
page_token: str,
|
||||
page_size: int = 20,
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/list_tables"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"page_token": page_token,
|
||||
"page_size": page_size,
|
||||
}
|
||||
res = self._send_request(url, method="GET", params=params)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
||||
|
||||
def read_records(
|
||||
self,
|
||||
app_token: str,
|
||||
table_id: str,
|
||||
table_name: str,
|
||||
record_ids: str,
|
||||
user_id_type: str = "open_id",
|
||||
) -> dict:
|
||||
url = f"{self.API_BASE_URL}/base/read_records"
|
||||
params = {
|
||||
"app_token": app_token,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
}
|
||||
if not record_ids:
|
||||
record_id_list = []
|
||||
else:
|
||||
try:
|
||||
record_id_list = json.loads(record_ids)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("The input string is not valid JSON")
|
||||
payload = {
|
||||
"record_ids": record_id_list,
|
||||
"user_id_type": user_id_type,
|
||||
}
|
||||
res = self._send_request(url, method="POST", params=params, payload=payload)
|
||||
if "data" in res:
|
||||
return res.get("data")
|
||||
return res
|
@ -69,7 +69,7 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
try:
|
||||
result = self._run()
|
||||
except Exception as e:
|
||||
logger.error(f"Node {self.node_id} failed to run: {e}")
|
||||
logger.exception(f"Node {self.node_id} failed to run: {e}")
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
|
@ -97,15 +97,6 @@ class Executor:
|
||||
headers = self.variable_pool.convert_template(self.node_data.headers).text
|
||||
self.headers = _plain_text_to_dict(headers)
|
||||
|
||||
body = self.node_data.body
|
||||
if body is None:
|
||||
return
|
||||
if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE:
|
||||
self.headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
|
||||
if body.type == "form-data":
|
||||
self.boundary = f"----WebKitFormBoundary{_generate_random_string(16)}"
|
||||
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
|
||||
|
||||
def _init_body(self):
|
||||
body = self.node_data.body
|
||||
if body is not None:
|
||||
@ -154,9 +145,8 @@ class Executor:
|
||||
for k, v in files.items()
|
||||
if v.related_id is not None
|
||||
}
|
||||
|
||||
self.data = form_data
|
||||
self.files = files
|
||||
self.files = files or None
|
||||
|
||||
def _assembling_headers(self) -> dict[str, Any]:
|
||||
authorization = deepcopy(self.auth)
|
||||
@ -217,6 +207,7 @@ class Executor:
|
||||
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||
"follow_redirects": True,
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
|
||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
||||
return response
|
||||
@ -244,6 +235,13 @@ class Executor:
|
||||
raw += f"Host: {url_parts.netloc}\r\n"
|
||||
|
||||
headers = self._assembling_headers()
|
||||
body = self.node_data.body
|
||||
boundary = f"----WebKitFormBoundary{_generate_random_string(16)}"
|
||||
if body:
|
||||
if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE:
|
||||
headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
|
||||
if body.type == "form-data":
|
||||
headers["Content-Type"] = f"multipart/form-data; boundary={boundary}"
|
||||
for k, v in headers.items():
|
||||
if self.auth.type == "api-key":
|
||||
authorization_header = "Authorization"
|
||||
@ -256,7 +254,6 @@ class Executor:
|
||||
|
||||
body = ""
|
||||
if self.files:
|
||||
boundary = self.boundary
|
||||
for k, v in self.files.items():
|
||||
body += f"--{boundary}\r\n"
|
||||
body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n'
|
||||
@ -271,7 +268,6 @@ class Executor:
|
||||
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
|
||||
body = urlencode(self.data)
|
||||
elif self.data and self.node_data.body.type == "form-data":
|
||||
boundary = self.boundary
|
||||
for key, value in self.data.items():
|
||||
body += f"--{boundary}\r\n"
|
||||
body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
|
@ -14,6 +14,7 @@ from core.model_runtime.entities import (
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -560,7 +561,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
# cuz vision detail is related to the configuration from FileUpload feature.
|
||||
content_item.detail = vision_detail
|
||||
prompt_message_content.append(content_item)
|
||||
elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent):
|
||||
elif isinstance(
|
||||
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
|
||||
):
|
||||
prompt_message_content.append(content_item)
|
||||
|
||||
if len(prompt_message_content) > 1:
|
||||
|
@ -127,7 +127,7 @@ class QuestionClassifierNode(LLMNode):
|
||||
category_id = category_id_result
|
||||
|
||||
except OutputParserError:
|
||||
logging.error(f"Failed to parse result text: {result_text}")
|
||||
logging.exception(f"Failed to parse result text: {result_text}")
|
||||
try:
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import posixpath
|
||||
from collections.abc import Generator
|
||||
|
||||
import oss2 as aliyun_s3
|
||||
@ -50,9 +51,4 @@ class AliyunOssStorage(BaseStorage):
|
||||
self.client.delete_object(self.__wrapper_folder_filename(filename))
|
||||
|
||||
def __wrapper_folder_filename(self, filename) -> str:
|
||||
if self.folder:
|
||||
if self.folder.endswith("/"):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + "/" + filename
|
||||
return filename
|
||||
return posixpath.join(self.folder, filename) if self.folder else filename
|
||||
|
@ -202,6 +202,10 @@ simple_conversation_fields = {
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
conversation_delete_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
|
@ -39,13 +39,13 @@ class SMTPClient:
|
||||
|
||||
smtp.sendmail(self._from, mail["to"], msg.as_string())
|
||||
except smtplib.SMTPException as e:
|
||||
logging.error(f"SMTP error occurred: {str(e)}")
|
||||
logging.exception(f"SMTP error occurred: {str(e)}")
|
||||
raise
|
||||
except TimeoutError as e:
|
||||
logging.error(f"Timeout occurred while sending email: {str(e)}")
|
||||
logging.exception(f"Timeout occurred while sending email: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error occurred while sending email: {str(e)}")
|
||||
logging.exception(f"Unexpected error occurred while sending email: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
if smtp:
|
||||
|
@ -34,6 +34,7 @@ select = [
|
||||
"RUF101", # redirected-noqa
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"TRY400", # error-instead-of-exception
|
||||
"UP", # pyupgrade rules
|
||||
"W191", # tab-indentation
|
||||
"W605", # invalid-escape-sequence
|
||||
|
@ -821,7 +821,7 @@ class RegisterService:
|
||||
db.session.rollback()
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logging.error(f"Register failed: {e}")
|
||||
logging.exception(f"Register failed: {e}")
|
||||
raise AccountRegisterError(f"Registration failed: {e}") from e
|
||||
|
||||
return account
|
||||
|
@ -160,4 +160,5 @@ class ConversationService:
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
conversation.is_deleted = True
|
||||
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
@ -195,7 +195,7 @@ class ApiToolManageService:
|
||||
# try to parse schema, avoid SSRF attack
|
||||
ApiToolManageService.parser_api_schema(schema)
|
||||
except Exception as e:
|
||||
logger.error(f"parse api schema error: {str(e)}")
|
||||
logger.exception(f"parse api schema error: {str(e)}")
|
||||
raise ValueError("invalid schema, please check the url you provided")
|
||||
|
||||
return {"schema": schema}
|
||||
|
@ -196,8 +196,7 @@ class ToolTransformService:
|
||||
|
||||
username = user.name
|
||||
except Exception as e:
|
||||
logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
|
||||
|
||||
logger.exception(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
|
||||
# add provider into providers
|
||||
credentials = db_provider.credentials
|
||||
result = ToolProviderApiEntity(
|
||||
|
@ -196,3 +196,72 @@ def test_extract_selectors_from_template_with_newline():
|
||||
)
|
||||
|
||||
assert executor.params == {"test": "line1\nline2"}
|
||||
|
||||
|
||||
def test_executor_with_form_data():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
|
||||
variable_pool.add(["pre_node_id", "number_field"], 42)
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test Form Data",
|
||||
method="post",
|
||||
url="https://api.example.com/upload",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: multipart/form-data",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="form-data",
|
||||
data=[
|
||||
BodyData(
|
||||
key="text_field",
|
||||
type="text",
|
||||
value="{{#pre_node_id.text_field#}}",
|
||||
),
|
||||
BodyData(
|
||||
key="number_field",
|
||||
type="text",
|
||||
value="{{#pre_node_id.number_field#}}",
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/upload"
|
||||
assert "Content-Type" in executor.headers
|
||||
assert "multipart/form-data" in executor.headers["Content-Type"]
|
||||
assert executor.params == {}
|
||||
assert executor.json is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check that the form data is correctly loaded in executor.data
|
||||
assert isinstance(executor.data, dict)
|
||||
assert "text_field" in executor.data
|
||||
assert executor.data["text_field"] == "Hello, World!"
|
||||
assert "number_field" in executor.data
|
||||
assert executor.data["number_field"] == "42"
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /upload HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: multipart/form-data" in raw_request
|
||||
assert "text_field" in raw_request
|
||||
assert "Hello, World!" in raw_request
|
||||
assert "number_field" in raw_request
|
||||
assert "42" in raw_request
|
||||
|
@ -1115,7 +1115,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets/{dataset_id}/retrieve"
|
||||
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
|
||||
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
|
||||
"query": "test",
|
||||
"retrieval_model": {
|
||||
"search_method": "keyword_search",
|
||||
|
@ -1116,7 +1116,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
title="Request"
|
||||
tag="POST"
|
||||
label="/datasets/{dataset_id}/retrieve"
|
||||
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
|
||||
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
|
||||
"query": "test",
|
||||
"retrieval_model": {
|
||||
"search_method": "keyword_search",
|
||||
|
@ -468,8 +468,8 @@ const Configuration: FC = () => {
|
||||
transfer_methods: modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
},
|
||||
enabled: !!(modelConfig.file_upload?.enabled || modelConfig.file_upload?.image?.enabled),
|
||||
allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image],
|
||||
allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`),
|
||||
allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image, SupportUploadFileTypes.video],
|
||||
allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image], ...FILE_EXTS[SupportUploadFileTypes.video]].map(ext => `.${ext}`),
|
||||
allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
|
||||
number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3,
|
||||
fileUploadConfig: fileUploadConfigResponse,
|
||||
|
@ -1,6 +1,5 @@
|
||||
import {
|
||||
useCallback,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
import Textarea from 'rc-textarea'
|
||||
@ -63,7 +62,6 @@ const ChatInputArea = ({
|
||||
isMultipleLine,
|
||||
} = useTextAreaHeight()
|
||||
const [query, setQuery] = useState('')
|
||||
const isUseInputMethod = useRef(false)
|
||||
const [showVoiceInput, setShowVoiceInput] = useState(false)
|
||||
const filesStore = useFileStore()
|
||||
const {
|
||||
@ -95,20 +93,11 @@ const ChatInputArea = ({
|
||||
}
|
||||
}
|
||||
|
||||
const handleKeyUp = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.key === 'Enter') {
|
||||
e.preventDefault()
|
||||
// prevent send message when using input method enter
|
||||
if (!e.shiftKey && !isUseInputMethod.current)
|
||||
handleSend()
|
||||
}
|
||||
}
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
isUseInputMethod.current = e.nativeEvent.isComposing
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
setQuery(query.replace(/\n$/, ''))
|
||||
if (e.key === 'Enter' && !e.shiftKey && !e.nativeEvent.isComposing) {
|
||||
e.preventDefault()
|
||||
setQuery(query.replace(/\n$/, ''))
|
||||
handleSend()
|
||||
}
|
||||
}
|
||||
|
||||
@ -165,7 +154,6 @@ const ChatInputArea = ({
|
||||
setQuery(e.target.value)
|
||||
handleTextareaResize()
|
||||
}}
|
||||
onKeyUp={handleKeyUp}
|
||||
onKeyDown={handleKeyDown}
|
||||
onPaste={handleClipboardPasteFile}
|
||||
onDragEnter={handleDragFileEnter}
|
||||
|
@ -120,7 +120,7 @@ const ConfigCredential: FC<Props> = ({
|
||||
<input
|
||||
value={tempCredential.api_key_header}
|
||||
onChange={e => setTempCredential({ ...tempCredential, api_key_header: e.target.value })}
|
||||
className='w-full h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow'
|
||||
className='w-full h-10 px-3 text-sm font-normal border border-transparent bg-gray-100 rounded-lg grow outline-none focus:bg-components-input-bg-active focus:border-components-input-border-active focus:shadow-xs'
|
||||
placeholder={t('tools.createTool.authMethod.types.apiKeyPlaceholder')!}
|
||||
/>
|
||||
</div>
|
||||
@ -129,7 +129,7 @@ const ConfigCredential: FC<Props> = ({
|
||||
<input
|
||||
value={tempCredential.api_key_value}
|
||||
onChange={e => setTempCredential({ ...tempCredential, api_key_value: e.target.value })}
|
||||
className='w-full h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow'
|
||||
className='w-full h-10 px-3 text-sm font-normal border border-transparent bg-gray-100 rounded-lg grow outline-none focus:bg-components-input-bg-active focus:border-components-input-border-active focus:shadow-xs'
|
||||
placeholder={t('tools.createTool.authMethod.types.apiValuePlaceholder')!}
|
||||
/>
|
||||
</div>
|
||||
|
@ -70,7 +70,7 @@ const GetSchema: FC<Props> = ({
|
||||
<div className='relative'>
|
||||
<input
|
||||
type='text'
|
||||
className='w-[244px] h-8 pl-1.5 pr-[44px] overflow-x-auto border border-gray-200 rounded-lg text-[13px]'
|
||||
className='w-[244px] h-8 pl-1.5 pr-[44px] overflow-x-auto border border-gray-200 rounded-lg text-[13px] focus:outline-none focus:border-components-input-border-active'
|
||||
placeholder={t('tools.createTool.importFromUrlPlaceHolder')!}
|
||||
value={importUrl}
|
||||
onChange={e => setImportUrl(e.target.value)}
|
||||
@ -89,7 +89,7 @@ const GetSchema: FC<Props> = ({
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className='relative' ref={showExamplesRef}>
|
||||
<div className='relative -mt-0.5' ref={showExamplesRef}>
|
||||
<Button
|
||||
size='small'
|
||||
className='space-x-1'
|
||||
|
@ -186,8 +186,8 @@ const EditCustomCollectionModal: FC<Props> = ({
|
||||
positionCenter={isAdd && !positionLeft}
|
||||
onHide={onHide}
|
||||
title={t(`tools.createTool.${isAdd ? 'title' : 'editTitle'}`)!}
|
||||
panelClassName='mt-2 !w-[630px]'
|
||||
maxWidthClassName='!max-w-[630px]'
|
||||
panelClassName='mt-2 !w-[640px]'
|
||||
maxWidthClassName='!max-w-[640px]'
|
||||
height='calc(100vh - 16px)'
|
||||
headerClassName='!border-b-black/5'
|
||||
body={
|
||||
|
@ -27,8 +27,8 @@ const Contribute = ({ onRefreshData }: Props) => {
|
||||
|
||||
const linkUrl = useMemo(() => {
|
||||
if (language.startsWith('zh_'))
|
||||
return 'https://docs.dify.ai/v/zh-hans/guides/gong-ju/quick-tool-integration'
|
||||
return 'https://docs.dify.ai/tutorials/quick-tool-integration'
|
||||
return 'https://docs.dify.ai/zh-hans/guides/tools#ru-he-chuang-jian-zi-ding-yi-gong-ju'
|
||||
return 'https://docs.dify.ai/guides/tools#how-to-create-custom-tools'
|
||||
}, [language])
|
||||
|
||||
const [isShowEditCollectionToolModal, setIsShowEditCustomCollectionModal] = useState(false)
|
||||
|
Loading…
x
Reference in New Issue
Block a user