fix: refactor conversation pagination to use SQLAlchemy session manag… (#11956)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-22 10:39:29 +08:00 committed by GitHub
parent 366857cd26
commit 3d07a94bd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 78 additions and 55 deletions

View File

@ -1,12 +1,14 @@
from flask_login import current_user from flask_login import current_user
from flask_restful import marshal_with, reqparse from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.explore.error import NotChatAppError from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import AppMode from models.model import AppMode
@ -34,14 +36,16 @@ class ConversationListApi(InstalledAppResource):
pinned = True if args["pinned"] == "true" else False pinned = True if args["pinned"] == "true" else False
try: try:
return WebConversationService.pagination_by_last_id( with Session(db.engine) as session:
app_model=app_model, return WebConversationService.pagination_by_last_id(
user=current_user, session=session,
last_id=args["last_id"], app_model=app_model,
limit=args["limit"], user=current_user,
invoke_from=InvokeFrom.EXPLORE, last_id=args["last_id"],
pinned=pinned, limit=args["limit"],
) invoke_from=InvokeFrom.EXPLORE,
pinned=pinned,
)
except LastConversationNotExistsError: except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")

View File

@ -1,5 +1,6 @@
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services import services
@ -7,6 +8,7 @@ from controllers.service_api import api
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import ( from fields.conversation_fields import (
conversation_delete_fields, conversation_delete_fields,
conversation_infinite_scroll_pagination_fields, conversation_infinite_scroll_pagination_fields,
@ -39,14 +41,16 @@ class ConversationApi(Resource):
args = parser.parse_args() args = parser.parse_args()
try: try:
return ConversationService.pagination_by_last_id( with Session(db.engine) as session:
app_model=app_model, return ConversationService.pagination_by_last_id(
user=end_user, session=session,
last_id=args["last_id"], app_model=app_model,
limit=args["limit"], user=end_user,
invoke_from=InvokeFrom.SERVICE_API, last_id=args["last_id"],
sort_by=args["sort_by"], limit=args["limit"],
) invoke_from=InvokeFrom.SERVICE_API,
sort_by=args["sort_by"],
)
except services.errors.conversation.LastConversationNotExistsError: except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")

View File

@ -1,11 +1,13 @@
from flask_restful import marshal_with, reqparse from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.web import api from controllers.web import api
from controllers.web.error import NotChatAppError from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import AppMode from models.model import AppMode
@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource):
pinned = True if args["pinned"] == "true" else False pinned = True if args["pinned"] == "true" else False
try: try:
return WebConversationService.pagination_by_last_id( with Session(db.engine) as session:
app_model=app_model, return WebConversationService.pagination_by_last_id(
user=end_user, session=session,
last_id=args["last_id"], app_model=app_model,
limit=args["limit"], user=end_user,
invoke_from=InvokeFrom.WEB_APP, last_id=args["last_id"],
pinned=pinned, limit=args["limit"],
sort_by=args["sort_by"], invoke_from=InvokeFrom.WEB_APP,
) pinned=pinned,
sort_by=args["sort_by"],
)
except LastConversationNotExistsError: except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")

View File

@ -1,4 +1,5 @@
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column
from .engine import db from .engine import db
from .model import Message from .model import Message
@ -33,7 +34,7 @@ class PinnedConversation(db.Model):
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -1,8 +1,9 @@
from collections.abc import Callable from collections.abc import Callable, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Optional, Union from typing import Optional, Union
from sqlalchemy import asc, desc, or_ from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
@ -18,19 +19,21 @@ class ConversationService:
@classmethod @classmethod
def pagination_by_last_id( def pagination_by_last_id(
cls, cls,
*,
session: Session,
app_model: App, app_model: App,
user: Optional[Union[Account, EndUser]], user: Optional[Union[Account, EndUser]],
last_id: Optional[str], last_id: Optional[str],
limit: int, limit: int,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
include_ids: Optional[list] = None, include_ids: Optional[Sequence[str]] = None,
exclude_ids: Optional[list] = None, exclude_ids: Optional[Sequence[str]] = None,
sort_by: str = "-updated_at", sort_by: str = "-updated_at",
) -> InfiniteScrollPagination: ) -> InfiniteScrollPagination:
if not user: if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False) return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
base_query = db.session.query(Conversation).filter( stmt = select(Conversation).where(
Conversation.is_deleted == False, Conversation.is_deleted == False,
Conversation.app_id == app_model.id, Conversation.app_id == app_model.id,
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
@ -38,37 +41,40 @@ class ConversationService:
Conversation.from_account_id == (user.id if isinstance(user, Account) else None), Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
) )
if include_ids is not None: if include_ids is not None:
base_query = base_query.filter(Conversation.id.in_(include_ids)) stmt = stmt.where(Conversation.id.in_(include_ids))
if exclude_ids is not None: if exclude_ids is not None:
base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) stmt = stmt.where(~Conversation.id.in_(exclude_ids))
# define sort fields and directions # define sort fields and directions
sort_field, sort_direction = cls._get_sort_params(sort_by) sort_field, sort_direction = cls._get_sort_params(sort_by)
if last_id: if last_id:
last_conversation = base_query.filter(Conversation.id == last_id).first() last_conversation = session.scalar(stmt.where(Conversation.id == last_id))
if not last_conversation: if not last_conversation:
raise LastConversationNotExistsError() raise LastConversationNotExistsError()
# build filters based on sorting # build filters based on sorting
filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation) filter_condition = cls._build_filter_condition(
base_query = base_query.filter(filter_condition) sort_field=sort_field,
sort_direction=sort_direction,
base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field))) reference_conversation=last_conversation,
)
conversations = base_query.limit(limit).all() stmt = stmt.where(filter_condition)
query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit)
conversations = session.scalars(query_stmt).all()
has_more = False has_more = False
if len(conversations) == limit: if len(conversations) == limit:
current_page_last_conversation = conversations[-1] current_page_last_conversation = conversations[-1]
rest_filter_condition = cls._build_filter_condition( rest_filter_condition = cls._build_filter_condition(
sort_field, sort_direction, current_page_last_conversation, is_next_page=True sort_field=sort_field,
sort_direction=sort_direction,
reference_conversation=current_page_last_conversation,
) )
rest_count = base_query.filter(rest_filter_condition).count() count_stmt = stmt.where(rest_filter_condition)
count_stmt = select(func.count()).select_from(count_stmt.subquery())
rest_count = session.scalar(count_stmt) or 0
if rest_count > 0: if rest_count > 0:
has_more = True has_more = True
@ -81,11 +87,9 @@ class ConversationService:
return sort_by, asc return sort_by, asc
@classmethod @classmethod
def _build_filter_condition( def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation):
cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False
):
field_value = getattr(reference_conversation, sort_field) field_value = getattr(reference_conversation, sort_field)
if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): if sort_direction == desc:
return getattr(Conversation, sort_field) < field_value return getattr(Conversation, sort_field) < field_value
else: else:
return getattr(Conversation, sort_field) > field_value return getattr(Conversation, sort_field) > field_value

