mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 02:09:02 +08:00
Feat: AIPPT & DynamicToolParamter (#2725)
This commit is contained in:
parent
7052565380
commit
27e678480e
@ -9,6 +9,7 @@
|
|||||||
- azuredalle
|
- azuredalle
|
||||||
- stablediffusion
|
- stablediffusion
|
||||||
- webscraper
|
- webscraper
|
||||||
|
- aippt
|
||||||
- youtube
|
- youtube
|
||||||
- wolframalpha
|
- wolframalpha
|
||||||
- maths
|
- maths
|
||||||
|
BIN
api/core/tools/provider/builtin/aippt/_assets/icon.png
Normal file
BIN
api/core/tools/provider/builtin/aippt/_assets/icon.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.9 KiB |
11
api/core/tools/provider/builtin/aippt/aippt.py
Normal file
11
api/core/tools/provider/builtin/aippt/aippt.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool
|
||||||
|
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||||
|
|
||||||
|
|
||||||
|
class AIPPTProvider(BuiltinToolProviderController):
|
||||||
|
def _validate_credentials(self, credentials: dict) -> None:
|
||||||
|
try:
|
||||||
|
AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__')
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolProviderCredentialValidationError(str(e))
|
42
api/core/tools/provider/builtin/aippt/aippt.yaml
Normal file
42
api/core/tools/provider/builtin/aippt/aippt.yaml
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
identity:
|
||||||
|
author: Dify
|
||||||
|
name: aippt
|
||||||
|
label:
|
||||||
|
en_US: AIPPT
|
||||||
|
zh_Hans: AIPPT
|
||||||
|
description:
|
||||||
|
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
|
||||||
|
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
|
||||||
|
icon: icon.png
|
||||||
|
credentials_for_provider:
|
||||||
|
aippt_access_key:
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: AIPPT API key
|
||||||
|
zh_Hans: AIPPT API key
|
||||||
|
pt_BR: AIPPT API key
|
||||||
|
help:
|
||||||
|
en_US: Please input your AIPPT API key
|
||||||
|
zh_Hans: 请输入你的 AIPPT API key
|
||||||
|
pt_BR: Please input your AIPPT API key
|
||||||
|
placeholder:
|
||||||
|
en_US: Please input your AIPPT API key
|
||||||
|
zh_Hans: 请输入你的 AIPPT API key
|
||||||
|
pt_BR: Please input your AIPPT API key
|
||||||
|
url: https://www.aippt.cn
|
||||||
|
aippt_secret_key:
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: AIPPT Secret key
|
||||||
|
zh_Hans: AIPPT Secret key
|
||||||
|
pt_BR: AIPPT Secret key
|
||||||
|
help:
|
||||||
|
en_US: Please input your AIPPT Secret key
|
||||||
|
zh_Hans: 请输入你的 AIPPT Secret key
|
||||||
|
pt_BR: Please input your AIPPT Secret key
|
||||||
|
placeholder:
|
||||||
|
en_US: Please input your AIPPT Secret key
|
||||||
|
zh_Hans: 请输入你的 AIPPT Secret key
|
||||||
|
pt_BR: Please input your AIPPT Secret key
|
509
api/core/tools/provider/builtin/aippt/tools/aippt.py
Normal file
509
api/core/tools/provider/builtin/aippt/tools/aippt.py
Normal file
@ -0,0 +1,509 @@
|
|||||||
|
from base64 import b64encode
|
||||||
|
from hashlib import sha1
|
||||||
|
from hmac import new as hmac_new
|
||||||
|
from json import loads as json_loads
|
||||||
|
from threading import Lock
|
||||||
|
from time import sleep, time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from httpx import get, post
|
||||||
|
from requests import get as requests_get
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||||
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
|
||||||
|
class AIPPTGenerateTool(BuiltinTool):
|
||||||
|
"""
|
||||||
|
A tool for generating a ppt
|
||||||
|
"""
|
||||||
|
|
||||||
|
_api_base_url = URL('https://co.aippt.cn/api')
|
||||||
|
_api_token_cache = {}
|
||||||
|
_api_token_cache_lock = Lock()
|
||||||
|
|
||||||
|
_task = {}
|
||||||
|
_task_type_map = {
|
||||||
|
'auto': 1,
|
||||||
|
'markdown': 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||||
|
"""
|
||||||
|
Invokes the AIPPT generate tool with the given user ID and tool parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The ID of the user invoking the tool.
|
||||||
|
tool_parameters (dict[str, Any]): The parameters for the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages.
|
||||||
|
"""
|
||||||
|
title = tool_parameters.get('title', '')
|
||||||
|
if not title:
|
||||||
|
return self.create_text_message('Please provide a title for the ppt')
|
||||||
|
|
||||||
|
model = tool_parameters.get('model', 'aippt')
|
||||||
|
if not model:
|
||||||
|
return self.create_text_message('Please provide a model for the ppt')
|
||||||
|
|
||||||
|
outline = tool_parameters.get('outline', '')
|
||||||
|
|
||||||
|
# create task
|
||||||
|
task_id = self._create_task(
|
||||||
|
type=self._task_type_map['auto' if not outline else 'markdown'],
|
||||||
|
title=title,
|
||||||
|
content=outline,
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# get suit
|
||||||
|
color = tool_parameters.get('color')
|
||||||
|
style = tool_parameters.get('style')
|
||||||
|
|
||||||
|
if color == '__default__':
|
||||||
|
color_id = ''
|
||||||
|
else:
|
||||||
|
color_id = int(color.split('-')[1])
|
||||||
|
|
||||||
|
if style == '__default__':
|
||||||
|
style_id = ''
|
||||||
|
else:
|
||||||
|
style_id = int(style.split('-')[1])
|
||||||
|
|
||||||
|
suit_id = self._get_suit(style_id=style_id, colour_id=color_id)
|
||||||
|
|
||||||
|
# generate outline
|
||||||
|
if not outline:
|
||||||
|
self._generate_outline(
|
||||||
|
task_id=task_id,
|
||||||
|
model=model,
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# generate content
|
||||||
|
self._generate_content(
|
||||||
|
task_id=task_id,
|
||||||
|
model=model,
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# generate ppt
|
||||||
|
_, ppt_url = self._generate_ppt(
|
||||||
|
task_id=task_id,
|
||||||
|
suit_id=suit_id,
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.create_text_message('''the ppt has been created successfully,'''
|
||||||
|
f'''the ppt url is {ppt_url}'''
|
||||||
|
'''please give the ppt url to user and direct user to download it.''')
|
||||||
|
|
||||||
|
def _create_task(self, type: int, title: str, content: str, user_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Create a task
|
||||||
|
|
||||||
|
:param type: the task type
|
||||||
|
:param title: the task title
|
||||||
|
:param content: the task content
|
||||||
|
|
||||||
|
:return: the task ID
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
'x-channel': '',
|
||||||
|
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||||
|
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||||
|
}
|
||||||
|
response = post(
|
||||||
|
str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'),
|
||||||
|
headers=headers,
|
||||||
|
files={
|
||||||
|
'type': ('', str(type)),
|
||||||
|
'title': ('', title),
|
||||||
|
'content': ('', content)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||||
|
|
||||||
|
response = response.json()
|
||||||
|
if response.get('code') != 0:
|
||||||
|
raise Exception(f'Failed to create task: {response.get("msg")}')
|
||||||
|
|
||||||
|
return response.get('data', {}).get('id')
|
||||||
|
|
||||||
|
def _generate_outline(self, task_id: str, model: str, user_id: str) -> str:
|
||||||
|
api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \
|
||||||
|
self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline'
|
||||||
|
api_url %= {'task_id': task_id}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
'x-channel': '',
|
||||||
|
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||||
|
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests_get(
|
||||||
|
url=api_url,
|
||||||
|
headers=headers,
|
||||||
|
stream=True,
|
||||||
|
timeout=(10, 60)
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||||
|
|
||||||
|
outline = ''
|
||||||
|
for chunk in response.iter_lines(delimiter=b'\n\n'):
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
|
||||||
|
event = ''
|
||||||
|
lines = chunk.decode('utf-8').split('\n')
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith('event:'):
|
||||||
|
event = line[6:]
|
||||||
|
elif line.startswith('data:'):
|
||||||
|
data = line[5:]
|
||||||
|
if event == 'message':
|
||||||
|
try:
|
||||||
|
data = json_loads(data)
|
||||||
|
outline += data.get('content', '')
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
elif event == 'close':
|
||||||
|
break
|
||||||
|
elif event == 'error' or event == 'filter':
|
||||||
|
raise Exception(f'Failed to generate outline: {data}')
|
||||||
|
|
||||||
|
return outline
|
||||||
|
|
||||||
|
def _generate_content(self, task_id: str, model: str, user_id: str) -> str:
|
||||||
|
api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \
|
||||||
|
self._api_base_url / 'ai' / 'chat' / 'wx' / 'content'
|
||||||
|
api_url %= {'task_id': task_id}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
'x-channel': '',
|
||||||
|
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||||
|
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests_get(
|
||||||
|
url=api_url,
|
||||||
|
headers=headers,
|
||||||
|
stream=True,
|
||||||
|
timeout=(10, 60)
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||||
|
|
||||||
|
if model == 'aippt':
|
||||||
|
content = ''
|
||||||
|
for chunk in response.iter_lines(delimiter=b'\n\n'):
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
|
||||||
|
event = ''
|
||||||
|
lines = chunk.decode('utf-8').split('\n')
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith('event:'):
|
||||||
|
event = line[6:]
|
||||||
|
elif line.startswith('data:'):
|
||||||
|
data = line[5:]
|
||||||
|
if event == 'message':
|
||||||
|
try:
|
||||||
|
data = json_loads(data)
|
||||||
|
content += data.get('content', '')
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
elif event == 'close':
|
||||||
|
break
|
||||||
|
elif event == 'error' or event == 'filter':
|
||||||
|
raise Exception(f'Failed to generate content: {data}')
|
||||||
|
|
||||||
|
return content
|
||||||
|
elif model == 'wenxin':
|
||||||
|
response = response.json()
|
||||||
|
if response.get('code') != 0:
|
||||||
|
raise Exception(f'Failed to generate content: {response.get("msg")}')
|
||||||
|
|
||||||
|
return response.get('data', '')
|
||||||
|
|
||||||
|
return ''
|
||||||
|
|
||||||
|
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Generate a ppt
|
||||||
|
|
||||||
|
:param task_id: the task ID
|
||||||
|
:param suit_id: the suit ID
|
||||||
|
:return: the cover url of the ppt and the ppt url
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
'x-channel': '',
|
||||||
|
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||||
|
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||||
|
}
|
||||||
|
|
||||||
|
response = post(
|
||||||
|
str(self._api_base_url / 'design' / 'v2' / 'save'),
|
||||||
|
headers=headers,
|
||||||
|
data={
|
||||||
|
'task_id': task_id,
|
||||||
|
'template_id': suit_id
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||||
|
|
||||||
|
response = response.json()
|
||||||
|
if response.get('code') != 0:
|
||||||
|
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||||
|
|
||||||
|
id = response.get('data', {}).get('id')
|
||||||
|
cover_url = response.get('data', {}).get('cover_url')
|
||||||
|
|
||||||
|
response = post(
|
||||||
|
str(self._api_base_url / 'download' / 'export' / 'file'),
|
||||||
|
headers=headers,
|
||||||
|
data={
|
||||||
|
'id': id,
|
||||||
|
'format': 'ppt',
|
||||||
|
'files_to_zip': False,
|
||||||
|
'edit': True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||||
|
|
||||||
|
response = response.json()
|
||||||
|
if response.get('code') != 0:
|
||||||
|
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||||
|
|
||||||
|
export_code = response.get('data')
|
||||||
|
if not export_code:
|
||||||
|
raise Exception('Failed to generate ppt, the export code is empty')
|
||||||
|
|
||||||
|
current_iteration = 0
|
||||||
|
while current_iteration < 50:
|
||||||
|
# get ppt url
|
||||||
|
response = post(
|
||||||
|
str(self._api_base_url / 'download' / 'export' / 'file' / 'result'),
|
||||||
|
headers=headers,
|
||||||
|
data={
|
||||||
|
'task_key': export_code
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||||
|
|
||||||
|
response = response.json()
|
||||||
|
if response.get('code') != 0:
|
||||||
|
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||||
|
|
||||||
|
if response.get('msg') == '导出中':
|
||||||
|
current_iteration += 1
|
||||||
|
sleep(2)
|
||||||
|
continue
|
||||||
|
|
||||||
|
ppt_url = response.get('data', [])
|
||||||
|
if len(ppt_url) == 0:
|
||||||
|
raise Exception('Failed to generate ppt, the ppt url is empty')
|
||||||
|
|
||||||
|
return cover_url, ppt_url[0]
|
||||||
|
|
||||||
|
raise Exception('Failed to generate ppt, the export is timeout')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Get API token
|
||||||
|
|
||||||
|
:param credentials: the credentials
|
||||||
|
:return: the API token
|
||||||
|
"""
|
||||||
|
access_key = credentials['aippt_access_key']
|
||||||
|
secret_key = credentials['aippt_secret_key']
|
||||||
|
|
||||||
|
cache_key = f'{access_key}#@#{user_id}'
|
||||||
|
|
||||||
|
with cls._api_token_cache_lock:
|
||||||
|
# clear expired tokens
|
||||||
|
now = time()
|
||||||
|
for key in list(cls._api_token_cache.keys()):
|
||||||
|
if cls._api_token_cache[key]['expire'] < now:
|
||||||
|
del cls._api_token_cache[key]
|
||||||
|
|
||||||
|
if cache_key in cls._api_token_cache:
|
||||||
|
return cls._api_token_cache[cache_key]['token']
|
||||||
|
|
||||||
|
# get token
|
||||||
|
headers = {
|
||||||
|
'x-api-key': access_key,
|
||||||
|
'x-timestamp': str(int(now)),
|
||||||
|
'x-signature': cls._calculate_sign(access_key, secret_key, int(now))
|
||||||
|
}
|
||||||
|
|
||||||
|
param = {
|
||||||
|
'uid': user_id,
|
||||||
|
'channel': ''
|
||||||
|
}
|
||||||
|
|
||||||
|
response = get(
|
||||||
|
str(cls._api_base_url / 'grant' / 'token'),
|
||||||
|
params=param,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||||
|
response = response.json()
|
||||||
|
if response.get('code') != 0:
|
||||||
|
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||||
|
|
||||||
|
token = response.get('data', {}).get('token')
|
||||||
|
expire = response.get('data', {}).get('time_expire')
|
||||||
|
|
||||||
|
with cls._api_token_cache_lock:
|
||||||
|
cls._api_token_cache[cache_key] = {
|
||||||
|
'token': token,
|
||||||
|
'expire': now + expire
|
||||||
|
}
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
|
||||||
|
return b64encode(
|
||||||
|
hmac_new(
|
||||||
|
key=secret_key.encode('utf-8'),
|
||||||
|
msg=f'GET@/api/grant/token/@{timestamp}'.encode(),
|
||||||
|
digestmod=sha1
|
||||||
|
).digest()
|
||||||
|
).decode('utf-8')
|
||||||
|
|
||||||
|
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
|
||||||
|
"""
|
||||||
|
Get styles
|
||||||
|
|
||||||
|
:param credentials: the credentials
|
||||||
|
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
'x-channel': '',
|
||||||
|
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||||
|
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id)
|
||||||
|
}
|
||||||
|
response = get(
|
||||||
|
str(self._api_base_url / 'template_component' / 'suit' / 'select'),
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||||
|
|
||||||
|
response = response.json()
|
||||||
|
|
||||||
|
if response.get('code') != 0:
|
||||||
|
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||||
|
|
||||||
|
colors = [{
|
||||||
|
'id': f'id-{item.get("id")}',
|
||||||
|
'name': item.get('name'),
|
||||||
|
'en_name': item.get('en_name', item.get('name')),
|
||||||
|
} for item in response.get('data', {}).get('colour') or []]
|
||||||
|
styles = [{
|
||||||
|
'id': f'id-{item.get("id")}',
|
||||||
|
'name': item.get('title'),
|
||||||
|
} for item in response.get('data', {}).get('suit_style') or []]
|
||||||
|
|
||||||
|
return colors, styles
|
||||||
|
|
||||||
|
def _get_suit(self, style_id: int, colour_id: int) -> int:
|
||||||
|
"""
|
||||||
|
Get suit
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
'x-channel': '',
|
||||||
|
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||||
|
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__')
|
||||||
|
}
|
||||||
|
response = get(
|
||||||
|
str(self._api_base_url / 'template_component' / 'suit' / 'search'),
|
||||||
|
headers=headers,
|
||||||
|
params={
|
||||||
|
'style_id': style_id,
|
||||||
|
'colour_id': colour_id,
|
||||||
|
'page': 1,
|
||||||
|
'page_size': 1
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||||
|
|
||||||
|
response = response.json()
|
||||||
|
|
||||||
|
if response.get('code') != 0:
|
||||||
|
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||||
|
|
||||||
|
if len(response.get('data', {}).get('list') or []) > 0:
|
||||||
|
return response.get('data', {}).get('list')[0].get('id')
|
||||||
|
|
||||||
|
raise Exception('Failed to get suit, the suit does not exist, please check the style and color')
|
||||||
|
|
||||||
|
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||||
|
"""
|
||||||
|
Get runtime parameters
|
||||||
|
|
||||||
|
Override this method to add runtime parameters to the tool.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
colors, styles = self.get_styles(user_id='__dify_system__')
|
||||||
|
except Exception as e:
|
||||||
|
colors, styles = [
|
||||||
|
{'id': -1, 'name': '__default__'}
|
||||||
|
], [
|
||||||
|
{'id': -1, 'name': '__default__'}
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
ToolParameter(
|
||||||
|
name='color',
|
||||||
|
label=I18nObject(zh_Hans='颜色', en_US='Color'),
|
||||||
|
human_description=I18nObject(zh_Hans='颜色', en_US='Color'),
|
||||||
|
type=ToolParameter.ToolParameterType.SELECT,
|
||||||
|
form=ToolParameter.ToolParameterForm.FORM,
|
||||||
|
required=False,
|
||||||
|
default=colors[0]['id'],
|
||||||
|
options=[
|
||||||
|
ToolParameterOption(
|
||||||
|
value=color['id'],
|
||||||
|
label=I18nObject(zh_Hans=color['name'], en_US=color['en_name'])
|
||||||
|
) for color in colors
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ToolParameter(
|
||||||
|
name='style',
|
||||||
|
label=I18nObject(zh_Hans='风格', en_US='Style'),
|
||||||
|
human_description=I18nObject(zh_Hans='风格', en_US='Style'),
|
||||||
|
type=ToolParameter.ToolParameterType.SELECT,
|
||||||
|
form=ToolParameter.ToolParameterForm.FORM,
|
||||||
|
required=False,
|
||||||
|
default=styles[0]['id'],
|
||||||
|
options=[
|
||||||
|
ToolParameterOption(
|
||||||
|
value=style['id'],
|
||||||
|
label=I18nObject(zh_Hans=style['name'], en_US=style['name'])
|
||||||
|
) for style in styles
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
54
api/core/tools/provider/builtin/aippt/tools/aippt.yaml
Normal file
54
api/core/tools/provider/builtin/aippt/tools/aippt.yaml
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
identity:
|
||||||
|
name: aippt
|
||||||
|
author: Dify
|
||||||
|
label:
|
||||||
|
en_US: AIPPT
|
||||||
|
zh_Hans: AIPPT
|
||||||
|
description:
|
||||||
|
human:
|
||||||
|
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
|
||||||
|
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
|
||||||
|
llm: A tool used to generate PPT with AI, input your content topic, and let AI generate PPT for you.
|
||||||
|
parameters:
|
||||||
|
- name: title
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: Title
|
||||||
|
zh_Hans: 标题
|
||||||
|
human_description:
|
||||||
|
en_US: The title of the PPT.
|
||||||
|
zh_Hans: PPT的标题。
|
||||||
|
llm_description: The title of the PPT, which will be used to generate the PPT outline.
|
||||||
|
form: llm
|
||||||
|
- name: outline
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: Outline
|
||||||
|
zh_Hans: 大纲
|
||||||
|
human_description:
|
||||||
|
en_US: The outline of the PPT
|
||||||
|
zh_Hans: PPT的大纲
|
||||||
|
llm_description: The outline of the PPT, which will be used to generate the PPT content. provide it if you have.
|
||||||
|
form: llm
|
||||||
|
- name: llm
|
||||||
|
type: select
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: LLM model
|
||||||
|
zh_Hans: 生成大纲的LLM
|
||||||
|
options:
|
||||||
|
- value: aippt
|
||||||
|
label:
|
||||||
|
en_US: AIPPT default model
|
||||||
|
zh_Hans: AIPPT默认模型
|
||||||
|
- value: wenxin
|
||||||
|
label:
|
||||||
|
en_US: Wenxin ErnieBot
|
||||||
|
zh_Hans: 文心一言
|
||||||
|
default: aippt
|
||||||
|
human_description:
|
||||||
|
en_US: The LLM model used for generating PPT outline.
|
||||||
|
zh_Hans: 用于生成PPT大纲的LLM模型。
|
||||||
|
form: form
|
@ -2,11 +2,11 @@ import io
|
|||||||
import json
|
import json
|
||||||
from base64 import b64decode, b64encode
|
from base64 import b64decode, b64encode
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from os.path import join
|
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from httpx import get, post
|
from httpx import get, post
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||||
@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
|
|
||||||
# set model
|
# set model
|
||||||
try:
|
try:
|
||||||
url = join(base_url, 'sdapi/v1/options')
|
url = str(URL(base_url) / 'sdapi' / 'v1' / 'options')
|
||||||
response = post(url, data=json.dumps({
|
response = post(url, data=json.dumps({
|
||||||
'sd_model_checkpoint': model
|
'sd_model_checkpoint': model
|
||||||
}))
|
}))
|
||||||
@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
if not model:
|
if not model:
|
||||||
raise ToolProviderCredentialValidationError('Please input model')
|
raise ToolProviderCredentialValidationError('Please input model')
|
||||||
|
|
||||||
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
|
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
|
||||||
if response.status_code != 200:
|
response = get(url=api_url, timeout=10)
|
||||||
|
if response.status_code == 404:
|
||||||
|
# try draw a picture
|
||||||
|
self._invoke(
|
||||||
|
user_id='test',
|
||||||
|
tool_parameters={
|
||||||
|
'prompt': 'a cat',
|
||||||
|
'width': 1024,
|
||||||
|
'height': 1024,
|
||||||
|
'steps': 1,
|
||||||
|
'lora': '',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif response.status_code != 200:
|
||||||
raise ToolProviderCredentialValidationError('Failed to get models')
|
raise ToolProviderCredentialValidationError('Failed to get models')
|
||||||
else:
|
else:
|
||||||
models = [d['model_name'] for d in response.json()]
|
models = [d['model_name'] for d in response.json()]
|
||||||
@ -165,6 +178,23 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
|
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
|
||||||
|
|
||||||
|
def get_sd_models(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
get sd models
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
base_url = self.runtime.credentials.get('base_url', None)
|
||||||
|
if not base_url:
|
||||||
|
return []
|
||||||
|
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
|
||||||
|
response = get(url=api_url, timeout=10)
|
||||||
|
if response.status_code != 200:
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return [d['model_name'] for d in response.json()]
|
||||||
|
except Exception as e:
|
||||||
|
return []
|
||||||
|
|
||||||
def img2img(self, base_url: str, lora: str, image_binary: bytes,
|
def img2img(self, base_url: str, lora: str, image_binary: bytes,
|
||||||
prompt: str, negative_prompt: str,
|
prompt: str, negative_prompt: str,
|
||||||
width: int, height: int, steps: int) \
|
width: int, height: int, steps: int) \
|
||||||
@ -192,7 +222,7 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
draw_options['prompt'] = prompt
|
draw_options['prompt'] = prompt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url = join(base_url, 'sdapi/v1/img2img')
|
url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
|
||||||
response = post(url, data=json.dumps(draw_options), timeout=120)
|
response = post(url, data=json.dumps(draw_options), timeout=120)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
return self.create_text_message('Failed to generate image')
|
return self.create_text_message('Failed to generate image')
|
||||||
@ -225,7 +255,7 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
draw_options['negative_prompt'] = negative_prompt
|
draw_options['negative_prompt'] = negative_prompt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url = join(base_url, 'sdapi/v1/txt2img')
|
url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
|
||||||
response = post(url, data=json.dumps(draw_options), timeout=120)
|
response = post(url, data=json.dumps(draw_options), timeout=120)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
return self.create_text_message('Failed to generate image')
|
return self.create_text_message('Failed to generate image')
|
||||||
@ -269,5 +299,29 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
label=I18nObject(en_US=i.name, zh_Hans=i.name)
|
label=I18nObject(en_US=i.name, zh_Hans=i.name)
|
||||||
) for i in self.list_default_image_variables()])
|
) for i in self.list_default_image_variables()])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.runtime.credentials:
|
||||||
|
try:
|
||||||
|
models = self.get_sd_models()
|
||||||
|
if len(models) != 0:
|
||||||
|
parameters.append(
|
||||||
|
ToolParameter(name='model',
|
||||||
|
label=I18nObject(en_US='Model', zh_Hans='Model'),
|
||||||
|
human_description=I18nObject(
|
||||||
|
en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
|
||||||
|
zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档',
|
||||||
|
),
|
||||||
|
type=ToolParameter.ToolParameterType.SELECT,
|
||||||
|
form=ToolParameter.ToolParameterForm.FORM,
|
||||||
|
llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
|
||||||
|
required=True,
|
||||||
|
default=models[0],
|
||||||
|
options=[ToolParameterOption(
|
||||||
|
value=i,
|
||||||
|
label=I18nObject(en_US=i, zh_Hans=i)
|
||||||
|
) for i in models])
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
return parameters
|
return parameters
|
||||||
|
@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import (
|
|||||||
ApiProviderAuthType,
|
ApiProviderAuthType,
|
||||||
ApiProviderSchemaType,
|
ApiProviderSchemaType,
|
||||||
ToolCredentialsOption,
|
ToolCredentialsOption,
|
||||||
|
ToolParameter,
|
||||||
ToolProviderCredentials,
|
ToolProviderCredentials,
|
||||||
)
|
)
|
||||||
from core.tools.entities.user_entities import UserTool, UserToolProvider
|
from core.tools.entities.user_entities import UserTool, UserToolProvider
|
||||||
@ -73,15 +74,52 @@ class ToolManageService:
|
|||||||
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
|
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
|
||||||
tools = provider_controller.get_tools()
|
tools = provider_controller.get_tools()
|
||||||
|
|
||||||
result = [
|
tool_provider_configurations = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||||
UserTool(
|
# check if user has added the provider
|
||||||
|
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||||
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
|
BuiltinToolProvider.provider == provider,
|
||||||
|
).first()
|
||||||
|
|
||||||
|
credentials = {}
|
||||||
|
if builtin_provider is not None:
|
||||||
|
# get credentials
|
||||||
|
credentials = builtin_provider.credentials
|
||||||
|
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for tool in tools:
|
||||||
|
# fork tool runtime
|
||||||
|
tool = tool.fork_tool_runtime(meta={
|
||||||
|
'credentials': credentials,
|
||||||
|
'tenant_id': tenant_id,
|
||||||
|
})
|
||||||
|
|
||||||
|
# get tool parameters
|
||||||
|
parameters = tool.parameters or []
|
||||||
|
# get tool runtime parameters
|
||||||
|
runtime_parameters = tool.get_runtime_parameters()
|
||||||
|
# override parameters
|
||||||
|
current_parameters = parameters.copy()
|
||||||
|
for runtime_parameter in runtime_parameters:
|
||||||
|
found = False
|
||||||
|
for index, parameter in enumerate(current_parameters):
|
||||||
|
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||||
|
current_parameters[index] = runtime_parameter
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||||
|
current_parameters.append(runtime_parameter)
|
||||||
|
|
||||||
|
user_tool = UserTool(
|
||||||
author=tool.identity.author,
|
author=tool.identity.author,
|
||||||
name=tool.identity.name,
|
name=tool.identity.name,
|
||||||
label=tool.identity.label,
|
label=tool.identity.label,
|
||||||
description=tool.description.human,
|
description=tool.description.human,
|
||||||
parameters=tool.parameters or []
|
parameters=current_parameters
|
||||||
) for tool in tools
|
)
|
||||||
]
|
result.append(user_tool)
|
||||||
|
|
||||||
return json.loads(
|
return json.loads(
|
||||||
serialize_base_model_array(result)
|
serialize_base_model_array(result)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user