From cb79a90031d16b09ef1747e227cb1e174cc8a435 Mon Sep 17 00:00:00 2001 From: Onelevenvy <49232224+Onelevenvy@users.noreply.github.com> Date: Mon, 18 Mar 2024 16:22:48 +0800 Subject: [PATCH] feat: Add tools for open weather search and image generation using the Spark API. (#2845) --- .../baichuan/text_embedding/text_embedding.py | 2 +- .../builtin/openweather/_assets/icon.svg | 12 ++ .../builtin/openweather/openweather.py | 36 ++++ .../builtin/openweather/openweather.yaml | 29 ++++ .../builtin/openweather/tools/weather.py | 60 +++++++ .../builtin/openweather/tools/weather.yaml | 80 +++++++++ .../tools/provider/builtin/spark/__init__.py | 0 .../provider/builtin/spark/_assets/icon.svg | 5 + .../tools/provider/builtin/spark/spark.py | 40 +++++ .../tools/provider/builtin/spark/spark.yaml | 59 +++++++ .../spark/tools/spark_img_generation.py | 154 ++++++++++++++++++ .../spark/tools/spark_img_generation.yaml | 36 ++++ sdks/python-client/dify_client/__init__.py | 2 +- 13 files changed, 513 insertions(+), 2 deletions(-) create mode 100644 api/core/tools/provider/builtin/openweather/_assets/icon.svg create mode 100644 api/core/tools/provider/builtin/openweather/openweather.py create mode 100644 api/core/tools/provider/builtin/openweather/openweather.yaml create mode 100644 api/core/tools/provider/builtin/openweather/tools/weather.py create mode 100644 api/core/tools/provider/builtin/openweather/tools/weather.yaml create mode 100644 api/core/tools/provider/builtin/spark/__init__.py create mode 100644 api/core/tools/provider/builtin/spark/_assets/icon.svg create mode 100644 api/core/tools/provider/builtin/spark/spark.py create mode 100644 api/core/tools/provider/builtin/spark/spark.yaml create mode 100644 api/core/tools/provider/builtin/spark/tools/spark_img_generation.py create mode 100644 api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 535714f663..5ae90d54b5 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -124,7 +124,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): elif err == 'insufficient_quota': raise InsufficientAccountBalance(msg) elif err == 'invalid_authentication': - raise InvalidAuthenticationError(msg) + raise InvalidAuthenticationError(msg) elif err and 'rate' in err: raise RateLimitReachedError(msg) elif err and 'internal' in err: diff --git a/api/core/tools/provider/builtin/openweather/_assets/icon.svg b/api/core/tools/provider/builtin/openweather/_assets/icon.svg new file mode 100644 index 0000000000..f06cd87e64 --- /dev/null +++ b/api/core/tools/provider/builtin/openweather/_assets/icon.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/openweather/openweather.py b/api/core/tools/provider/builtin/openweather/openweather.py new file mode 100644 index 0000000000..a2827177a3 --- /dev/null +++ b/api/core/tools/provider/builtin/openweather/openweather.py @@ -0,0 +1,36 @@ +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +def query_weather(city="Beijing", units="metric", language="zh_cn", api_key=None): + + url = "https://api.openweathermap.org/data/2.5/weather" + params = {"q": city, "appid": api_key, "units": units, "lang": language} + + return requests.get(url, params=params) + + +class OpenweatherProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + if "api_key" not in credentials or not credentials.get("api_key"): + raise ToolProviderCredentialValidationError( + "Open weather API key is required." + ) + apikey = credentials.get("api_key") + try: + response = query_weather(api_key=apikey) + if response.status_code == 200: + pass + else: + raise ToolProviderCredentialValidationError( + (response.json()).get("info") + ) + except Exception as e: + raise ToolProviderCredentialValidationError( + "Open weather API Key is invalid. {}".format(e) + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/openweather/openweather.yaml b/api/core/tools/provider/builtin/openweather/openweather.yaml new file mode 100644 index 0000000000..60bb33c36d --- /dev/null +++ b/api/core/tools/provider/builtin/openweather/openweather.yaml @@ -0,0 +1,29 @@ +identity: + author: Onelevenvy + name: openweather + label: + en_US: Open weather query + zh_Hans: Open Weather + pt_BR: Consulta de clima open weather + description: + en_US: Weather query toolkit based on Open Weather + zh_Hans: 基于open weather的天气查询工具包 + pt_BR: Kit de consulta de clima baseado no Open Weather + icon: icon.svg +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: API Key + zh_Hans: API Key + pt_BR: Fogo a chave + placeholder: + en_US: Please enter your open weather API Key + zh_Hans: 请输入你的open weather API Key + pt_BR: Insira sua chave de API open weather + help: + en_US: Get your API Key from open weather + zh_Hans: 从open weather获取您的 API Key + pt_BR: Obtenha sua chave de API do open weather + url: https://openweathermap.org diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.py b/api/core/tools/provider/builtin/openweather/tools/weather.py new file mode 100644 index 0000000000..536a3511f4 --- /dev/null +++ b/api/core/tools/provider/builtin/openweather/tools/weather.py @@ -0,0 +1,60 @@ +import json +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class OpenweatherTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + city = tool_parameters.get("city", "") + if not city: + return self.create_text_message("Please tell me your city") + if ( + "api_key" not in self.runtime.credentials + or not self.runtime.credentials.get("api_key") + ): + return self.create_text_message("OpenWeather API key is required.") + + units = tool_parameters.get("units", "metric") + lang = tool_parameters.get("lang", "zh_cn") + try: + # request URL + url = "https://api.openweathermap.org/data/2.5/weather" + + # request parmas + params = { + "q": city, + "appid": self.runtime.credentials.get("api_key"), + "units": units, + "lang": lang, + } + response = requests.get(url, params=params) + + if response.status_code == 200: + + data = response.json() + return self.create_text_message( + self.summary( + user_id=user_id, content=json.dumps(data, ensure_ascii=False) + ) + ) + else: + error_message = { + "error": f"failed:{response.status_code}", + "data": response.text, + } + # return error + return json.dumps(error_message) + + except Exception as e: + return self.create_text_message( + "Openweather API Key is invalid. {}".format(e) + ) diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.yaml b/api/core/tools/provider/builtin/openweather/tools/weather.yaml new file mode 100644 index 0000000000..f2dae5c2df --- /dev/null +++ b/api/core/tools/provider/builtin/openweather/tools/weather.yaml @@ -0,0 +1,80 @@ +identity: + name: weather + author: Onelevenvy + label: + en_US: Open Weather Query + zh_Hans: 天气查询 + pt_BR: Previsão do tempo + icon: icon.svg +description: + human: + en_US: Weather forecast inquiry + zh_Hans: 天气查询 + pt_BR: Inquérito sobre previsão meteorológica + llm: A tool when you want to ask about the weather or weather-related question +parameters: + - name: city + type: string + required: true + label: + en_US: city + zh_Hans: 城市 + pt_BR: cidade + human_description: + en_US: Target city for weather forecast query + zh_Hans: 天气预报查询的目标城市 + pt_BR: Cidade de destino para consulta de previsão do tempo + llm_description: If you don't know you can extract the city name from the + question or you can reply:Please tell me your city. You have to extract + the Chinese city name from the question.If the input region is in Chinese + characters for China, it should be replaced with the corresponding English + name, such as '北京' for correct input is 'Beijing' + form: llm + - name: lang + type: select + required: true + human_description: + en_US: language + zh_Hans: 语言 + pt_BR: language + label: + en_US: language + zh_Hans: 语言 + pt_BR: language + form: form + options: + - value: zh_cn + label: + en_US: cn + zh_Hans: 中国 + pt_BR: cn + - value: en_us + label: + en_US: usa + zh_Hans: 美国 + pt_BR: usa + default: zh_cn + - name: units + type: select + required: true + human_description: + en_US: units for temperature + zh_Hans: 温度单位 + pt_BR: units for temperature + label: + en_US: units + zh_Hans: 单位 + pt_BR: units + form: form + options: + - value: metric + label: + en_US: metric + zh_Hans: ℃ + pt_BR: metric + - value: imperial + label: + en_US: imperial + zh_Hans: ℉ + pt_BR: imperial + default: metric diff --git a/api/core/tools/provider/builtin/spark/__init__.py b/api/core/tools/provider/builtin/spark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tools/provider/builtin/spark/_assets/icon.svg b/api/core/tools/provider/builtin/spark/_assets/icon.svg new file mode 100644 index 0000000000..ef0a9131a4 --- /dev/null +++ b/api/core/tools/provider/builtin/spark/_assets/icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/api/core/tools/provider/builtin/spark/spark.py b/api/core/tools/provider/builtin/spark/spark.py new file mode 100644 index 0000000000..cb8e69a59f --- /dev/null +++ b/api/core/tools/provider/builtin/spark/spark.py @@ -0,0 +1,40 @@ +import json + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.spark.tools.spark_img_generation import spark_response +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class SparkProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + if "APPID" not in credentials or not credentials.get("APPID"): + raise ToolProviderCredentialValidationError("APPID is required.") + if "APISecret" not in credentials or not credentials.get("APISecret"): + raise ToolProviderCredentialValidationError("APISecret is required.") + if "APIKey" not in credentials or not credentials.get("APIKey"): + raise ToolProviderCredentialValidationError("APIKey is required.") + + appid = credentials.get("APPID") + apisecret = credentials.get("APISecret") + apikey = credentials.get("APIKey") + prompt = "a cute black dog" + + try: + response = spark_response(prompt, appid, apikey, apisecret) + data = json.loads(response) + code = data["header"]["code"] + + if code == 0: + # 0 success, + pass + else: + raise ToolProviderCredentialValidationError( + "image generate error, code:{}".format(code) + ) + except Exception as e: + raise ToolProviderCredentialValidationError( + "APPID APISecret APIKey is invalid. {}".format(e) + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/spark/spark.yaml b/api/core/tools/provider/builtin/spark/spark.yaml new file mode 100644 index 0000000000..f2b9c89e96 --- /dev/null +++ b/api/core/tools/provider/builtin/spark/spark.yaml @@ -0,0 +1,59 @@ +identity: + author: Onelevenvy + name: spark + label: + en_US: Spark + zh_Hans: 讯飞星火 + pt_BR: Spark + description: + en_US: Spark Platform Toolkit + zh_Hans: 讯飞星火平台工具 + pt_BR: Pacote de Ferramentas da Plataforma Spark + icon: icon.svg +credentials_for_provider: + APPID: + type: secret-input + required: true + label: + en_US: Spark APPID + zh_Hans: APPID + pt_BR: Spark APPID + help: + en_US: Please input your APPID + zh_Hans: 请输入你的 APPID + pt_BR: Please input your APPID + placeholder: + en_US: Please input your APPID + zh_Hans: 请输入你的 APPID + pt_BR: Please input your APPID + APISecret: + type: secret-input + required: true + label: + en_US: Spark APISecret + zh_Hans: APISecret + pt_BR: Spark APISecret + help: + en_US: Please input your Spark APISecret + zh_Hans: 请输入你的 APISecret + pt_BR: Please input your Spark APISecret + placeholder: + en_US: Please input your Spark APISecret + zh_Hans: 请输入你的 APISecret + pt_BR: Please input your Spark APISecret + APIKey: + type: secret-input + required: true + label: + en_US: Spark APIKey + zh_Hans: APIKey + pt_BR: Spark APIKey + help: + en_US: Please input your Spark APIKey + zh_Hans: 请输入你的 APIKey + pt_BR: Please input your Spark APIKey + placeholder: + en_US: Please input your Spark APIKey + zh_Hans: 请输入你的 APIKey + pt_BR: Please input Spark APIKey + url: https://console.xfyun.cn/services diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py new file mode 100644 index 0000000000..a977af2b76 --- /dev/null +++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py @@ -0,0 +1,154 @@ +import base64 +import hashlib +import hmac +import json +from base64 import b64decode +from datetime import datetime +from time import mktime +from typing import Any, Union +from urllib.parse import urlencode +from wsgiref.handlers import format_date_time + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class AssembleHeaderException(Exception): + def __init__(self, msg): + self.message = msg + + +class Url: + def __init__(this, host, path, schema): + this.host = host + this.path = path + this.schema = schema + + +# calculate sha256 and encode to base64 +def sha256base64(data): + sha256 = hashlib.sha256() + sha256.update(data) + digest = base64.b64encode(sha256.digest()).decode(encoding="utf-8") + return digest + + +def parse_url(requset_url): + stidx = requset_url.index("://") + host = requset_url[stidx + 3 :] + schema = requset_url[: stidx + 3] + edidx = host.index("/") + if edidx <= 0: + raise AssembleHeaderException("invalid request url:" + requset_url) + path = host[edidx:] + host = host[:edidx] + u = Url(host, path, schema) + return u + +def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""): + u = parse_url(requset_url) + host = u.host + path = u.path + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format( + host, date, method, path + ) + signature_sha = hmac.new( + api_secret.encode("utf-8"), + signature_origin.encode("utf-8"), + digestmod=hashlib.sha256, + ).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8") + authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' + + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( + encoding="utf-8" + ) + values = {"host": host, "date": date, "authorization": authorization} + + return requset_url + "?" + urlencode(values) + + +def get_body(appid, text): + body = { + "header": {"app_id": appid, "uid": "123456789"}, + "parameter": { + "chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096} + }, + "payload": {"message": {"text": [{"role": "user", "content": text}]}}, + } + return body + + +def spark_response(text, appid, apikey, apisecret): + host = "http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti" + url = assemble_ws_auth_url( + host, method="POST", api_key=apikey, api_secret=apisecret + ) + content = get_body(appid, text) + response = requests.post( + url, json=content, headers={"content-type": "application/json"} + ).text + return response + + +class SparkImgGeneratorTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + + if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get( + "APPID" + ): + return self.create_text_message("APPID is required.") + if ( + "APISecret" not in self.runtime.credentials + or not self.runtime.credentials.get("APISecret") + ): + return self.create_text_message("APISecret is required.") + if ( + "APIKey" not in self.runtime.credentials + or not self.runtime.credentials.get("APIKey") + ): + return self.create_text_message("APIKey is required.") + + prompt = tool_parameters.get("prompt", "") + if not prompt: + return self.create_text_message("Please input prompt") + res = self.img_generation(prompt) + result = [] + for image in res: + result.append( + self.create_blob_message( + blob=b64decode(image["base64_image"]), + meta={"mime_type": "image/png"}, + save_as=self.VARIABLE_KEY.IMAGE.value, + ) + ) + return result + + def img_generation(self, prompt): + response = spark_response( + text=prompt, + appid=self.runtime.credentials.get("APPID"), + apikey=self.runtime.credentials.get("APIKey"), + apisecret=self.runtime.credentials.get("APISecret"), + ) + data = json.loads(response) + code = data["header"]["code"] + if code != 0: + return self.create_text_message(f"error: {code}, {data}") + else: + text = data["payload"]["choices"]["text"] + image_content = text[0] + image_base = image_content["content"] + json_data = {"base64_image": image_base} + return [json_data] diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml new file mode 100644 index 0000000000..d44bbc9564 --- /dev/null +++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml @@ -0,0 +1,36 @@ +identity: + name: spark_img_generation + author: Onelevenvy + label: + en_US: Spark Image Generation + zh_Hans: 图片生成 + pt_BR: Geração de imagens Spark + icon: icon.svg + description: + en_US: Spark Image Generation + zh_Hans: 图片生成 + pt_BR: Geração de imagens Spark +description: + human: + en_US: Generate images based on user input, with image generation API + provided by Spark + zh_Hans: 根据用户的输入生成图片,由讯飞星火提供图片生成api + pt_BR: Gerar imagens com base na entrada do usuário, com API de geração + de imagem fornecida pela Spark + llm: spark_img_generation is a tool used to generate images from text +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + pt_BR: Prompt + human_description: + en_US: Image prompt + zh_Hans: 图像提示词 + pt_BR: Image prompt + llm_description: Image prompt of spark_img_generation tooll, you should + describe the image you want to generate as a list of words as possible + as detailed + form: llm diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py index 6fa9d190e5..6ef0017fee 100644 --- a/sdks/python-client/dify_client/__init__.py +++ b/sdks/python-client/dify_client/__init__.py @@ -1 +1 @@ -from dify_client.client import ChatClient, CompletionClient, DifyClient +from dify_client.client import ChatClient, CompletionClient, DifyClient \ No newline at end of file