optimize question classifier prompt and support keyword hit test (#3565)

This commit is contained in:
Jyong 2024-04-17 17:40:40 +08:00 committed by GitHub
parent 40b48510f4
commit 394ceee141
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 12 deletions

View File

@ -12,7 +12,7 @@ from controllers.console.app.error import (
ProviderNotInitializeError, ProviderNotInitializeError,
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.console.datasets.error import DatasetNotInitializedError, HighQualityDatasetOnlyError from controllers.console.datasets.error import DatasetNotInitializedError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.errors.error import ( from core.errors.error import (
@ -45,10 +45,6 @@ class HitTestingApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# only high quality dataset can be used for hit testing
if dataset.indexing_technique != 'high_quality':
raise HighQualityDatasetOnlyError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('query', type=str, location='json') parser.add_argument('query', type=str, location='json')
parser.add_argument('retrieval_model', type=dict, required=False, location='json') parser.add_argument('retrieval_model', type=dict, required=False, location='json')

View File

@ -1,4 +1,3 @@
import json
import logging import logging
from typing import Optional, Union, cast from typing import Optional, Union, cast
@ -26,6 +25,7 @@ from core.workflow.nodes.question_classifier.template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_2, QUESTION_CLASSIFIER_USER_PROMPT_2,
QUESTION_CLASSIFIER_USER_PROMPT_3, QUESTION_CLASSIFIER_USER_PROMPT_3,
) )
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -64,7 +64,8 @@ class QuestionClassifierNode(LLMNode):
) )
categories = [_class.name for _class in node_data.classes] categories = [_class.name for _class in node_data.classes]
try: try:
result_text_json = json.loads(result_text.strip('```JSON\n')) result_text_json = parse_and_check_json_markdown(result_text, [])
#result_text_json = json.loads(result_text.strip('```JSON\n'))
categories_result = result_text_json.get('categories', []) categories_result = result_text_json.get('categories', [])
if categories_result: if categories_result:
categories = categories_result categories = categories_result

View File

@ -19,29 +19,33 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
QUESTION_CLASSIFIER_USER_PROMPT_1 = """ QUESTION_CLASSIFIER_USER_PROMPT_1 = """
{ "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], { "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],
"categories": ["Customer Service", "Satisfaction", "Sales", "Product"], "categories": ["Customer Service", "Satisfaction", "Sales", "Product"],
"classification_instructions": ["classify the text based on the feedback provided by customer"]}```JSON "classification_instructions": ["classify the text based on the feedback provided by customer"]}
""" """
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """ QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """
```json
{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"], {"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],
"categories": ["Customer Service"]}``` "categories": ["Customer Service"]}
```
""" """
QUESTION_CLASSIFIER_USER_PROMPT_2 = """ QUESTION_CLASSIFIER_USER_PROMPT_2 = """
{"input_text": ["bad service, slow to bring the food"], {"input_text": ["bad service, slow to bring the food"],
"categories": ["Food Quality", "Experience", "Price" ], "categories": ["Food Quality", "Experience", "Price" ],
"classification_instructions": []}```JSON "classification_instructions": []}
""" """
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """ QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """
```json
{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"], {"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],
"categories": ["Experience"]}``` "categories": ["Experience"]}
```
""" """
QUESTION_CLASSIFIER_USER_PROMPT_3 = """ QUESTION_CLASSIFIER_USER_PROMPT_3 = """
'{{"input_text": ["{input_text}"],', '{{"input_text": ["{input_text}"],',
'"categories": ["{categories}" ], ', '"categories": ["{categories}" ], ',
'"classification_instructions": ["{classification_instructions}"]}}```JSON' '"classification_instructions": ["{classification_instructions}"]}}'
""" """
QUESTION_CLASSIFIER_COMPLETION_PROMPT = """ QUESTION_CLASSIFIER_COMPLETION_PROMPT = """