mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 21:59:00 +08:00
feat: auto rule generator (#273)
This commit is contained in:
parent
44a1aa5e44
commit
490858a4d5
@ -9,7 +9,7 @@ api = ExternalApi(bp)
|
|||||||
from . import setup, version, apikey, admin
|
from . import setup, version, apikey, admin
|
||||||
|
|
||||||
# Import app controllers
|
# Import app controllers
|
||||||
from .app import app, site, completion, model_config, statistic, conversation, message
|
from .app import app, site, completion, model_config, statistic, conversation, message, generator
|
||||||
|
|
||||||
# Import auth controllers
|
# Import auth controllers
|
||||||
from .auth import login, oauth
|
from .auth import login, oauth
|
||||||
|
@ -9,18 +9,13 @@ from werkzeug.exceptions import Unauthorized, Forbidden
|
|||||||
|
|
||||||
from constants.model_template import model_templates, demo_model_templates
|
from constants.model_template import model_templates, demo_model_templates
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError, ProviderQuotaExceededError, \
|
from controllers.console.app.error import AppNotFoundError
|
||||||
CompletionRequestError, ProviderModelCurrentlyNotSupportError
|
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.generator.llm_generator import LLMGenerator
|
|
||||||
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
|
|
||||||
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
|
||||||
from events.app_event import app_was_created, app_was_deleted
|
from events.app_event import app_was_created, app_was_deleted
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, AppModelConfig, Site, InstalledApp
|
from models.model import App, AppModelConfig, Site
|
||||||
from services.account_service import TenantService
|
|
||||||
from services.app_model_config_service import AppModelConfigService
|
from services.app_model_config_service import AppModelConfigService
|
||||||
|
|
||||||
model_config_fields = {
|
model_config_fields = {
|
||||||
@ -478,35 +473,6 @@ class AppExport(Resource):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class IntroductionGenerateApi(Resource):
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
def post(self):
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument('prompt_template', type=str, required=True, location='json')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
account = current_user
|
|
||||||
|
|
||||||
try:
|
|
||||||
answer = LLMGenerator.generate_introduction(
|
|
||||||
account.current_tenant_id,
|
|
||||||
args['prompt_template']
|
|
||||||
)
|
|
||||||
except ProviderTokenNotInitError:
|
|
||||||
raise ProviderNotInitializeError()
|
|
||||||
except QuotaExceededError:
|
|
||||||
raise ProviderQuotaExceededError()
|
|
||||||
except ModelCurrentlyNotSupportError:
|
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
|
||||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
|
||||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
|
||||||
raise CompletionRequestError(str(e))
|
|
||||||
|
|
||||||
return {'introduction': answer}
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(AppListApi, '/apps')
|
api.add_resource(AppListApi, '/apps')
|
||||||
api.add_resource(AppTemplateApi, '/app-templates')
|
api.add_resource(AppTemplateApi, '/app-templates')
|
||||||
api.add_resource(AppApi, '/apps/<uuid:app_id>')
|
api.add_resource(AppApi, '/apps/<uuid:app_id>')
|
||||||
@ -515,4 +481,3 @@ api.add_resource(AppNameApi, '/apps/<uuid:app_id>/name')
|
|||||||
api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
|
api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
|
||||||
api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
|
api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
|
||||||
api.add_resource(AppRateLimit, '/apps/<uuid:app_id>/rate-limit')
|
api.add_resource(AppRateLimit, '/apps/<uuid:app_id>/rate-limit')
|
||||||
api.add_resource(IntroductionGenerateApi, '/introduction-generate')
|
|
||||||
|
75
api/controllers/console/app/generator.py
Normal file
75
api/controllers/console/app/generator.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from flask_login import login_required, current_user
|
||||||
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \
|
||||||
|
CompletionRequestError, ProviderModelCurrentlyNotSupportError
|
||||||
|
from controllers.console.setup import setup_required
|
||||||
|
from controllers.console.wraps import account_initialization_required
|
||||||
|
from core.generator.llm_generator import LLMGenerator
|
||||||
|
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
|
||||||
|
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||||
|
|
||||||
|
|
||||||
|
class IntroductionGenerateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('prompt_template', type=str, required=True, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
account = current_user
|
||||||
|
|
||||||
|
try:
|
||||||
|
answer = LLMGenerator.generate_introduction(
|
||||||
|
account.current_tenant_id,
|
||||||
|
args['prompt_template']
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError:
|
||||||
|
raise ProviderNotInitializeError()
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||||
|
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||||
|
raise CompletionRequestError(str(e))
|
||||||
|
|
||||||
|
return {'introduction': answer}
|
||||||
|
|
||||||
|
|
||||||
|
class RuleGenerateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('audiences', type=str, required=True, nullable=False, location='json')
|
||||||
|
parser.add_argument('hoping_to_solve', type=str, required=True, nullable=False, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
account = current_user
|
||||||
|
|
||||||
|
try:
|
||||||
|
rules = LLMGenerator.generate_rule_config(
|
||||||
|
account.current_tenant_id,
|
||||||
|
args['audiences'],
|
||||||
|
args['hoping_to_solve']
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError:
|
||||||
|
raise ProviderNotInitializeError()
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||||
|
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||||
|
raise CompletionRequestError(str(e))
|
||||||
|
|
||||||
|
return rules
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(IntroductionGenerateApi, '/introduction-generate')
|
||||||
|
api.add_resource(RuleGenerateApi, '/rule-generate')
|
@ -11,6 +11,8 @@ from langchain.chains import LLMChain
|
|||||||
from langchain.prompts import BasePromptTemplate
|
from langchain.prompts import BasePromptTemplate
|
||||||
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
|
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
|
||||||
|
|
||||||
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||||
|
|
||||||
|
|
||||||
class Route(NamedTuple):
|
class Route(NamedTuple):
|
||||||
destination: Optional[str]
|
destination: Optional[str]
|
||||||
@ -82,42 +84,10 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
|
|||||||
next_inputs_type: Type = str
|
next_inputs_type: Type = str
|
||||||
next_inputs_inner_key: str = "input"
|
next_inputs_inner_key: str = "input"
|
||||||
|
|
||||||
def parse_json_markdown(self, json_string: str) -> dict:
|
|
||||||
# Remove the triple backticks if present
|
|
||||||
json_string = json_string.strip()
|
|
||||||
start_index = json_string.find("```json")
|
|
||||||
end_index = json_string.find("```", start_index + len("```json"))
|
|
||||||
|
|
||||||
if start_index != -1 and end_index != -1:
|
|
||||||
extracted_content = json_string[start_index + len("```json"):end_index].strip()
|
|
||||||
|
|
||||||
# Parse the JSON string into a Python dictionary
|
|
||||||
parsed = json.loads(extracted_content)
|
|
||||||
elif json_string.startswith("{"):
|
|
||||||
# Parse the JSON string into a Python dictionary
|
|
||||||
parsed = json.loads(json_string)
|
|
||||||
else:
|
|
||||||
raise Exception("Could not find JSON block in the output.")
|
|
||||||
|
|
||||||
return parsed
|
|
||||||
|
|
||||||
def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict:
|
|
||||||
try:
|
|
||||||
json_obj = self.parse_json_markdown(text)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
|
||||||
for key in expected_keys:
|
|
||||||
if key not in json_obj:
|
|
||||||
raise OutputParserException(
|
|
||||||
f"Got invalid return object. Expected key `{key}` "
|
|
||||||
f"to be present, but got {json_obj}"
|
|
||||||
)
|
|
||||||
return json_obj
|
|
||||||
|
|
||||||
def parse(self, text: str) -> Dict[str, Any]:
|
def parse(self, text: str) -> Dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
expected_keys = ["destination", "next_inputs"]
|
expected_keys = ["destination", "next_inputs"]
|
||||||
parsed = self.parse_and_check_json_markdown(text, expected_keys)
|
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||||
if not isinstance(parsed["destination"], str):
|
if not isinstance(parsed["destination"], str):
|
||||||
raise ValueError("Expected 'destination' to be a string.")
|
raise ValueError("Expected 'destination' to be a string.")
|
||||||
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
||||||
@ -135,5 +105,5 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
|
|||||||
return parsed
|
return parsed
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise OutputParserException(
|
raise OutputParserException(
|
||||||
f"Parsing text\n{text}\n raised following error:\n{e}"
|
f"Parsing text\n{text}\n of llm router raised following error:\n{e}"
|
||||||
)
|
)
|
||||||
|
@ -23,7 +23,8 @@ think that revising it will ultimately lead to a better response from the langua
|
|||||||
model.
|
model.
|
||||||
|
|
||||||
<< FORMATTING >>
|
<< FORMATTING >>
|
||||||
Return a markdown code snippet with a JSON object formatted to look like:
|
Return a markdown code snippet with a JSON object formatted to look like, \
|
||||||
|
no any other string out of markdown code snippet:
|
||||||
```json
|
```json
|
||||||
{{{{
|
{{{{
|
||||||
"destination": string \\ name of the prompt to use or "DEFAULT"
|
"destination": string \\ name of the prompt to use or "DEFAULT"
|
||||||
|
@ -7,6 +7,7 @@ from core.constant import llm_constant
|
|||||||
from core.llm.llm_builder import LLMBuilder
|
from core.llm.llm_builder import LLMBuilder
|
||||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||||
from core.llm.token_calculator import TokenCalculator
|
from core.llm.token_calculator import TokenCalculator
|
||||||
|
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||||
|
|
||||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||||
from core.prompt.prompt_template import OutLinePromptTemplate
|
from core.prompt.prompt_template import OutLinePromptTemplate
|
||||||
@ -118,3 +119,46 @@ class LLMGenerator:
|
|||||||
questions = []
|
questions = []
|
||||||
|
|
||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
|
||||||
|
output_parser = RuleConfigGeneratorOutputParser()
|
||||||
|
|
||||||
|
prompt = OutLinePromptTemplate(
|
||||||
|
template=output_parser.get_format_instructions(),
|
||||||
|
input_variables=["audiences", "hoping_to_solve"],
|
||||||
|
partial_variables={
|
||||||
|
"variable": '{variable}',
|
||||||
|
"lanA": '{lanA}',
|
||||||
|
"lanB": '{lanB}',
|
||||||
|
"topic": '{topic}'
|
||||||
|
},
|
||||||
|
validate_template=False
|
||||||
|
)
|
||||||
|
|
||||||
|
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
|
||||||
|
|
||||||
|
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_name=generate_base_model,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=512
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(llm, BaseChatModel):
|
||||||
|
query = [HumanMessage(content=_input.to_string())]
|
||||||
|
else:
|
||||||
|
query = _input.to_string()
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = llm(query)
|
||||||
|
rule_config = output_parser.parse(output)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Error generating prompt")
|
||||||
|
rule_config = {
|
||||||
|
"prompt": "",
|
||||||
|
"variables": [],
|
||||||
|
"opening_statement": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return rule_config
|
||||||
|
32
api/core/prompt/output_parser/rule_config_generator.py
Normal file
32
api/core/prompt/output_parser/rule_config_generator.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.schema import BaseOutputParser, OutputParserException
|
||||||
|
from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE
|
||||||
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||||
|
|
||||||
|
|
||||||
|
class RuleConfigGeneratorOutputParser(BaseOutputParser):
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
return RULE_CONFIG_GENERATE_TEMPLATE
|
||||||
|
|
||||||
|
def parse(self, text: str) -> Any:
|
||||||
|
try:
|
||||||
|
expected_keys = ["prompt", "variables", "opening_statement"]
|
||||||
|
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||||
|
if not isinstance(parsed["prompt"], str):
|
||||||
|
raise ValueError("Expected 'prompt' to be a string.")
|
||||||
|
if not isinstance(parsed["variables"], list):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected 'variables' to be a list."
|
||||||
|
)
|
||||||
|
if not isinstance(parsed["opening_statement"], str):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected 'opening_statement' to be a str."
|
||||||
|
)
|
||||||
|
return parsed
|
||||||
|
except Exception as e:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}"
|
||||||
|
)
|
||||||
|
|
@ -61,3 +61,60 @@ QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
|
|||||||
QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt(
|
QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt(
|
||||||
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
|
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
|
||||||
)
|
)
|
||||||
|
|
||||||
|
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
|
||||||
|
the model prompt that best suits the input.
|
||||||
|
You will be provided with the prompt, variables, and an opening statement.
|
||||||
|
Only the content enclosed in double curly braces, such as {{variable}}, in the prompt can be considered as a variable; \
|
||||||
|
otherwise, it cannot exist as a variable in the variables.
|
||||||
|
If you believe revising the original input will result in a better response from the language model, you may \
|
||||||
|
suggest revisions.
|
||||||
|
|
||||||
|
<< FORMATTING >>
|
||||||
|
Return a markdown code snippet with a JSON object formatted to look like, \
|
||||||
|
no any other string out of markdown code snippet:
|
||||||
|
```json
|
||||||
|
{{{{
|
||||||
|
"prompt": string \\ generated prompt
|
||||||
|
"variables": list of string \\ variables
|
||||||
|
"opening_statement": string \\ an opening statement to guide users on how to ask questions with generated prompt \
|
||||||
|
and fill in variables, with a welcome sentence, and keep TLDR.
|
||||||
|
}}}}
|
||||||
|
```
|
||||||
|
|
||||||
|
<< EXAMPLES >>
|
||||||
|
[EXAMPLE A]
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"prompt": "Write a letter about love",
|
||||||
|
"variables": [],
|
||||||
|
"opening_statement": "Hi! I'm your love letter writer AI."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
[EXAMPLE B]
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"prompt": "Translate from {{lanA}} to {{lanB}}",
|
||||||
|
"variables": ["lanA", "lanB"],
|
||||||
|
"opening_statement": "Welcome to use translate app"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
[EXAMPLE C]
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"prompt": "Write a story about {{topic}}",
|
||||||
|
"variables": ["topic"],
|
||||||
|
"opening_statement": "I'm your story writer"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
<< MY INTENDED AUDIENCES >>
|
||||||
|
{audiences}
|
||||||
|
|
||||||
|
<< HOPING TO SOLVE >>
|
||||||
|
{hoping_to_solve}
|
||||||
|
|
||||||
|
<< OUTPUT >>
|
||||||
|
"""
|
38
api/libs/json_in_md_parser.py
Normal file
38
api/libs/json_in_md_parser.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.schema import OutputParserException
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json_markdown(json_string: str) -> dict:
|
||||||
|
# Remove the triple backticks if present
|
||||||
|
json_string = json_string.strip()
|
||||||
|
start_index = json_string.find("```json")
|
||||||
|
end_index = json_string.find("```", start_index + len("```json"))
|
||||||
|
|
||||||
|
if start_index != -1 and end_index != -1:
|
||||||
|
extracted_content = json_string[start_index + len("```json"):end_index].strip()
|
||||||
|
|
||||||
|
# Parse the JSON string into a Python dictionary
|
||||||
|
parsed = json.loads(extracted_content)
|
||||||
|
elif json_string.startswith("{"):
|
||||||
|
# Parse the JSON string into a Python dictionary
|
||||||
|
parsed = json.loads(json_string)
|
||||||
|
else:
|
||||||
|
raise Exception("Could not find JSON block in the output.")
|
||||||
|
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||||
|
try:
|
||||||
|
json_obj = parse_json_markdown(text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||||
|
for key in expected_keys:
|
||||||
|
if key not in json_obj:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Got invalid return object. Expected key `{key}` "
|
||||||
|
f"to be present, but got {json_obj}"
|
||||||
|
)
|
||||||
|
return json_obj
|
Loading…
x
Reference in New Issue
Block a user