View File

@ -1,5 +1,8 @@
from typing import Optional, Union from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -13,6 +16,8 @@ class WebConversationService:
@classmethod @classmethod
def pagination_by_last_id( def pagination_by_last_id(
cls, cls,
*,
session: Session,
app_model: App, app_model: App,
user: Optional[Union[Account, EndUser]], user: Optional[Union[Account, EndUser]],
last_id: Optional[str], last_id: Optional[str],
@ -23,24 +28,25 @@ class WebConversationService:
) -> InfiniteScrollPagination: ) -> InfiniteScrollPagination:
include_ids = None include_ids = None
exclude_ids = None exclude_ids = None
if pinned is not None: if pinned is not None and user:
pinned_conversations = ( stmt = (
db.session.query(PinnedConversation) select(PinnedConversation.conversation_id)
.filter( .where(
PinnedConversation.app_id == app_model.id, PinnedConversation.app_id == app_model.id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
PinnedConversation.created_by == user.id, PinnedConversation.created_by == user.id,
) )
.order_by(PinnedConversation.created_at.desc()) .order_by(PinnedConversation.created_at.desc())
.all()
) )
pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] pinned_conversation_ids = session.scalars(stmt).all()
if pinned: if pinned:
include_ids = pinned_conversation_ids include_ids = pinned_conversation_ids
else: else:
exclude_ids = pinned_conversation_ids exclude_ids = pinned_conversation_ids
return ConversationService.pagination_by_last_id( return ConversationService.pagination_by_last_id(
session=session,
app_model=app_model, app_model=app_model,
user=user, user=user,
last_id=last_id, last_id=last_id,