From cb79a90031d16b09ef1747e227cb1e174cc8a435 Mon Sep 17 00:00:00 2001
From: Onelevenvy <49232224+Onelevenvy@users.noreply.github.com>
Date: Mon, 18 Mar 2024 16:22:48 +0800
Subject: [PATCH] feat: Add tools for open weather search and image generation
using the Spark API. (#2845)
---
.../baichuan/text_embedding/text_embedding.py | 2 +-
.../builtin/openweather/_assets/icon.svg | 12 ++
.../builtin/openweather/openweather.py | 36 ++++
.../builtin/openweather/openweather.yaml | 29 ++++
.../builtin/openweather/tools/weather.py | 60 +++++++
.../builtin/openweather/tools/weather.yaml | 80 +++++++++
.../tools/provider/builtin/spark/__init__.py | 0
.../provider/builtin/spark/_assets/icon.svg | 5 +
.../tools/provider/builtin/spark/spark.py | 40 +++++
.../tools/provider/builtin/spark/spark.yaml | 59 +++++++
.../spark/tools/spark_img_generation.py | 154 ++++++++++++++++++
.../spark/tools/spark_img_generation.yaml | 36 ++++
sdks/python-client/dify_client/__init__.py | 2 +-
13 files changed, 513 insertions(+), 2 deletions(-)
create mode 100644 api/core/tools/provider/builtin/openweather/_assets/icon.svg
create mode 100644 api/core/tools/provider/builtin/openweather/openweather.py
create mode 100644 api/core/tools/provider/builtin/openweather/openweather.yaml
create mode 100644 api/core/tools/provider/builtin/openweather/tools/weather.py
create mode 100644 api/core/tools/provider/builtin/openweather/tools/weather.yaml
create mode 100644 api/core/tools/provider/builtin/spark/__init__.py
create mode 100644 api/core/tools/provider/builtin/spark/_assets/icon.svg
create mode 100644 api/core/tools/provider/builtin/spark/spark.py
create mode 100644 api/core/tools/provider/builtin/spark/spark.yaml
create mode 100644 api/core/tools/provider/builtin/spark/tools/spark_img_generation.py
create mode 100644 api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml
diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py
index 535714f663..5ae90d54b5 100644
--- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py
+++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py
@@ -124,7 +124,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
elif err == 'insufficient_quota':
raise InsufficientAccountBalance(msg)
elif err == 'invalid_authentication':
- raise InvalidAuthenticationError(msg)
+ raise InvalidAuthenticationError(msg)
elif err and 'rate' in err:
raise RateLimitReachedError(msg)
elif err and 'internal' in err:
diff --git a/api/core/tools/provider/builtin/openweather/_assets/icon.svg b/api/core/tools/provider/builtin/openweather/_assets/icon.svg
new file mode 100644
index 0000000000..f06cd87e64
--- /dev/null
+++ b/api/core/tools/provider/builtin/openweather/_assets/icon.svg
@@ -0,0 +1,12 @@
+
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/openweather/openweather.py b/api/core/tools/provider/builtin/openweather/openweather.py
new file mode 100644
index 0000000000..a2827177a3
--- /dev/null
+++ b/api/core/tools/provider/builtin/openweather/openweather.py
@@ -0,0 +1,36 @@
+import requests
+
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+def query_weather(city="Beijing", units="metric", language="zh_cn", api_key=None):
+
+ url = "https://api.openweathermap.org/data/2.5/weather"
+ params = {"q": city, "appid": api_key, "units": units, "lang": language}
+
+ return requests.get(url, params=params)
+
+
+class OpenweatherProvider(BuiltinToolProviderController):
+ def _validate_credentials(self, credentials: dict) -> None:
+ try:
+ if "api_key" not in credentials or not credentials.get("api_key"):
+ raise ToolProviderCredentialValidationError(
+ "Open weather API key is required."
+ )
+ apikey = credentials.get("api_key")
+ try:
+ response = query_weather(api_key=apikey)
+ if response.status_code == 200:
+ pass
+ else:
+ raise ToolProviderCredentialValidationError(
+ (response.json()).get("info")
+ )
+ except Exception as e:
+ raise ToolProviderCredentialValidationError(
+ "Open weather API Key is invalid. {}".format(e)
+ )
+ except Exception as e:
+ raise ToolProviderCredentialValidationError(str(e))
diff --git a/api/core/tools/provider/builtin/openweather/openweather.yaml b/api/core/tools/provider/builtin/openweather/openweather.yaml
new file mode 100644
index 0000000000..60bb33c36d
--- /dev/null
+++ b/api/core/tools/provider/builtin/openweather/openweather.yaml
@@ -0,0 +1,29 @@
+identity:
+ author: Onelevenvy
+ name: openweather
+ label:
+ en_US: Open weather query
+ zh_Hans: Open Weather
+ pt_BR: Consulta de clima open weather
+ description:
+ en_US: Weather query toolkit based on Open Weather
+ zh_Hans: 基于open weather的天气查询工具包
+ pt_BR: Kit de consulta de clima baseado no Open Weather
+ icon: icon.svg
+credentials_for_provider:
+ api_key:
+ type: secret-input
+ required: true
+ label:
+ en_US: API Key
+ zh_Hans: API Key
+ pt_BR: Fogo a chave
+ placeholder:
+ en_US: Please enter your open weather API Key
+ zh_Hans: 请输入你的open weather API Key
+ pt_BR: Insira sua chave de API open weather
+ help:
+ en_US: Get your API Key from open weather
+ zh_Hans: 从open weather获取您的 API Key
+ pt_BR: Obtenha sua chave de API do open weather
+ url: https://openweathermap.org
diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.py b/api/core/tools/provider/builtin/openweather/tools/weather.py
new file mode 100644
index 0000000000..536a3511f4
--- /dev/null
+++ b/api/core/tools/provider/builtin/openweather/tools/weather.py
@@ -0,0 +1,60 @@
+import json
+from typing import Any, Union
+
+import requests
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class OpenweatherTool(BuiltinTool):
+ def _invoke(
+ self, user_id: str, tool_parameters: dict[str, Any]
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ """
+ invoke tools
+ """
+ city = tool_parameters.get("city", "")
+ if not city:
+ return self.create_text_message("Please tell me your city")
+ if (
+ "api_key" not in self.runtime.credentials
+ or not self.runtime.credentials.get("api_key")
+ ):
+ return self.create_text_message("OpenWeather API key is required.")
+
+ units = tool_parameters.get("units", "metric")
+ lang = tool_parameters.get("lang", "zh_cn")
+ try:
+ # request URL
+ url = "https://api.openweathermap.org/data/2.5/weather"
+
+ # request parmas
+ params = {
+ "q": city,
+ "appid": self.runtime.credentials.get("api_key"),
+ "units": units,
+ "lang": lang,
+ }
+ response = requests.get(url, params=params)
+
+ if response.status_code == 200:
+
+ data = response.json()
+ return self.create_text_message(
+ self.summary(
+ user_id=user_id, content=json.dumps(data, ensure_ascii=False)
+ )
+ )
+ else:
+ error_message = {
+ "error": f"failed:{response.status_code}",
+ "data": response.text,
+ }
+ # return error
+ return json.dumps(error_message)
+
+ except Exception as e:
+ return self.create_text_message(
+ "Openweather API Key is invalid. {}".format(e)
+ )
diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.yaml b/api/core/tools/provider/builtin/openweather/tools/weather.yaml
new file mode 100644
index 0000000000..f2dae5c2df
--- /dev/null
+++ b/api/core/tools/provider/builtin/openweather/tools/weather.yaml
@@ -0,0 +1,80 @@
+identity:
+ name: weather
+ author: Onelevenvy
+ label:
+ en_US: Open Weather Query
+ zh_Hans: 天气查询
+ pt_BR: Previsão do tempo
+ icon: icon.svg
+description:
+ human:
+ en_US: Weather forecast inquiry
+ zh_Hans: 天气查询
+ pt_BR: Inquérito sobre previsão meteorológica
+ llm: A tool when you want to ask about the weather or weather-related question
+parameters:
+ - name: city
+ type: string
+ required: true
+ label:
+ en_US: city
+ zh_Hans: 城市
+ pt_BR: cidade
+ human_description:
+ en_US: Target city for weather forecast query
+ zh_Hans: 天气预报查询的目标城市
+ pt_BR: Cidade de destino para consulta de previsão do tempo
+ llm_description: If you don't know you can extract the city name from the
+ question or you can reply:Please tell me your city. You have to extract
+ the Chinese city name from the question.If the input region is in Chinese
+ characters for China, it should be replaced with the corresponding English
+ name, such as '北京' for correct input is 'Beijing'
+ form: llm
+ - name: lang
+ type: select
+ required: true
+ human_description:
+ en_US: language
+ zh_Hans: 语言
+ pt_BR: language
+ label:
+ en_US: language
+ zh_Hans: 语言
+ pt_BR: language
+ form: form
+ options:
+ - value: zh_cn
+ label:
+ en_US: cn
+ zh_Hans: 中国
+ pt_BR: cn
+ - value: en_us
+ label:
+ en_US: usa
+ zh_Hans: 美国
+ pt_BR: usa
+ default: zh_cn
+ - name: units
+ type: select
+ required: true
+ human_description:
+ en_US: units for temperature
+ zh_Hans: 温度单位
+ pt_BR: units for temperature
+ label:
+ en_US: units
+ zh_Hans: 单位
+ pt_BR: units
+ form: form
+ options:
+ - value: metric
+ label:
+ en_US: metric
+ zh_Hans: ℃
+ pt_BR: metric
+ - value: imperial
+ label:
+ en_US: imperial
+ zh_Hans: ℉
+ pt_BR: imperial
+ default: metric
diff --git a/api/core/tools/provider/builtin/spark/__init__.py b/api/core/tools/provider/builtin/spark/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/tools/provider/builtin/spark/_assets/icon.svg b/api/core/tools/provider/builtin/spark/_assets/icon.svg
new file mode 100644
index 0000000000..ef0a9131a4
--- /dev/null
+++ b/api/core/tools/provider/builtin/spark/_assets/icon.svg
@@ -0,0 +1,5 @@
+
diff --git a/api/core/tools/provider/builtin/spark/spark.py b/api/core/tools/provider/builtin/spark/spark.py
new file mode 100644
index 0000000000..cb8e69a59f
--- /dev/null
+++ b/api/core/tools/provider/builtin/spark/spark.py
@@ -0,0 +1,40 @@
+import json
+
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin.spark.tools.spark_img_generation import spark_response
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+class SparkProvider(BuiltinToolProviderController):
+ def _validate_credentials(self, credentials: dict) -> None:
+ try:
+ if "APPID" not in credentials or not credentials.get("APPID"):
+ raise ToolProviderCredentialValidationError("APPID is required.")
+ if "APISecret" not in credentials or not credentials.get("APISecret"):
+ raise ToolProviderCredentialValidationError("APISecret is required.")
+ if "APIKey" not in credentials or not credentials.get("APIKey"):
+ raise ToolProviderCredentialValidationError("APIKey is required.")
+
+ appid = credentials.get("APPID")
+ apisecret = credentials.get("APISecret")
+ apikey = credentials.get("APIKey")
+ prompt = "a cute black dog"
+
+ try:
+ response = spark_response(prompt, appid, apikey, apisecret)
+ data = json.loads(response)
+ code = data["header"]["code"]
+
+ if code == 0:
+ # 0 success,
+ pass
+ else:
+ raise ToolProviderCredentialValidationError(
+ "image generate error, code:{}".format(code)
+ )
+ except Exception as e:
+ raise ToolProviderCredentialValidationError(
+ "APPID APISecret APIKey is invalid. {}".format(e)
+ )
+ except Exception as e:
+ raise ToolProviderCredentialValidationError(str(e))
diff --git a/api/core/tools/provider/builtin/spark/spark.yaml b/api/core/tools/provider/builtin/spark/spark.yaml
new file mode 100644
index 0000000000..f2b9c89e96
--- /dev/null
+++ b/api/core/tools/provider/builtin/spark/spark.yaml
@@ -0,0 +1,59 @@
+identity:
+ author: Onelevenvy
+ name: spark
+ label:
+ en_US: Spark
+ zh_Hans: 讯飞星火
+ pt_BR: Spark
+ description:
+ en_US: Spark Platform Toolkit
+ zh_Hans: 讯飞星火平台工具
+ pt_BR: Pacote de Ferramentas da Plataforma Spark
+ icon: icon.svg
+credentials_for_provider:
+ APPID:
+ type: secret-input
+ required: true
+ label:
+ en_US: Spark APPID
+ zh_Hans: APPID
+ pt_BR: Spark APPID
+ help:
+ en_US: Please input your APPID
+ zh_Hans: 请输入你的 APPID
+ pt_BR: Please input your APPID
+ placeholder:
+ en_US: Please input your APPID
+ zh_Hans: 请输入你的 APPID
+ pt_BR: Please input your APPID
+ APISecret:
+ type: secret-input
+ required: true
+ label:
+ en_US: Spark APISecret
+ zh_Hans: APISecret
+ pt_BR: Spark APISecret
+ help:
+ en_US: Please input your Spark APISecret
+ zh_Hans: 请输入你的 APISecret
+ pt_BR: Please input your Spark APISecret
+ placeholder:
+ en_US: Please input your Spark APISecret
+ zh_Hans: 请输入你的 APISecret
+ pt_BR: Please input your Spark APISecret
+ APIKey:
+ type: secret-input
+ required: true
+ label:
+ en_US: Spark APIKey
+ zh_Hans: APIKey
+ pt_BR: Spark APIKey
+ help:
+ en_US: Please input your Spark APIKey
+ zh_Hans: 请输入你的 APIKey
+ pt_BR: Please input your Spark APIKey
+ placeholder:
+ en_US: Please input your Spark APIKey
+ zh_Hans: 请输入你的 APIKey
+ pt_BR: Please input Spark APIKey
+ url: https://console.xfyun.cn/services
diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py
new file mode 100644
index 0000000000..a977af2b76
--- /dev/null
+++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py
@@ -0,0 +1,154 @@
+import base64
+import hashlib
+import hmac
+import json
+from base64 import b64decode
+from datetime import datetime
+from time import mktime
+from typing import Any, Union
+from urllib.parse import urlencode
+from wsgiref.handlers import format_date_time
+
+import requests
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class AssembleHeaderException(Exception):
+ def __init__(self, msg):
+ self.message = msg
+
+
+class Url:
+ def __init__(this, host, path, schema):
+ this.host = host
+ this.path = path
+ this.schema = schema
+
+
+# calculate sha256 and encode to base64
+def sha256base64(data):
+ sha256 = hashlib.sha256()
+ sha256.update(data)
+ digest = base64.b64encode(sha256.digest()).decode(encoding="utf-8")
+ return digest
+
+
+def parse_url(requset_url):
+ stidx = requset_url.index("://")
+ host = requset_url[stidx + 3 :]
+ schema = requset_url[: stidx + 3]
+ edidx = host.index("/")
+ if edidx <= 0:
+ raise AssembleHeaderException("invalid request url:" + requset_url)
+ path = host[edidx:]
+ host = host[:edidx]
+ u = Url(host, path, schema)
+ return u
+
+def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
+ u = parse_url(requset_url)
+ host = u.host
+ path = u.path
+ now = datetime.now()
+ date = format_date_time(mktime(now.timetuple()))
+ signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(
+ host, date, method, path
+ )
+ signature_sha = hmac.new(
+ api_secret.encode("utf-8"),
+ signature_origin.encode("utf-8"),
+ digestmod=hashlib.sha256,
+ ).digest()
+ signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
+ authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"'
+
+ authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
+ encoding="utf-8"
+ )
+ values = {"host": host, "date": date, "authorization": authorization}
+
+ return requset_url + "?" + urlencode(values)
+
+
+def get_body(appid, text):
+ body = {
+ "header": {"app_id": appid, "uid": "123456789"},
+ "parameter": {
+ "chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096}
+ },
+ "payload": {"message": {"text": [{"role": "user", "content": text}]}},
+ }
+ return body
+
+
+def spark_response(text, appid, apikey, apisecret):
+ host = "http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti"
+ url = assemble_ws_auth_url(
+ host, method="POST", api_key=apikey, api_secret=apisecret
+ )
+ content = get_body(appid, text)
+ response = requests.post(
+ url, json=content, headers={"content-type": "application/json"}
+ ).text
+ return response
+
+
+class SparkImgGeneratorTool(BuiltinTool):
+ def _invoke(
+ self,
+ user_id: str,
+ tool_parameters: dict[str, Any],
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+ """
+ invoke tools
+ """
+
+ if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get(
+ "APPID"
+ ):
+ return self.create_text_message("APPID is required.")
+ if (
+ "APISecret" not in self.runtime.credentials
+ or not self.runtime.credentials.get("APISecret")
+ ):
+ return self.create_text_message("APISecret is required.")
+ if (
+ "APIKey" not in self.runtime.credentials
+ or not self.runtime.credentials.get("APIKey")
+ ):
+ return self.create_text_message("APIKey is required.")
+
+ prompt = tool_parameters.get("prompt", "")
+ if not prompt:
+ return self.create_text_message("Please input prompt")
+ res = self.img_generation(prompt)
+ result = []
+ for image in res:
+ result.append(
+ self.create_blob_message(
+ blob=b64decode(image["base64_image"]),
+ meta={"mime_type": "image/png"},
+ save_as=self.VARIABLE_KEY.IMAGE.value,
+ )
+ )
+ return result
+
+ def img_generation(self, prompt):
+ response = spark_response(
+ text=prompt,
+ appid=self.runtime.credentials.get("APPID"),
+ apikey=self.runtime.credentials.get("APIKey"),
+ apisecret=self.runtime.credentials.get("APISecret"),
+ )
+ data = json.loads(response)
+ code = data["header"]["code"]
+ if code != 0:
+ return self.create_text_message(f"error: {code}, {data}")
+ else:
+ text = data["payload"]["choices"]["text"]
+ image_content = text[0]
+ image_base = image_content["content"]
+ json_data = {"base64_image": image_base}
+ return [json_data]
diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml
new file mode 100644
index 0000000000..d44bbc9564
--- /dev/null
+++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml
@@ -0,0 +1,36 @@
+identity:
+ name: spark_img_generation
+ author: Onelevenvy
+ label:
+ en_US: Spark Image Generation
+ zh_Hans: 图片生成
+ pt_BR: Geração de imagens Spark
+ icon: icon.svg
+ description:
+ en_US: Spark Image Generation
+ zh_Hans: 图片生成
+ pt_BR: Geração de imagens Spark
+description:
+ human:
+ en_US: Generate images based on user input, with image generation API
+ provided by Spark
+ zh_Hans: 根据用户的输入生成图片,由讯飞星火提供图片生成api
+ pt_BR: Gerar imagens com base na entrada do usuário, com API de geração
+ de imagem fornecida pela Spark
+ llm: spark_img_generation is a tool used to generate images from text
+parameters:
+ - name: prompt
+ type: string
+ required: true
+ label:
+ en_US: Prompt
+ zh_Hans: 提示词
+ pt_BR: Prompt
+ human_description:
+ en_US: Image prompt
+ zh_Hans: 图像提示词
+ pt_BR: Image prompt
+ llm_description: Image prompt of spark_img_generation tooll, you should
+ describe the image you want to generate as a list of words as possible
+ as detailed
+ form: llm
diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py
index 6fa9d190e5..6ef0017fee 100644
--- a/sdks/python-client/dify_client/__init__.py
+++ b/sdks/python-client/dify_client/__init__.py
@@ -1 +1 @@
-from dify_client.client import ChatClient, CompletionClient, DifyClient
+from dify_client.client import ChatClient, CompletionClient, DifyClient
\ No newline at end of file