mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-05 15:50:38 +08:00
parent
a9b942981d
commit
a71f2863ac
@ -28,7 +28,7 @@ from extensions.ext_database import db
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import InvitationCode, Tenant, TenantAccountJoin
|
||||
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
|
||||
from models.model import Account, AppModelConfig, App
|
||||
from models.model import Account, AppModelConfig, App, MessageAnnotation, Message
|
||||
import secrets
|
||||
import base64
|
||||
|
||||
@ -752,6 +752,30 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
|
||||
@click.command('add-annotation-question-field-value', help='add annotation question value')
|
||||
def add_annotation_question_field_value():
|
||||
click.echo(click.style('Start add annotation question value.', fg='green'))
|
||||
message_annotations = db.session.query(MessageAnnotation).all()
|
||||
message_annotation_deal_count = 0
|
||||
if message_annotations:
|
||||
for message_annotation in message_annotations:
|
||||
try:
|
||||
if message_annotation.message_id and not message_annotation.question:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_annotation.message_id
|
||||
).first()
|
||||
message_annotation.question = message.query
|
||||
db.session.add(message_annotation)
|
||||
db.session.commit()
|
||||
message_annotation_deal_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
click.echo(
|
||||
click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
@ -766,3 +790,4 @@ def register_commands(app):
|
||||
app.cli.add_command(normalization_collections)
|
||||
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
|
||||
app.cli.add_command(add_qdrant_full_text_index)
|
||||
app.cli.add_command(add_annotation_question_field_value)
|
||||
|
@ -9,7 +9,7 @@ api = ExternalApi(bp)
|
||||
from . import extension, setup, version, apikey, admin
|
||||
|
||||
# Import app controllers
|
||||
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import login, oauth, data_source_oauth, activate
|
||||
|
291
api/controllers/console/app/annotation.py
Normal file
291
api/controllers/console/app/annotation.py
Normal file
@ -0,0 +1,291 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, marshal
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import NoFileUploadedError
|
||||
from controllers.console.datasets.error import TooManyFilesError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_list_fields, annotation_hit_history_list_fields, annotation_fields, \
|
||||
annotation_hit_history_fields
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from flask import request
|
||||
|
||||
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id, action):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
||||
parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
|
||||
parser.add_argument('embedding_model_name', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
if action == 'enable':
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||
elif action == 'disable':
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
else:
|
||||
raise ValueError('Unsupported annotation reply action')
|
||||
return result, 200
|
||||
|
||||
|
||||
class AppAnnotationSettingDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
||||
return result, 200
|
||||
|
||||
|
||||
class AppAnnotationSettingUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('score_threshold', required=True, type=float, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||
return result, 200
|
||||
|
||||
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id, action):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ''
|
||||
if job_status == 'error':
|
||||
app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
|
||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': job_status,
|
||||
'error_msg': error_msg
|
||||
}, 200
|
||||
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
keyword = request.args.get('keyword', default=None, type=str)
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
response = {
|
||||
'data': marshal(annotation_list, annotation_fields),
|
||||
'has_more': len(annotation_list) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
class AnnotationExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response = {
|
||||
'data': marshal(annotation_list, annotation_fields)
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
class AnnotationCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
return annotation
|
||||
|
||||
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def delete(self, app_id, annotation_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class AnnotationBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
|
||||
|
||||
class AnnotationBatchImportStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
def get(self, app_id, job_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job is not exist.")
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ''
|
||||
if job_status == 'error':
|
||||
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
|
||||
error_msg = redis_client.get(indexing_error_msg_key).decode()
|
||||
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': job_status,
|
||||
'error_msg': error_msg
|
||||
}, 200
|
||||
|
||||
|
||||
class AnnotationHitHistoryListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id, annotation_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
|
||||
page, limit)
|
||||
response = {
|
||||
'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||
'has_more': len(annotation_hit_history_list) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
|
||||
api.add_resource(AnnotationReplyActionStatusApi,
|
||||
'/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
|
||||
api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
|
||||
api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
|
||||
api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
|
||||
api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
|
||||
api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
|
||||
api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
|
||||
api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
|
||||
api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')
|
@ -72,4 +72,16 @@ class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
@ -6,22 +6,23 @@ from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, marshal_with, fields
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \
|
||||
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.login import login_required
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from fields.conversation_fields import message_detail_fields, annotation_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -151,44 +152,24 @@ class MessageAnnotationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check('annotation')
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
||||
# get app info
|
||||
app = _get_app(app_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', required=True, type=uuid_value, location='json')
|
||||
parser.add_argument('content', type=str, location='json')
|
||||
parser.add_argument('message_id', required=False, type=uuid_value, location='json')
|
||||
parser.add_argument('question', required=True, type=str, location='json')
|
||||
parser.add_argument('answer', required=True, type=str, location='json')
|
||||
parser.add_argument('annotation_reply', required=False, type=dict, location='json')
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
|
||||
message_id = str(args['message_id'])
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app.id
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
annotation = message.annotation
|
||||
|
||||
if annotation:
|
||||
annotation.content = args['content']
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
content=args['content'],
|
||||
account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}
|
||||
return annotation
|
||||
|
||||
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
|
@ -24,29 +24,29 @@ class ModelConfigResource(Resource):
|
||||
"""Modify app model config"""
|
||||
app_id = str(app_id)
|
||||
|
||||
app_model = _get_app(app_id)
|
||||
app = _get_app(app_id)
|
||||
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=request.json,
|
||||
mode=app_model.mode
|
||||
mode=app.mode
|
||||
)
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
app_id=app.id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
db.session.add(new_app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app.app_model_config_id = new_app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
app_model_config_was_updated.send(
|
||||
app_model,
|
||||
app,
|
||||
app_model_config=new_app_model_config
|
||||
)
|
||||
|
||||
|
@ -30,6 +30,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
@ -49,6 +50,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
|
@ -17,6 +17,7 @@ class UniversalChatParameterApi(UniversalChatResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@ -32,6 +33,7 @@ class UniversalChatParameterApi(UniversalChatResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
}
|
||||
|
||||
|
||||
|
@ -47,6 +47,7 @@ def universal_chat_app_required(view=None):
|
||||
suggested_questions=json.dumps([]),
|
||||
suggested_questions_after_answer=json.dumps({'enabled': True}),
|
||||
speech_to_text=json.dumps({'enabled': True}),
|
||||
annotation_reply=json.dumps({'enabled': False}),
|
||||
retriever_resource=json.dumps({'enabled': True}),
|
||||
more_like_this=None,
|
||||
sensitive_word_avoidance=None,
|
||||
|
@ -55,6 +55,7 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
members = billing_info['members']
|
||||
apps = billing_info['apps']
|
||||
vector_space = billing_info['vector_space']
|
||||
annotation_quota_limit = billing_info['annotation_quota_limit']
|
||||
|
||||
if resource == 'members' and 0 < members['limit'] <= members['size']:
|
||||
abort(403, error_msg)
|
||||
@ -62,6 +63,8 @@ def cloud_edition_billing_resource_check(resource: str,
|
||||
abort(403, error_msg)
|
||||
elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
|
||||
abort(403, error_msg)
|
||||
elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] <= annotation_quota_limit['size']:
|
||||
abort(403, error_msg)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
@ -31,6 +31,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
@ -49,6 +50,7 @@ class AppParameterApi(AppApiResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
|
@ -30,6 +30,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
@ -48,6 +49,7 @@ class AppParameterApi(WebApiResource):
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
|
@ -12,8 +12,10 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.file.file_obj import FileObj
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
@ -23,9 +25,12 @@ from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppModelConfig, Account, Conversation, EndUser
|
||||
from core.moderation.base import ModerationException, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
class Completion:
|
||||
@ -33,7 +38,7 @@ class Completion:
|
||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
|
||||
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
|
||||
auto_generate_name: bool = True):
|
||||
auto_generate_name: bool = True, from_source: str = 'console'):
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
@ -109,7 +114,10 @@ class Completion:
|
||||
fake_response=str(e)
|
||||
)
|
||||
return
|
||||
|
||||
# check annotation reply
|
||||
annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
|
||||
if annotation_reply:
|
||||
return
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_model_config.external_data_tools_list
|
||||
if external_data_tools:
|
||||
@ -166,17 +174,18 @@ class Completion:
|
||||
except ChunkedEncodingError as e:
|
||||
# Interrupt by LLM (like OpenAI), handle it.
|
||||
logging.warning(f'ChunkedEncodingError: {e}')
|
||||
conversation_message_task.end()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
|
||||
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
|
||||
query: str):
|
||||
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
|
||||
return inputs, query
|
||||
|
||||
type = app_model_config.sensitive_word_avoidance_dict['type']
|
||||
|
||||
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
|
||||
moderation = ModerationFactory(type, app_id, tenant_id,
|
||||
app_model_config.sensitive_word_avoidance_dict['config'])
|
||||
moderation_result = moderation.moderation_for_inputs(inputs, query)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
@ -324,6 +333,76 @@ class Completion:
|
||||
external_context = memory.load_memory_variables({})
|
||||
return external_context[memory_key]
|
||||
|
||||
@classmethod
|
||||
def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
|
||||
from_source: str) -> bool:
|
||||
"""Get memory messages."""
|
||||
app_model_config = conversation_message_task.app_model_config
|
||||
app = conversation_message_task.app
|
||||
annotation_reply = app_model_config.annotation_reply_dict
|
||||
if annotation_reply['enabled']:
|
||||
score_threshold = annotation_reply.get('score_threshold', 1)
|
||||
embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
|
||||
embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
|
||||
# get embedding model
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=app.tenant_id,
|
||||
model_provider_name=embedding_provider_name,
|
||||
model_name=embedding_model_name
|
||||
)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
'annotation'
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = vector_index.search(
|
||||
conversation_message_task.query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': 1,
|
||||
'score_threshold': score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
}
|
||||
)
|
||||
if documents:
|
||||
annotation_id = documents[0].metadata['annotation_id']
|
||||
score = documents[0].metadata['score']
|
||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||
if annotation:
|
||||
conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
|
||||
# insert annotation history
|
||||
AppAnnotationService.add_annotation_history(annotation.id,
|
||||
app.id,
|
||||
annotation.question,
|
||||
annotation.content,
|
||||
conversation_message_task.query,
|
||||
conversation_message_task.user.id,
|
||||
conversation_message_task.message.id,
|
||||
from_source,
|
||||
score)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
|
||||
conversation: Conversation,
|
||||
|
@ -319,6 +319,10 @@ class ConversationMessageTask:
|
||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str):
|
||||
self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at)
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
|
||||
class PubHandler:
|
||||
def __init__(self, user: Union[Account, EndUser], task_id: str,
|
||||
@ -435,7 +439,7 @@ class PubHandler:
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id
|
||||
'conversation_id': self._conversation.id,
|
||||
}
|
||||
}
|
||||
if retriever_resource:
|
||||
@ -446,6 +450,30 @@ class PubHandler:
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float):
|
||||
content = {
|
||||
'event': 'annotation',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': self._conversation.id,
|
||||
'text': text,
|
||||
'annotation_id': annotation_id,
|
||||
'annotation_author_name': annotation_author_name
|
||||
}
|
||||
}
|
||||
self._message.answer = text
|
||||
self._message.provider_response_latency = time.perf_counter() - start_at
|
||||
|
||||
db.session.commit()
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_end(self):
|
||||
content = {
|
||||
'event': 'end',
|
||||
|
@ -32,6 +32,10 @@ class BaseIndex(ABC):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
@ -107,6 +107,9 @@ class KeywordTableIndex(BaseIndex):
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
pass
|
||||
|
||||
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
return KeywordTableRetriever(index=self, **kwargs)
|
||||
|
||||
|
@ -121,6 +121,16 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
'filter': f'id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
vector_store.del_texts({
|
||||
'filter': f'id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
|
@ -138,6 +138,22 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
|
@ -141,6 +141,17 @@ class WeaviateVectorIndex(BaseVectorIndex):
|
||||
"valueText": document_id
|
||||
})
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
})
|
||||
|
||||
def delete_by_group_id(self, group_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
|
@ -30,6 +30,16 @@ class MilvusVectorStore(Milvus):
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = self.col.query(
|
||||
expr=f'metadata["{key}"] == "{value}"',
|
||||
output_fields=["id"]
|
||||
)
|
||||
if result:
|
||||
return [item["id"] for item in result]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_ids_by_doc_ids(self, doc_ids: list):
|
||||
result = self.col.query(
|
||||
expr=f'metadata["doc_id"] in {doc_ids}',
|
||||
|
@ -6,13 +6,13 @@ from models.model import AppModelConfig
|
||||
|
||||
@app_model_config_was_updated.connect
|
||||
def handle(sender, **kwargs):
|
||||
app_model = sender
|
||||
app = sender
|
||||
app_model_config = kwargs.get('app_model_config')
|
||||
|
||||
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
|
||||
|
||||
app_dataset_joins = db.session.query(AppDatasetJoin).filter(
|
||||
AppDatasetJoin.app_id == app_model.id
|
||||
AppDatasetJoin.app_id == app.id
|
||||
).all()
|
||||
|
||||
removed_dataset_ids = []
|
||||
@ -29,14 +29,14 @@ def handle(sender, **kwargs):
|
||||
if removed_dataset_ids:
|
||||
for dataset_id in removed_dataset_ids:
|
||||
db.session.query(AppDatasetJoin).filter(
|
||||
AppDatasetJoin.app_id == app_model.id,
|
||||
AppDatasetJoin.app_id == app.id,
|
||||
AppDatasetJoin.dataset_id == dataset_id
|
||||
).delete()
|
||||
|
||||
if added_dataset_ids:
|
||||
for dataset_id in added_dataset_ids:
|
||||
app_dataset_join = AppDatasetJoin(
|
||||
app_id=app_model.id,
|
||||
app_id=app.id,
|
||||
dataset_id=dataset_id
|
||||
)
|
||||
db.session.add(app_dataset_join)
|
||||
|
36
api/fields/annotation_fields.py
Normal file
36
api/fields/annotation_fields.py
Normal file
@ -0,0 +1,36 @@
|
||||
from flask_restful import fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'email': fields.String
|
||||
}
|
||||
|
||||
|
||||
annotation_fields = {
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"answer": fields.Raw(attribute='content'),
|
||||
"hit_count": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
# 'account': fields.Nested(account_fields, allow_null=True)
|
||||
}
|
||||
|
||||
annotation_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_fields)),
|
||||
}
|
||||
|
||||
annotation_hit_history_fields = {
|
||||
"id": fields.String,
|
||||
"source": fields.String,
|
||||
"score": fields.Float,
|
||||
"question": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"match": fields.String(attribute='annotation_question'),
|
||||
"response": fields.String(attribute='annotation_content')
|
||||
}
|
||||
|
||||
annotation_hit_history_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_hit_history_fields)),
|
||||
}
|
@ -21,6 +21,7 @@ model_config_fields = {
|
||||
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
||||
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
||||
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
|
||||
'annotation_reply': fields.Raw(attribute='annotation_reply_dict'),
|
||||
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
||||
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
|
||||
'external_data_tools': fields.Raw(attribute='external_data_tools_list'),
|
||||
|
@ -23,11 +23,18 @@ feedback_fields = {
|
||||
}
|
||||
|
||||
annotation_fields = {
|
||||
'id': fields.String,
|
||||
'question': fields.String,
|
||||
'content': fields.String,
|
||||
'account': fields.Nested(account_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
annotation_hit_history_fields = {
|
||||
'annotation_id': fields.String,
|
||||
'annotation_create_account': fields.Nested(account_fields, allow_null=True)
|
||||
}
|
||||
|
||||
message_file_fields = {
|
||||
'id': fields.String,
|
||||
'type': fields.String,
|
||||
@ -49,6 +56,7 @@ message_detail_fields = {
|
||||
'from_account_id': fields.String,
|
||||
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True),
|
||||
'created_at': TimestampField,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
}
|
||||
|
@ -0,0 +1,50 @@
|
||||
"""add_app_anntation_setting
|
||||
|
||||
Revision ID: 246ba09cbbdb
|
||||
Revises: 714aafe25d39
|
||||
Create Date: 2023-12-14 11:26:12.287264
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '246ba09cbbdb'
|
||||
down_revision = '714aafe25d39'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('app_annotation_settings',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('app_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('score_threshold', sa.Float(), server_default=sa.text('0'), nullable=False),
|
||||
sa.Column('collection_binding_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_user_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_user_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey')
|
||||
)
|
||||
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
|
||||
batch_op.create_index('app_annotation_settings_app_idx', ['app_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('annotation_reply')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True))
|
||||
|
||||
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
|
||||
batch_op.drop_index('app_annotation_settings_app_idx')
|
||||
|
||||
op.drop_table('app_annotation_settings')
|
||||
# ### end Alembic commands ###
|
@ -0,0 +1,32 @@
|
||||
"""add-annotation-histoiry-score
|
||||
|
||||
Revision ID: 46976cc39132
|
||||
Revises: e1901f623fd0
|
||||
Create Date: 2023-12-13 04:39:59.302971
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '46976cc39132'
|
||||
down_revision = 'e1901f623fd0'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('score', sa.Float(), server_default=sa.text('0'), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.drop_column('score')
|
||||
|
||||
# ### end Alembic commands ###
|
@ -0,0 +1,34 @@
|
||||
"""add_anntation_history_match_response
|
||||
|
||||
Revision ID: 714aafe25d39
|
||||
Revises: f2a6fc85e260
|
||||
Create Date: 2023-12-14 06:38:02.972527
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '714aafe25d39'
|
||||
down_revision = 'f2a6fc85e260'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False))
|
||||
batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.drop_column('annotation_content')
|
||||
batch_op.drop_column('annotation_question')
|
||||
|
||||
# ### end Alembic commands ###
|
79
api/migrations/versions/e1901f623fd0_add_annotation_reply.py
Normal file
79
api/migrations/versions/e1901f623fd0_add_annotation_reply.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""add-annotation-reply
|
||||
|
||||
Revision ID: e1901f623fd0
|
||||
Revises: fca025d3b60f
|
||||
Create Date: 2023-12-12 06:58:41.054544
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e1901f623fd0'
|
||||
down_revision = 'fca025d3b60f'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('app_annotation_hit_histories',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('app_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('annotation_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('source', sa.Text(), nullable=False),
|
||||
sa.Column('question', sa.Text(), nullable=False),
|
||||
sa.Column('account_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey')
|
||||
)
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.create_index('app_annotation_hit_histories_account_idx', ['account_id'], unique=False)
|
||||
batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False)
|
||||
batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'::character varying"), nullable=False))
|
||||
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('question', sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('message_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.alter_column('message_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('conversation_id',
|
||||
existing_type=postgresql.UUID(),
|
||||
nullable=False)
|
||||
batch_op.drop_column('hit_count')
|
||||
batch_op.drop_column('question')
|
||||
|
||||
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_column('type')
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('annotation_reply')
|
||||
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.drop_index('app_annotation_hit_histories_app_idx')
|
||||
batch_op.drop_index('app_annotation_hit_histories_annotation_idx')
|
||||
batch_op.drop_index('app_annotation_hit_histories_account_idx')
|
||||
|
||||
op.drop_table('app_annotation_hit_histories')
|
||||
# ### end Alembic commands ###
|
@ -0,0 +1,34 @@
|
||||
"""add_anntation_history_message_id
|
||||
|
||||
Revision ID: f2a6fc85e260
|
||||
Revises: 46976cc39132
|
||||
Create Date: 2023-12-13 11:09:29.329584
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'f2a6fc85e260'
|
||||
down_revision = '46976cc39132'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False))
|
||||
batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
|
||||
batch_op.drop_index('app_annotation_hit_histories_message_idx')
|
||||
batch_op.drop_column('message_id')
|
||||
|
||||
# ### end Alembic commands ###
|
@ -475,5 +475,6 @@ class DatasetCollectionBinding(db.Model):
|
||||
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
|
||||
provider_name = db.Column(db.String(40), nullable=False)
|
||||
model_name = db.Column(db.String(40), nullable=False)
|
||||
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
|
||||
collection_name = db.Column(db.String(64), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
@ -128,6 +129,25 @@ class AppModelConfig(db.Model):
|
||||
return json.loads(self.retriever_resource) if self.retriever_resource \
|
||||
else {"enabled": False}
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> dict:
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == self.app_id).first()
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name
|
||||
}
|
||||
}
|
||||
|
||||
else:
|
||||
return {"enabled": False}
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> dict:
|
||||
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
|
||||
@ -170,7 +190,9 @@ class AppModelConfig(db.Model):
|
||||
|
||||
@property
|
||||
def file_upload_dict(self) -> dict:
|
||||
return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}}
|
||||
return json.loads(self.file_upload) if self.file_upload else {
|
||||
"image": {"enabled": False, "number_limits": 3, "detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"]}}
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
@ -182,6 +204,7 @@ class AppModelConfig(db.Model):
|
||||
"suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
|
||||
"speech_to_text": self.speech_to_text_dict,
|
||||
"retriever_resource": self.retriever_resource_dict,
|
||||
"annotation_reply": self.annotation_reply_dict,
|
||||
"more_like_this": self.more_like_this_dict,
|
||||
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
|
||||
"external_data_tools": self.external_data_tools_list,
|
||||
@ -504,6 +527,12 @@ class Message(db.Model):
|
||||
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first()
|
||||
return annotation
|
||||
|
||||
@property
|
||||
def annotation_hit_history(self):
|
||||
annotation_history = (db.session.query(AppAnnotationHitHistory)
|
||||
.filter(AppAnnotationHitHistory.message_id == self.id).first())
|
||||
return annotation_history
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first()
|
||||
@ -616,9 +645,11 @@ class MessageAnnotation(db.Model):
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False)
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=True)
|
||||
message_id = db.Column(UUID, nullable=True)
|
||||
question = db.Column(db.Text, nullable=True)
|
||||
content = db.Column(db.Text, nullable=False)
|
||||
hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
|
||||
account_id = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
@ -629,6 +660,79 @@ class MessageAnnotation(db.Model):
|
||||
return account
|
||||
|
||||
|
||||
class AppAnnotationHitHistory(db.Model):
|
||||
__tablename__ = 'app_annotation_hit_histories'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey'),
|
||||
db.Index('app_annotation_hit_histories_app_idx', 'app_id'),
|
||||
db.Index('app_annotation_hit_histories_account_idx', 'account_id'),
|
||||
db.Index('app_annotation_hit_histories_annotation_idx', 'annotation_id'),
|
||||
db.Index('app_annotation_hit_histories_message_idx', 'message_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
annotation_id = db.Column(UUID, nullable=False)
|
||||
source = db.Column(db.Text, nullable=False)
|
||||
question = db.Column(db.Text, nullable=False)
|
||||
account_id = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
score = db.Column(Float, nullable=False, server_default=db.text('0'))
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
annotation_question = db.Column(db.Text, nullable=False)
|
||||
annotation_content = db.Column(db.Text, nullable=False)
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
account = (db.session.query(Account)
|
||||
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
|
||||
.filter(MessageAnnotation.id == self.annotation_id).first())
|
||||
return account
|
||||
|
||||
@property
|
||||
def annotation_create_account(self):
|
||||
account = db.session.query(Account).filter(Account.id == self.account_id).first()
|
||||
return account
|
||||
|
||||
|
||||
class AppAnnotationSetting(db.Model):
|
||||
__tablename__ = 'app_annotation_settings'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey'),
|
||||
db.Index('app_annotation_settings_app_idx', 'app_id')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
app_id = db.Column(UUID, nullable=False)
|
||||
score_threshold = db.Column(Float, nullable=False, server_default=db.text('0'))
|
||||
collection_binding_id = db.Column(UUID, nullable=False)
|
||||
created_user_id = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_user_id = db.Column(UUID, nullable=False)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@property
|
||||
def created_account(self):
|
||||
account = (db.session.query(Account)
|
||||
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
|
||||
.filter(AppAnnotationSetting.id == self.annotation_id).first())
|
||||
return account
|
||||
|
||||
@property
|
||||
def updated_account(self):
|
||||
account = (db.session.query(Account)
|
||||
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
|
||||
.filter(AppAnnotationSetting.id == self.annotation_id).first())
|
||||
return account
|
||||
|
||||
@property
|
||||
def collection_binding_detail(self):
|
||||
from .dataset import DatasetCollectionBinding
|
||||
collection_binding_detail = (db.session.query(DatasetCollectionBinding)
|
||||
.filter(DatasetCollectionBinding.id == self.collection_binding_id).first())
|
||||
return collection_binding_detail
|
||||
|
||||
|
||||
class OperationLog(db.Model):
|
||||
__tablename__ = 'operation_logs'
|
||||
__table_args__ = (
|
||||
|
426
api/services/annotation_service.py
Normal file
426
api/services/annotation_service.py
Normal file
@ -0,0 +1,426 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import pandas as pd
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import or_
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import MessageAnnotation, Message, App, AppAnnotationHitHistory, AppAnnotationSetting
|
||||
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
|
||||
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
|
||||
from tasks.annotation.disable_annotation_reply_task import disable_annotation_reply_task
|
||||
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
|
||||
from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task
|
||||
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
|
||||
|
||||
|
||||
class AppAnnotationService:
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
if 'message_id' in args and args['message_id']:
|
||||
message_id = str(args['message_id'])
|
||||
# get message info
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app.id
|
||||
).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
annotation = message.annotation
|
||||
# save the message annotation
|
||||
if annotation:
|
||||
annotation.content = args['answer']
|
||||
annotation.question = args['question']
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
content=args['answer'],
|
||||
question=args['question'],
|
||||
account_id=current_user.id
|
||||
)
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
content=args['answer'],
|
||||
question=args['question'],
|
||||
account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
|
||||
app_id, annotation_setting.collection_binding_id)
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
|
||||
enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
return {
|
||||
'job_id': cache_result,
|
||||
'job_status': 'processing'
|
||||
}
|
||||
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(enable_app_annotation_job_key, 'waiting')
|
||||
enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id,
|
||||
args['score_threshold'],
|
||||
args['embedding_provider_name'], args['embedding_model_name'])
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def disable_app_annotation(cls, app_id: str) -> dict:
|
||||
disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
|
||||
cache_result = redis_client.get(disable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
return {
|
||||
'job_id': cache_result,
|
||||
'job_status': 'processing'
|
||||
}
|
||||
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(disable_app_annotation_job_key, 'waiting')
|
||||
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
if keyword:
|
||||
annotations = (db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.filter(
|
||||
or_(
|
||||
MessageAnnotation.question.ilike('%{}%'.format(keyword)),
|
||||
MessageAnnotation.content.ilike('%{}%'.format(keyword))
|
||||
)
|
||||
)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
|
||||
else:
|
||||
annotations = (db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
|
||||
return annotations.items, annotations.total
|
||||
|
||||
@classmethod
|
||||
def export_annotation_list_by_app_id(cls, app_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
annotations = (db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc()).all())
|
||||
return annotations
|
||||
|
||||
@classmethod
|
||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
content=args['answer'],
|
||||
question=args['question'],
|
||||
account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
|
||||
app_id, annotation_setting.collection_binding_id)
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
annotation.content = args['answer']
|
||||
annotation.question = args['question']
|
||||
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id
|
||||
).first()
|
||||
|
||||
if app_annotation_setting:
|
||||
update_annotation_to_index_task.delay(annotation.id, annotation.question,
|
||||
current_user.current_tenant_id,
|
||||
app_id, app_annotation_setting.collection_binding_id)
|
||||
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
db.session.delete(annotation)
|
||||
|
||||
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
|
||||
.filter(AppAnnotationHitHistory.annotation_id == annotation_id)
|
||||
.all()
|
||||
)
|
||||
if annotation_hit_histories:
|
||||
for annotation_hit_history in annotation_hit_histories:
|
||||
db.session.delete(annotation_hit_history)
|
||||
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , delete annotation index
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id
|
||||
).first()
|
||||
|
||||
if app_annotation_setting:
|
||||
delete_annotation_index_task.delay(annotation.id, app_id,
|
||||
current_user.current_tenant_id,
|
||||
app_annotation_setting.collection_binding_id)
|
||||
|
||||
@classmethod
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
try:
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file)
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
content = {
|
||||
'question': row[0],
|
||||
'answer': row[1]
|
||||
}
|
||||
result.append(content)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, 'waiting')
|
||||
batch_import_annotations_task.delay(str(job_id), result, app_id,
|
||||
current_user.current_tenant_id, current_user.id)
|
||||
except Exception as e:
|
||||
return {
|
||||
'error_msg': str(e)
|
||||
}
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
|
||||
.filter(AppAnnotationHitHistory.app_id == app_id,
|
||||
AppAnnotationHitHistory.annotation_id == annotation_id,
|
||||
)
|
||||
.order_by(AppAnnotationHitHistory.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
|
||||
return annotation_hit_histories.items, annotation_hit_histories.total
|
||||
|
||||
@classmethod
|
||||
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
|
||||
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
|
||||
|
||||
if not annotation:
|
||||
return None
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str,
|
||||
annotation_content: str, query: str, user_id: str,
|
||||
message_id: str, from_source: str, score: float):
|
||||
# add hit count to annotation
|
||||
db.session.query(MessageAnnotation).filter(
|
||||
MessageAnnotation.id == annotation_id
|
||||
).update(
|
||||
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
annotation_hit_history = AppAnnotationHitHistory(
|
||||
annotation_id=annotation_id,
|
||||
app_id=app_id,
|
||||
account_id=user_id,
|
||||
question=query,
|
||||
source=from_source,
|
||||
score=score,
|
||||
message_id=message_id,
|
||||
annotation_question=annotation_question,
|
||||
annotation_content=annotation_content
|
||||
)
|
||||
db.session.add(annotation_hit_history)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name
|
||||
}
|
||||
}
|
||||
return {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id,
|
||||
AppAnnotationSetting.id == annotation_setting_id,
|
||||
).first()
|
||||
if not annotation_setting:
|
||||
raise NotFound("App annotation not found")
|
||||
annotation_setting.score_threshold = args['score_threshold']
|
||||
annotation_setting.updated_user_id = current_user.id
|
||||
annotation_setting.updated_at = datetime.datetime.utcnow()
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name
|
||||
}
|
||||
}
|
@ -138,7 +138,22 @@ class AppModelConfigService:
|
||||
config["retriever_resource"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["retriever_resource"]["enabled"], bool):
|
||||
raise ValueError("enabled in speech_to_text must be of boolean type")
|
||||
raise ValueError("enabled in retriever_resource must be of boolean type")
|
||||
|
||||
# annotation reply
|
||||
if 'annotation_reply' not in config or not config["annotation_reply"]:
|
||||
config["annotation_reply"] = {
|
||||
"enabled": False
|
||||
}
|
||||
|
||||
if not isinstance(config["annotation_reply"], dict):
|
||||
raise ValueError("annotation_reply must be of dict type")
|
||||
|
||||
if "enabled" not in config["annotation_reply"] or not config["annotation_reply"]["enabled"]:
|
||||
config["annotation_reply"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["annotation_reply"]["enabled"], bool):
|
||||
raise ValueError("enabled in annotation_reply must be of boolean type")
|
||||
|
||||
# more_like_this
|
||||
if 'more_like_this' not in config or not config["more_like_this"]:
|
||||
@ -325,6 +340,7 @@ class AppModelConfigService:
|
||||
"suggested_questions_after_answer": config["suggested_questions_after_answer"],
|
||||
"speech_to_text": config["speech_to_text"],
|
||||
"retriever_resource": config["retriever_resource"],
|
||||
"annotation_reply": config["annotation_reply"],
|
||||
"more_like_this": config["more_like_this"],
|
||||
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
|
||||
"external_data_tools": config["external_data_tools"],
|
||||
|
@ -165,7 +165,8 @@ class CompletionService:
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': is_model_config_override,
|
||||
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
|
||||
'auto_generate_name': auto_generate_name
|
||||
'auto_generate_name': auto_generate_name,
|
||||
'from_source': from_source
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
@ -193,7 +194,7 @@ class CompletionService:
|
||||
query: str, inputs: dict, files: List[PromptMessageFile],
|
||||
detached_user: Union[Account, EndUser],
|
||||
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
|
||||
retriever_from: str = 'dev', auto_generate_name: bool = True):
|
||||
retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'):
|
||||
with flask_app.app_context():
|
||||
# fixed the state of the model object when it detached from the original session
|
||||
user = db.session.merge(detached_user)
|
||||
@ -218,7 +219,8 @@ class CompletionService:
|
||||
streaming=streaming,
|
||||
is_override=is_model_config_override,
|
||||
retriever_from=retriever_from,
|
||||
auto_generate_name=auto_generate_name
|
||||
auto_generate_name=auto_generate_name,
|
||||
from_source=from_source
|
||||
)
|
||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
||||
pass
|
||||
@ -385,6 +387,9 @@ class CompletionService:
|
||||
result = json.loads(result)
|
||||
if result.get('error'):
|
||||
cls.handle_error(result)
|
||||
if result['event'] == 'annotation' and 'data' in result:
|
||||
message_result['annotation'] = result.get('data')
|
||||
return cls.get_blocking_annotation_message_response_data(message_result)
|
||||
if result['event'] == 'message' and 'data' in result:
|
||||
message_result['message'] = result.get('data')
|
||||
if result['event'] == 'message_end' and 'data' in result:
|
||||
@ -427,6 +432,9 @@ class CompletionService:
|
||||
elif event == 'agent_thought':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'annotation':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_annotation_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'message_end':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_message_end_data(result.get('data'))) + "\n\n"
|
||||
@ -499,6 +507,25 @@ class CompletionService:
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_blocking_annotation_message_response_data(cls, data: dict):
|
||||
message = data.get('annotation')
|
||||
response_data = {
|
||||
'event': 'annotation',
|
||||
'task_id': message.get('task_id'),
|
||||
'id': message.get('message_id'),
|
||||
'answer': message.get('text'),
|
||||
'metadata': {},
|
||||
'created_at': int(time.time()),
|
||||
'annotation_id': message.get('annotation_id'),
|
||||
'annotation_author_name': message.get('annotation_author_name')
|
||||
}
|
||||
|
||||
if message.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = message.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_message_end_data(cls, data: dict):
|
||||
response_data = {
|
||||
@ -551,6 +578,23 @@ class CompletionService:
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_annotation_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'annotation',
|
||||
'task_id': data.get('task_id'),
|
||||
'id': data.get('message_id'),
|
||||
'answer': data.get('text'),
|
||||
'created_at': int(time.time()),
|
||||
'annotation_id': data.get('annotation_id'),
|
||||
'annotation_author_name': data.get('annotation_author_name'),
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def handle_error(cls, result: dict):
|
||||
logging.debug("error: %s", result)
|
||||
|
@ -33,10 +33,7 @@ from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.create_segment_to_index_task import create_segment_to_index_task
|
||||
from tasks.update_segment_index_task import update_segment_index_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
|
||||
|
||||
@ -1175,10 +1172,12 @@ class SegmentService:
|
||||
|
||||
class DatasetCollectionBindingService:
|
||||
@classmethod
|
||||
def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding:
|
||||
def get_dataset_collection_binding(cls, provider_name: str, model_name: str,
|
||||
collection_type: str = 'dataset') -> DatasetCollectionBinding:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.provider_name == provider_name,
|
||||
DatasetCollectionBinding.model_name == model_name). \
|
||||
DatasetCollectionBinding.model_name == model_name,
|
||||
DatasetCollectionBinding.type == collection_type). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
|
||||
@ -1186,8 +1185,20 @@ class DatasetCollectionBindingService:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node',
|
||||
type=collection_type
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
return dataset_collection_binding
|
||||
|
||||
@classmethod
|
||||
def get_dataset_collection_binding_by_id_and_type(cls, collection_binding_id: str,
|
||||
collection_type: str = 'dataset') -> DatasetCollectionBinding:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == collection_binding_id,
|
||||
DatasetCollectionBinding.type == collection_type). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
|
||||
return dataset_collection_binding
|
||||
|
59
api/tasks/annotation/add_annotation_to_index_task.py
Normal file
59
api/tasks/annotation/add_annotation_to_index_task.py
Normal file
@ -0,0 +1,59 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str,
|
||||
collection_binding_id: str):
|
||||
"""
|
||||
Add annotation to index.
|
||||
:param annotation_id: annotation id
|
||||
:param question: question
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param collection_binding_id: embedding binding id
|
||||
|
||||
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
|
||||
"""
|
||||
logging.info(click.style('Start build index for annotation: {}'.format(annotation_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id,
|
||||
'annotation'
|
||||
)
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
document = Document(
|
||||
page_content=question,
|
||||
metadata={
|
||||
"annotation_id": annotation_id,
|
||||
"app_id": app_id,
|
||||
"doc_id": annotation_id
|
||||
}
|
||||
)
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts([document])
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception:
|
||||
logging.exception("Build index for annotation failed")
|
99
api/tasks/annotation/batch_import_annotations_task.py
Normal file
99
api/tasks/annotation/batch_import_annotations_task.py
Normal file
@ -0,0 +1,99 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from langchain.schema import Document
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.model import MessageAnnotation, App, AppAnnotationSetting
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str,
|
||||
user_id: str):
|
||||
"""
|
||||
Add annotation to index.
|
||||
:param job_id: job_id
|
||||
:param content_list: content list
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param user_id: user_id
|
||||
|
||||
"""
|
||||
logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if app:
|
||||
try:
|
||||
documents = []
|
||||
for content in content_list:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
content=content['answer'],
|
||||
question=content['question'],
|
||||
account_id=user_id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.flush()
|
||||
|
||||
document = Document(
|
||||
page_content=content['question'],
|
||||
metadata={
|
||||
"annotation_id": annotation.id,
|
||||
"app_id": app_id,
|
||||
"doc_id": annotation.id
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
# if annotation reply is enabled , batch add annotations' index
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id
|
||||
).first()
|
||||
|
||||
if app_annotation_setting:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
app_annotation_setting.collection_binding_id,
|
||||
'annotation'
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
raise NotFound("App annotation setting not found")
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts(documents)
|
||||
|
||||
db.session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, 'completed')
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
redis_client.setex(indexing_cache_key, 600, 'error')
|
||||
indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
|
||||
redis_client.setex(indexing_error_msg_key, 600, str(e))
|
||||
logging.exception("Build index for batch import annotations failed")
|
45
api/tasks/annotation/delete_annotation_index_task.py
Normal file
45
api/tasks/annotation/delete_annotation_index_task.py
Normal file
@ -0,0 +1,45 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from core.index.index import IndexBuilder
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str,
|
||||
collection_binding_id: str):
|
||||
"""
|
||||
Async delete annotation index task
|
||||
"""
|
||||
logging.info(click.style('Start delete app annotation index: {}'.format(app_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
try:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id,
|
||||
'annotation'
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
|
||||
if vector_index:
|
||||
try:
|
||||
vector_index.delete_by_metadata_field('annotation_id', annotation_id)
|
||||
except Exception:
|
||||
logging.exception("Delete annotation index failed when annotation deleted.")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
logging.exception("Annotation deleted index failed:{}".format(str(e)))
|
||||
|
74
api/tasks/annotation/disable_annotation_reply_task.py
Normal file
74
api/tasks/annotation/disable_annotation_reply_task.py
Normal file
@ -0,0 +1,74 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.model import MessageAnnotation, App, AppAnnotationSetting
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
||||
"""
|
||||
Async enable annotation reply task
|
||||
"""
|
||||
logging.info(click.style('Start delete app annotations index: {}'.format(app_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id
|
||||
).first()
|
||||
|
||||
if not app_annotation_setting:
|
||||
raise NotFound("App annotation setting not found")
|
||||
|
||||
disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
|
||||
disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
|
||||
|
||||
try:
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
collection_binding_id=app_annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
|
||||
if vector_index:
|
||||
try:
|
||||
vector_index.delete_by_metadata_field('app_id', app_id)
|
||||
except Exception:
|
||||
logging.exception("Delete doc index failed when dataset deleted.")
|
||||
redis_client.setex(disable_app_annotation_job_key, 600, 'completed')
|
||||
|
||||
# delete annotation setting
|
||||
db.session.delete(app_annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
logging.exception("Annotation batch deleted index failed:{}".format(str(e)))
|
||||
redis_client.setex(disable_app_annotation_job_key, 600, 'error')
|
||||
disable_app_annotation_error_key = 'disable_app_annotation_error_{}'.format(str(job_id))
|
||||
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
|
||||
finally:
|
||||
redis_client.delete(disable_app_annotation_key)
|
106
api/tasks/annotation/enable_annotation_reply_task.py
Normal file
106
api/tasks/annotation/enable_annotation_reply_task.py
Normal file
@ -0,0 +1,106 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from langchain.schema import Document
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.model import MessageAnnotation, App, AppAnnotationSetting
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_id: str, score_threshold: float,
|
||||
embedding_provider_name: str, embedding_model_name: str):
|
||||
"""
|
||||
Async enable annotation reply task
|
||||
"""
|
||||
logging.info(click.style('Start add app annotation to index: {}'.format(app_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all()
|
||||
enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
|
||||
enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
|
||||
|
||||
try:
|
||||
documents = []
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
'annotation'
|
||||
)
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
if annotation_setting:
|
||||
annotation_setting.score_threshold = score_threshold
|
||||
annotation_setting.collection_binding_id = dataset_collection_binding.id
|
||||
annotation_setting.updated_user_id = user_id
|
||||
annotation_setting.updated_at = datetime.datetime.utcnow()
|
||||
db.session.add(annotation_setting)
|
||||
else:
|
||||
new_app_annotation_setting = AppAnnotationSetting(
|
||||
app_id=app_id,
|
||||
score_threshold=score_threshold,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
created_user_id=user_id,
|
||||
updated_user_id=user_id
|
||||
)
|
||||
db.session.add(new_app_annotation_setting)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question,
|
||||
metadata={
|
||||
"annotation_id": annotation.id,
|
||||
"app_id": app_id,
|
||||
"doc_id": annotation.id
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
try:
|
||||
index.delete_by_metadata_field('app_id', app_id)
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
click.style('Delete annotation index error: {}'.format(str(e)),
|
||||
fg='red'))
|
||||
index.add_texts(documents)
|
||||
db.session.commit()
|
||||
redis_client.setex(enable_app_annotation_job_key, 600, 'completed')
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('App annotations added to index: {} latency: {}'.format(app_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception as e:
|
||||
logging.exception("Annotation batch created index failed:{}".format(str(e)))
|
||||
redis_client.setex(enable_app_annotation_job_key, 600, 'error')
|
||||
enable_app_annotation_error_key = 'enable_app_annotation_error_{}'.format(str(job_id))
|
||||
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
|
||||
db.session.rollback()
|
||||
finally:
|
||||
redis_client.delete(enable_app_annotation_key)
|
63
api/tasks/annotation/update_annotation_to_index_task.py
Normal file
63
api/tasks/annotation/update_annotation_to_index_task.py
Normal file
@ -0,0 +1,63 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str,
|
||||
collection_binding_id: str):
|
||||
"""
|
||||
Update annotation to index.
|
||||
:param annotation_id: annotation id
|
||||
:param question: question
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param collection_binding_id: embedding binding id
|
||||
|
||||
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
|
||||
"""
|
||||
logging.info(click.style('Start update index for annotation: {}'.format(annotation_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
collection_binding_id,
|
||||
'annotation'
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
)
|
||||
|
||||
document = Document(
|
||||
page_content=question,
|
||||
metadata={
|
||||
"annotation_id": annotation_id,
|
||||
"app_id": app_id,
|
||||
"doc_id": annotation_id
|
||||
}
|
||||
)
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.delete_by_metadata_field('annotation_id', annotation_id)
|
||||
index.add_texts([document])
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at),
|
||||
fg='green'))
|
||||
except Exception:
|
||||
logging.exception("Build index for annotation failed")
|
Loading…
x
Reference in New Issue
Block a user