mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 22:35:54 +08:00
feat: update the xinf tool's API key to optional (#9073)
This commit is contained in:
parent
8204e0e14a
commit
5213650fed
@ -104,14 +104,15 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
model = self.runtime.credentials.get("model", None)
|
model = self.runtime.credentials.get("model", None)
|
||||||
if not model:
|
if not model:
|
||||||
return self.create_text_message("Please input model")
|
return self.create_text_message("Please input model")
|
||||||
|
api_key = self.runtime.credentials.get("api_key") or "abc"
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
# set model
|
# set model
|
||||||
try:
|
try:
|
||||||
url = str(URL(base_url) / "sdapi" / "v1" / "options")
|
url = str(URL(base_url) / "sdapi" / "v1" / "options")
|
||||||
response = post(
|
response = post(
|
||||||
url,
|
url,
|
||||||
json={"sd_model_checkpoint": model},
|
json={"sd_model_checkpoint": model},
|
||||||
headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
|
headers=headers,
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model")
|
raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model")
|
||||||
@ -257,14 +258,15 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
draw_options["prompt"] = f"{lora},{prompt}"
|
draw_options["prompt"] = f"{lora},{prompt}"
|
||||||
else:
|
else:
|
||||||
draw_options["prompt"] = prompt
|
draw_options["prompt"] = prompt
|
||||||
|
api_key = self.runtime.credentials.get("api_key") or "abc"
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
try:
|
try:
|
||||||
url = str(URL(base_url) / "sdapi" / "v1" / "img2img")
|
url = str(URL(base_url) / "sdapi" / "v1" / "img2img")
|
||||||
response = post(
|
response = post(
|
||||||
url,
|
url,
|
||||||
json=draw_options,
|
json=draw_options,
|
||||||
timeout=120,
|
timeout=120,
|
||||||
headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
|
headers=headers,
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
return self.create_text_message("Failed to generate image")
|
return self.create_text_message("Failed to generate image")
|
||||||
@ -298,14 +300,15 @@ class StableDiffusionTool(BuiltinTool):
|
|||||||
else:
|
else:
|
||||||
draw_options["prompt"] = prompt
|
draw_options["prompt"] = prompt
|
||||||
draw_options["override_settings"]["sd_model_checkpoint"] = model
|
draw_options["override_settings"]["sd_model_checkpoint"] = model
|
||||||
|
api_key = self.runtime.credentials.get("api_key") or "abc"
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
try:
|
try:
|
||||||
url = str(URL(base_url) / "sdapi" / "v1" / "txt2img")
|
url = str(URL(base_url) / "sdapi" / "v1" / "txt2img")
|
||||||
response = post(
|
response = post(
|
||||||
url,
|
url,
|
||||||
json=draw_options,
|
json=draw_options,
|
||||||
timeout=120,
|
timeout=120,
|
||||||
headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
|
headers=headers,
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
return self.create_text_message("Failed to generate image")
|
return self.create_text_message("Failed to generate image")
|
||||||
|
@ -6,12 +6,18 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
|||||||
|
|
||||||
class XinferenceProvider(BuiltinToolProviderController):
|
class XinferenceProvider(BuiltinToolProviderController):
|
||||||
def _validate_credentials(self, credentials: dict) -> None:
|
def _validate_credentials(self, credentials: dict) -> None:
|
||||||
base_url = credentials.get("base_url")
|
base_url = credentials.get("base_url", "").removesuffix("/")
|
||||||
api_key = credentials.get("api_key")
|
api_key = credentials.get("api_key", "")
|
||||||
model = credentials.get("model")
|
if not api_key:
|
||||||
|
api_key = "abc"
|
||||||
|
credentials["api_key"] = api_key
|
||||||
|
model = credentials.get("model", "")
|
||||||
|
if not base_url or not model:
|
||||||
|
raise ToolProviderCredentialValidationError("Xinference base_url and model is required")
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
f"{base_url}/sdapi/v1/options",
|
f"{base_url}/sdapi/v1/options",
|
||||||
headers={"Authorization": f"Bearer {api_key}"},
|
headers=headers,
|
||||||
json={"sd_model_checkpoint": model},
|
json={"sd_model_checkpoint": model},
|
||||||
)
|
)
|
||||||
if res.status_code != 200:
|
if res.status_code != 200:
|
||||||
|
@ -31,7 +31,7 @@ credentials_for_provider:
|
|||||||
zh_Hans: 请输入你的模型名称
|
zh_Hans: 请输入你的模型名称
|
||||||
api_key:
|
api_key:
|
||||||
type: secret-input
|
type: secret-input
|
||||||
required: true
|
required: false
|
||||||
label:
|
label:
|
||||||
en_US: API Key
|
en_US: API Key
|
||||||
zh_Hans: Xinference 服务器的 API Key
|
zh_Hans: Xinference 服务器的 API Key
|
||||||
|
Loading…
x
Reference in New Issue
Block a user