diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 6834d3a0c5..9c68e8ecdd 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -9,7 +9,7 @@ api = ExternalApi(bp) from . import setup, version, apikey, admin # Import app controllers -from .app import app, site, completion, model_config, statistic, conversation, message +from .app import app, site, completion, model_config, statistic, conversation, message, generator # Import auth controllers from .auth import login, oauth diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index fbb28fb4ae..83d6840838 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -9,18 +9,13 @@ from werkzeug.exceptions import Unauthorized, Forbidden from constants.model_template import model_templates, demo_model_templates from controllers.console import api -from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError, ProviderQuotaExceededError, \ - CompletionRequestError, ProviderModelCurrentlyNotSupportError +from controllers.console.app.error import AppNotFoundError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.generator.llm_generator import LLMGenerator -from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \ - LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError from events.app_event import app_was_created, app_was_deleted from libs.helper import TimestampField from extensions.ext_database import db -from models.model import App, AppModelConfig, Site, InstalledApp -from services.account_service import TenantService +from models.model import App, AppModelConfig, Site from services.app_model_config_service import AppModelConfigService model_config_fields = { @@ -478,35 +473,6 @@ class AppExport(Resource): pass -class IntroductionGenerateApi(Resource): - @setup_required - @login_required - @account_initialization_required - def post(self): - parser = reqparse.RequestParser() - parser.add_argument('prompt_template', type=str, required=True, location='json') - args = parser.parse_args() - - account = current_user - - try: - answer = LLMGenerator.generate_introduction( - account.current_tenant_id, - args['prompt_template'] - ) - except ProviderTokenNotInitError: - raise ProviderNotInitializeError() - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, - LLMRateLimitError, LLMAuthorizationError) as e: - raise CompletionRequestError(str(e)) - - return {'introduction': answer} - - api.add_resource(AppListApi, '/apps') api.add_resource(AppTemplateApi, '/app-templates') api.add_resource(AppApi, '/apps/') @@ -515,4 +481,3 @@ api.add_resource(AppNameApi, '/apps//name') api.add_resource(AppSiteStatus, '/apps//site-enable') api.add_resource(AppApiStatus, '/apps//api-enable') api.add_resource(AppRateLimit, '/apps//rate-limit') -api.add_resource(IntroductionGenerateApi, '/introduction-generate') diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py new file mode 100644 index 0000000000..6a74bf2584 --- /dev/null +++ b/api/controllers/console/app/generator.py @@ -0,0 +1,75 @@ +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse + +from controllers.console import api +from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \ + CompletionRequestError, ProviderModelCurrentlyNotSupportError +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.generator.llm_generator import LLMGenerator +from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \ + LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError + + +class IntroductionGenerateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('prompt_template', type=str, required=True, location='json') + args = parser.parse_args() + + account = current_user + + try: + answer = LLMGenerator.generate_introduction( + account.current_tenant_id, + args['prompt_template'] + ) + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + + return {'introduction': answer} + + +class RuleGenerateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('audiences', type=str, required=True, nullable=False, location='json') + parser.add_argument('hoping_to_solve', type=str, required=True, nullable=False, location='json') + args = parser.parse_args() + + account = current_user + + try: + rules = LLMGenerator.generate_rule_config( + account.current_tenant_id, + args['audiences'], + args['hoping_to_solve'] + ) + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + + return rules + + +api.add_resource(IntroductionGenerateApi, '/introduction-generate') +api.add_resource(RuleGenerateApi, '/rule-generate') diff --git a/api/core/chain/llm_router_chain.py b/api/core/chain/llm_router_chain.py index 1e3ee57c21..e3779c3612 100644 --- a/api/core/chain/llm_router_chain.py +++ b/api/core/chain/llm_router_chain.py @@ -11,6 +11,8 @@ from langchain.chains import LLMChain from langchain.prompts import BasePromptTemplate from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel +from libs.json_in_md_parser import parse_and_check_json_markdown + class Route(NamedTuple): destination: Optional[str] @@ -82,42 +84,10 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]): next_inputs_type: Type = str next_inputs_inner_key: str = "input" - def parse_json_markdown(self, json_string: str) -> dict: - # Remove the triple backticks if present - json_string = json_string.strip() - start_index = json_string.find("```json") - end_index = json_string.find("```", start_index + len("```json")) - - if start_index != -1 and end_index != -1: - extracted_content = json_string[start_index + len("```json"):end_index].strip() - - # Parse the JSON string into a Python dictionary - parsed = json.loads(extracted_content) - elif json_string.startswith("{"): - # Parse the JSON string into a Python dictionary - parsed = json.loads(json_string) - else: - raise Exception("Could not find JSON block in the output.") - - return parsed - - def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict: - try: - json_obj = self.parse_json_markdown(text) - except json.JSONDecodeError as e: - raise OutputParserException(f"Got invalid JSON object. Error: {e}") - for key in expected_keys: - if key not in json_obj: - raise OutputParserException( - f"Got invalid return object. Expected key `{key}` " - f"to be present, but got {json_obj}" - ) - return json_obj - def parse(self, text: str) -> Dict[str, Any]: try: expected_keys = ["destination", "next_inputs"] - parsed = self.parse_and_check_json_markdown(text, expected_keys) + parsed = parse_and_check_json_markdown(text, expected_keys) if not isinstance(parsed["destination"], str): raise ValueError("Expected 'destination' to be a string.") if not isinstance(parsed["next_inputs"], self.next_inputs_type): @@ -135,5 +105,5 @@ class RouterOutputParser(BaseOutputParser[Dict[str, str]]): return parsed except Exception as e: raise OutputParserException( - f"Parsing text\n{text}\n raised following error:\n{e}" + f"Parsing text\n{text}\n of llm router raised following error:\n{e}" ) diff --git a/api/core/chain/multi_dataset_router_chain.py b/api/core/chain/multi_dataset_router_chain.py index a66b5fdb11..fb0bc35f93 100644 --- a/api/core/chain/multi_dataset_router_chain.py +++ b/api/core/chain/multi_dataset_router_chain.py @@ -23,7 +23,8 @@ think that revising it will ultimately lead to a better response from the langua model. << FORMATTING >> -Return a markdown code snippet with a JSON object formatted to look like: +Return a markdown code snippet with a JSON object formatted to look like, \ +no any other string out of markdown code snippet: ```json {{{{ "destination": string \\ name of the prompt to use or "DEFAULT" diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index 67e5753007..055e54f4c7 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -7,6 +7,7 @@ from core.constant import llm_constant from core.llm.llm_builder import LLMBuilder from core.llm.streamable_open_ai import StreamableOpenAI from core.llm.token_calculator import TokenCalculator +from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.prompt.prompt_template import OutLinePromptTemplate @@ -118,3 +119,46 @@ class LLMGenerator: questions = [] return questions + + @classmethod + def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict: + output_parser = RuleConfigGeneratorOutputParser() + + prompt = OutLinePromptTemplate( + template=output_parser.get_format_instructions(), + input_variables=["audiences", "hoping_to_solve"], + partial_variables={ + "variable": '{variable}', + "lanA": '{lanA}', + "lanB": '{lanB}', + "topic": '{topic}' + }, + validate_template=False + ) + + _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve) + + llm: StreamableOpenAI = LLMBuilder.to_llm( + tenant_id=tenant_id, + model_name=generate_base_model, + temperature=0, + max_tokens=512 + ) + + if isinstance(llm, BaseChatModel): + query = [HumanMessage(content=_input.to_string())] + else: + query = _input.to_string() + + try: + output = llm(query) + rule_config = output_parser.parse(output) + except Exception: + logging.exception("Error generating prompt") + rule_config = { + "prompt": "", + "variables": [], + "opening_statement": "" + } + + return rule_config diff --git a/api/core/prompt/output_parser/rule_config_generator.py b/api/core/prompt/output_parser/rule_config_generator.py new file mode 100644 index 0000000000..84df4d0c34 --- /dev/null +++ b/api/core/prompt/output_parser/rule_config_generator.py @@ -0,0 +1,32 @@ +from typing import Any + +from langchain.schema import BaseOutputParser, OutputParserException +from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE +from libs.json_in_md_parser import parse_and_check_json_markdown + + +class RuleConfigGeneratorOutputParser(BaseOutputParser): + + def get_format_instructions(self) -> str: + return RULE_CONFIG_GENERATE_TEMPLATE + + def parse(self, text: str) -> Any: + try: + expected_keys = ["prompt", "variables", "opening_statement"] + parsed = parse_and_check_json_markdown(text, expected_keys) + if not isinstance(parsed["prompt"], str): + raise ValueError("Expected 'prompt' to be a string.") + if not isinstance(parsed["variables"], list): + raise ValueError( + f"Expected 'variables' to be a list." + ) + if not isinstance(parsed["opening_statement"], str): + raise ValueError( + f"Expected 'opening_statement' to be a str." + ) + return parsed + except Exception as e: + raise OutputParserException( + f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}" + ) + diff --git a/api/core/prompt/prompts.py b/api/core/prompt/prompts.py index 1d9c00990c..af17075408 100644 --- a/api/core/prompt/prompts.py +++ b/api/core/prompt/prompts.py @@ -61,3 +61,60 @@ QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = ( QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt( QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL ) + +RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ +the model prompt that best suits the input. +You will be provided with the prompt, variables, and an opening statement. +Only the content enclosed in double curly braces, such as {{variable}}, in the prompt can be considered as a variable; \ +otherwise, it cannot exist as a variable in the variables. +If you believe revising the original input will result in a better response from the language model, you may \ +suggest revisions. + +<< FORMATTING >> +Return a markdown code snippet with a JSON object formatted to look like, \ +no any other string out of markdown code snippet: +```json +{{{{ + "prompt": string \\ generated prompt + "variables": list of string \\ variables + "opening_statement": string \\ an opening statement to guide users on how to ask questions with generated prompt \ +and fill in variables, with a welcome sentence, and keep TLDR. +}}}} +``` + +<< EXAMPLES >> +[EXAMPLE A] +```json +{ + "prompt": "Write a letter about love", + "variables": [], + "opening_statement": "Hi! I'm your love letter writer AI." +} +``` + +[EXAMPLE B] +```json +{ + "prompt": "Translate from {{lanA}} to {{lanB}}", + "variables": ["lanA", "lanB"], + "opening_statement": "Welcome to use translate app" +} +``` + +[EXAMPLE C] +```json +{ + "prompt": "Write a story about {{topic}}", + "variables": ["topic"], + "opening_statement": "I'm your story writer" +} +``` + +<< MY INTENDED AUDIENCES >> +{audiences} + +<< HOPING TO SOLVE >> +{hoping_to_solve} + +<< OUTPUT >> +""" \ No newline at end of file diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py new file mode 100644 index 0000000000..b2dba695cc --- /dev/null +++ b/api/libs/json_in_md_parser.py @@ -0,0 +1,38 @@ +import json +from typing import List + +from langchain.schema import OutputParserException + + +def parse_json_markdown(json_string: str) -> dict: + # Remove the triple backticks if present + json_string = json_string.strip() + start_index = json_string.find("```json") + end_index = json_string.find("```", start_index + len("```json")) + + if start_index != -1 and end_index != -1: + extracted_content = json_string[start_index + len("```json"):end_index].strip() + + # Parse the JSON string into a Python dictionary + parsed = json.loads(extracted_content) + elif json_string.startswith("{"): + # Parse the JSON string into a Python dictionary + parsed = json.loads(json_string) + else: + raise Exception("Could not find JSON block in the output.") + + return parsed + + +def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: + try: + json_obj = parse_json_markdown(text) + except json.JSONDecodeError as e: + raise OutputParserException(f"Got invalid JSON object. Error: {e}") + for key in expected_keys: + if key not in json_obj: + raise OutputParserException( + f"Got invalid return object. Expected key `{key}` " + f"to be present, but got {json_obj}" + ) + return json_obj