mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 14:45:53 +08:00
fix: refactor conversation pagination to use SQLAlchemy session manag… (#11956)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
366857cd26
commit
3d07a94bd7
@ -1,12 +1,14 @@
|
|||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import marshal_with, reqparse
|
from flask_restful import marshal_with, reqparse
|
||||||
from flask_restful.inputs import int_range
|
from flask_restful.inputs import int_range
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.explore.error import NotChatAppError
|
from controllers.console.explore.error import NotChatAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from extensions.ext_database import db
|
||||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
@ -34,14 +36,16 @@ class ConversationListApi(InstalledAppResource):
|
|||||||
pinned = True if args["pinned"] == "true" else False
|
pinned = True if args["pinned"] == "true" else False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return WebConversationService.pagination_by_last_id(
|
with Session(db.engine) as session:
|
||||||
app_model=app_model,
|
return WebConversationService.pagination_by_last_id(
|
||||||
user=current_user,
|
session=session,
|
||||||
last_id=args["last_id"],
|
app_model=app_model,
|
||||||
limit=args["limit"],
|
user=current_user,
|
||||||
invoke_from=InvokeFrom.EXPLORE,
|
last_id=args["last_id"],
|
||||||
pinned=pinned,
|
limit=args["limit"],
|
||||||
)
|
invoke_from=InvokeFrom.EXPLORE,
|
||||||
|
pinned=pinned,
|
||||||
|
)
|
||||||
except LastConversationNotExistsError:
|
except LastConversationNotExistsError:
|
||||||
raise NotFound("Last Conversation Not Exists.")
|
raise NotFound("Last Conversation Not Exists.")
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from flask_restful import Resource, marshal_with, reqparse
|
from flask_restful import Resource, marshal_with, reqparse
|
||||||
from flask_restful.inputs import int_range
|
from flask_restful.inputs import int_range
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -7,6 +8,7 @@ from controllers.service_api import api
|
|||||||
from controllers.service_api.app.error import NotChatAppError
|
from controllers.service_api.app.error import NotChatAppError
|
||||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from extensions.ext_database import db
|
||||||
from fields.conversation_fields import (
|
from fields.conversation_fields import (
|
||||||
conversation_delete_fields,
|
conversation_delete_fields,
|
||||||
conversation_infinite_scroll_pagination_fields,
|
conversation_infinite_scroll_pagination_fields,
|
||||||
@ -39,14 +41,16 @@ class ConversationApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ConversationService.pagination_by_last_id(
|
with Session(db.engine) as session:
|
||||||
app_model=app_model,
|
return ConversationService.pagination_by_last_id(
|
||||||
user=end_user,
|
session=session,
|
||||||
last_id=args["last_id"],
|
app_model=app_model,
|
||||||
limit=args["limit"],
|
user=end_user,
|
||||||
invoke_from=InvokeFrom.SERVICE_API,
|
last_id=args["last_id"],
|
||||||
sort_by=args["sort_by"],
|
limit=args["limit"],
|
||||||
)
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
sort_by=args["sort_by"],
|
||||||
|
)
|
||||||
except services.errors.conversation.LastConversationNotExistsError:
|
except services.errors.conversation.LastConversationNotExistsError:
|
||||||
raise NotFound("Last Conversation Not Exists.")
|
raise NotFound("Last Conversation Not Exists.")
|
||||||
|
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
from flask_restful import marshal_with, reqparse
|
from flask_restful import marshal_with, reqparse
|
||||||
from flask_restful.inputs import int_range
|
from flask_restful.inputs import int_range
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.web import api
|
from controllers.web import api
|
||||||
from controllers.web.error import NotChatAppError
|
from controllers.web.error import NotChatAppError
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from extensions.ext_database import db
|
||||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource):
|
|||||||
pinned = True if args["pinned"] == "true" else False
|
pinned = True if args["pinned"] == "true" else False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return WebConversationService.pagination_by_last_id(
|
with Session(db.engine) as session:
|
||||||
app_model=app_model,
|
return WebConversationService.pagination_by_last_id(
|
||||||
user=end_user,
|
session=session,
|
||||||
last_id=args["last_id"],
|
app_model=app_model,
|
||||||
limit=args["limit"],
|
user=end_user,
|
||||||
invoke_from=InvokeFrom.WEB_APP,
|
last_id=args["last_id"],
|
||||||
pinned=pinned,
|
limit=args["limit"],
|
||||||
sort_by=args["sort_by"],
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
)
|
pinned=pinned,
|
||||||
|
sort_by=args["sort_by"],
|
||||||
|
)
|
||||||
except LastConversationNotExistsError:
|
except LastConversationNotExistsError:
|
||||||
raise NotFound("Last Conversation Not Exists.")
|
raise NotFound("Last Conversation Not Exists.")
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from .engine import db
|
from .engine import db
|
||||||
from .model import Message
|
from .model import Message
|
||||||
@ -33,7 +34,7 @@ class PinnedConversation(db.Model):
|
|||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
app_id = db.Column(StringUUID, nullable=False)
|
app_id = db.Column(StringUUID, nullable=False)
|
||||||
conversation_id = db.Column(StringUUID, nullable=False)
|
conversation_id: Mapped[str] = mapped_column(StringUUID)
|
||||||
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
|
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
|
||||||
created_by = db.Column(StringUUID, nullable=False)
|
created_by = db.Column(StringUUID, nullable=False)
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable, Sequence
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from sqlalchemy import asc, desc, or_
|
from sqlalchemy import asc, desc, func, or_, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
@ -18,19 +19,21 @@ class ConversationService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def pagination_by_last_id(
|
def pagination_by_last_id(
|
||||||
cls,
|
cls,
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Optional[Union[Account, EndUser]],
|
user: Optional[Union[Account, EndUser]],
|
||||||
last_id: Optional[str],
|
last_id: Optional[str],
|
||||||
limit: int,
|
limit: int,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
include_ids: Optional[list] = None,
|
include_ids: Optional[Sequence[str]] = None,
|
||||||
exclude_ids: Optional[list] = None,
|
exclude_ids: Optional[Sequence[str]] = None,
|
||||||
sort_by: str = "-updated_at",
|
sort_by: str = "-updated_at",
|
||||||
) -> InfiniteScrollPagination:
|
) -> InfiniteScrollPagination:
|
||||||
if not user:
|
if not user:
|
||||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||||
|
|
||||||
base_query = db.session.query(Conversation).filter(
|
stmt = select(Conversation).where(
|
||||||
Conversation.is_deleted == False,
|
Conversation.is_deleted == False,
|
||||||
Conversation.app_id == app_model.id,
|
Conversation.app_id == app_model.id,
|
||||||
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||||
@ -38,37 +41,40 @@ class ConversationService:
|
|||||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||||
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
|
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
|
||||||
)
|
)
|
||||||
|
|
||||||
if include_ids is not None:
|
if include_ids is not None:
|
||||||
base_query = base_query.filter(Conversation.id.in_(include_ids))
|
stmt = stmt.where(Conversation.id.in_(include_ids))
|
||||||
|
|
||||||
if exclude_ids is not None:
|
if exclude_ids is not None:
|
||||||
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
|
stmt = stmt.where(~Conversation.id.in_(exclude_ids))
|
||||||
|
|
||||||
# define sort fields and directions
|
# define sort fields and directions
|
||||||
sort_field, sort_direction = cls._get_sort_params(sort_by)
|
sort_field, sort_direction = cls._get_sort_params(sort_by)
|
||||||
|
|
||||||
if last_id:
|
if last_id:
|
||||||
last_conversation = base_query.filter(Conversation.id == last_id).first()
|
last_conversation = session.scalar(stmt.where(Conversation.id == last_id))
|
||||||
if not last_conversation:
|
if not last_conversation:
|
||||||
raise LastConversationNotExistsError()
|
raise LastConversationNotExistsError()
|
||||||
|
|
||||||
# build filters based on sorting
|
# build filters based on sorting
|
||||||
filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation)
|
filter_condition = cls._build_filter_condition(
|
||||||
base_query = base_query.filter(filter_condition)
|
sort_field=sort_field,
|
||||||
|
sort_direction=sort_direction,
|
||||||
base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field)))
|
reference_conversation=last_conversation,
|
||||||
|
)
|
||||||
conversations = base_query.limit(limit).all()
|
stmt = stmt.where(filter_condition)
|
||||||
|
query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit)
|
||||||
|
conversations = session.scalars(query_stmt).all()
|
||||||
|
|
||||||
has_more = False
|
has_more = False
|
||||||
if len(conversations) == limit:
|
if len(conversations) == limit:
|
||||||
current_page_last_conversation = conversations[-1]
|
current_page_last_conversation = conversations[-1]
|
||||||
rest_filter_condition = cls._build_filter_condition(
|
rest_filter_condition = cls._build_filter_condition(
|
||||||
sort_field, sort_direction, current_page_last_conversation, is_next_page=True
|
sort_field=sort_field,
|
||||||
|
sort_direction=sort_direction,
|
||||||
|
reference_conversation=current_page_last_conversation,
|
||||||
)
|
)
|
||||||
rest_count = base_query.filter(rest_filter_condition).count()
|
count_stmt = stmt.where(rest_filter_condition)
|
||||||
|
count_stmt = select(func.count()).select_from(count_stmt.subquery())
|
||||||
|
rest_count = session.scalar(count_stmt) or 0
|
||||||
if rest_count > 0:
|
if rest_count > 0:
|
||||||
has_more = True
|
has_more = True
|
||||||
|
|
||||||
@ -81,11 +87,9 @@ class ConversationService:
|
|||||||
return sort_by, asc
|
return sort_by, asc
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build_filter_condition(
|
def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation):
|
||||||
cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False
|
|
||||||
):
|
|
||||||
field_value = getattr(reference_conversation, sort_field)
|
field_value = getattr(reference_conversation, sort_field)
|
||||||
if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page):
|
if sort_direction == desc:
|
||||||
return getattr(Conversation, sort_field) < field_value
|
return getattr(Conversation, sort_field) < field_value
|
||||||
else:
|
else:
|
||||||
return getattr(Conversation, sort_field) > field_value
|
return getattr(Conversation, sort_field) > field_value
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
@ -13,6 +16,8 @@ class WebConversationService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def pagination_by_last_id(
|
def pagination_by_last_id(
|
||||||
cls,
|
cls,
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Optional[Union[Account, EndUser]],
|
user: Optional[Union[Account, EndUser]],
|
||||||
last_id: Optional[str],
|
last_id: Optional[str],
|
||||||
@ -23,24 +28,25 @@ class WebConversationService:
|
|||||||
) -> InfiniteScrollPagination:
|
) -> InfiniteScrollPagination:
|
||||||
include_ids = None
|
include_ids = None
|
||||||
exclude_ids = None
|
exclude_ids = None
|
||||||
if pinned is not None:
|
if pinned is not None and user:
|
||||||
pinned_conversations = (
|
stmt = (
|
||||||
db.session.query(PinnedConversation)
|
select(PinnedConversation.conversation_id)
|
||||||
.filter(
|
.where(
|
||||||
PinnedConversation.app_id == app_model.id,
|
PinnedConversation.app_id == app_model.id,
|
||||||
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||||
PinnedConversation.created_by == user.id,
|
PinnedConversation.created_by == user.id,
|
||||||
)
|
)
|
||||||
.order_by(PinnedConversation.created_at.desc())
|
.order_by(PinnedConversation.created_at.desc())
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations]
|
pinned_conversation_ids = session.scalars(stmt).all()
|
||||||
|
|
||||||
if pinned:
|
if pinned:
|
||||||
include_ids = pinned_conversation_ids
|
include_ids = pinned_conversation_ids
|
||||||
else:
|
else:
|
||||||
exclude_ids = pinned_conversation_ids
|
exclude_ids = pinned_conversation_ids
|
||||||
|
|
||||||
return ConversationService.pagination_by_last_id(
|
return ConversationService.pagination_by_last_id(
|
||||||
|
session=session,
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=user,
|
user=user,
|
||||||
last_id=last_id,
|
last_id=last_id,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user