diff --git a/api/.env.example b/api/.env.example index 80ef185e51..cf3a0f302d 100644 --- a/api/.env.example +++ b/api/.env.example @@ -183,6 +183,7 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 # Model Configuration MULTIMODAL_SEND_IMAGE_FORMAT=base64 +PROMPT_GENERATION_MAX_TOKENS=512 # Mail configuration, support: resend, smtp MAIL_TYPE= diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 4e228a70ff..6803775e20 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,3 +1,5 @@ +import os + from flask_login import current_user from flask_restful import Resource, reqparse @@ -28,13 +30,15 @@ class RuleGenerateApi(Resource): args = parser.parse_args() account = current_user + PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512')) try: rules = LLMGenerator.generate_rule_config( tenant_id=account.current_tenant_id, instruction=args['instruction'], model_config=args['model_config'], - no_variable=args['no_variable'] + no_variable=args['no_variable'], + rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index d6a4399fc7..0b5029460a 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -118,7 +118,7 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict: + def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512) -> dict: output_parser = RuleConfigGeneratorOutputParser() error = "" @@ -130,7 +130,7 @@ class LLMGenerator: "error": "" } model_parameters = { - "max_tokens": 512, + "max_tokens": rule_config_max_tokens, "temperature": 0.01 }