show avatar dialog instead of default (#4033)

show avatar dialog instead of default

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
so95 2024-12-17 16:29:35 +07:00 committed by GitHub
parent 09436f6c60
commit 251592eeeb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 159 additions and 12 deletions

View File

@ -23,6 +23,7 @@ from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken
@manager.route('/templates', methods=['GET']) # noqa: F821 @manager.route('/templates', methods=['GET']) # noqa: F821
@ -85,6 +86,20 @@ def get(canvas_id):
return get_data_error_result(message="canvas not found.") return get_data_error_result(message="canvas not found.")
return get_json_result(data=c.to_dict()) return get_json_result(data=c.to_dict())
@manager.route('/getsse/<canvas_id>', methods=['GET']) # type: ignore # noqa: F821
def getsse(canvas_id):
token = request.headers.get('Authorization').split()
if len(token) != 2:
return get_data_error_result(message='Authorization is not valid!"')
token = token[1]
objs = APIToken.query(beta=token)
if not objs:
return get_data_error_result(message='Token is not valid!"')
e, c = UserCanvasService.get_by_id(canvas_id)
if not e:
return get_data_error_result(message="canvas not found.")
return get_json_result(data=c.to_dict())
@manager.route('/completion', methods=['POST']) # noqa: F821 @manager.route('/completion', methods=['POST']) # noqa: F821
@validate_request("id") @validate_request("id")

View File

@ -17,6 +17,7 @@ import json
import re import re
import traceback import traceback
from copy import deepcopy from copy import deepcopy
from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
@ -32,7 +33,6 @@ from api.utils.api_utils import get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from graphrag.mind_map_extractor import MindMapExtractor from graphrag.mind_map_extractor import MindMapExtractor
@manager.route('/set', methods=['POST']) # noqa: F821 @manager.route('/set', methods=['POST']) # noqa: F821
@login_required @login_required
def set_conversation(): def set_conversation():
@ -79,12 +79,16 @@ def set_conversation():
def get(): def get():
conv_id = request.args["conversation_id"] conv_id = request.args["conversation_id"]
try: try:
e, conv = ConversationService.get_by_id(conv_id) e, conv = ConversationService.get_by_id(conv_id)
if not e: if not e:
return get_data_error_result(message="Conversation not found!") return get_data_error_result(message="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
avatar =None
for tenant in tenants: for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id): dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
if dialog and len(dialog)>0:
avatar = dialog[0].icon
break break
else: else:
return get_json_result( return get_json_result(
@ -108,10 +112,31 @@ def get():
} for ck in ref.get("chunks", [])] } for ck in ref.get("chunks", [])]
conv = conv.to_dict() conv = conv.to_dict()
conv["avatar"]=avatar
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/getsse/<dialog_id>', methods=['GET']) # type: ignore # noqa: F821
def getsse(dialog_id):
token = request.headers.get('Authorization').split()
if len(token) != 2:
return get_data_error_result(message='Authorization is not valid!"')
token = token[1]
objs = APIToken.query(beta=token)
if not objs:
return get_data_error_result(message='Token is not valid!"')
try:
e, conv = DialogService.get_by_id(dialog_id)
if not e:
return get_data_error_result(message="Dialog not found!")
conv = conv.to_dict()
conv["avatar"]= conv["icon"]
del conv["icon"]
return get_json_result(data=conv)
except Exception as e:
return server_error_response(e)
@manager.route('/rm', methods=['POST']) # noqa: F821 @manager.route('/rm', methods=['POST']) # noqa: F821
@login_required @login_required

View File

@ -30,6 +30,7 @@ interface IProps extends Partial<IRemoveMessageById>, IRegenerateMessage {
sendLoading?: boolean; sendLoading?: boolean;
nickname?: string; nickname?: string;
avatar?: string; avatar?: string;
avatardialog?: string | null;
clickDocumentButton?: (documentId: string, chunk: IReferenceChunk) => void; clickDocumentButton?: (documentId: string, chunk: IReferenceChunk) => void;
index: number; index: number;
showLikeButton?: boolean; showLikeButton?: boolean;
@ -40,6 +41,7 @@ const MessageItem = ({
reference, reference,
loading = false, loading = false,
avatar, avatar,
avatardialog,
sendLoading = false, sendLoading = false,
clickDocumentButton, clickDocumentButton,
index, index,
@ -103,8 +105,10 @@ const MessageItem = ({
> >
{item.role === MessageType.User ? ( {item.role === MessageType.User ? (
<Avatar size={40} src={avatar ?? '/logo.svg'} /> <Avatar size={40} src={avatar ?? '/logo.svg'} />
) : avatardialog ? (
<Avatar size={40} src={avatardialog} />
) : ( ) : (
<AssistantIcon></AssistantIcon> <AssistantIcon />
)} )}
<Flex vertical gap={8} flex={1}> <Flex vertical gap={8} flex={1}>
<Space> <Space>

View File

@ -11,6 +11,7 @@ import {
} from '@/interfaces/request/chat'; } from '@/interfaces/request/chat';
import i18n from '@/locales/config'; import i18n from '@/locales/config';
import { IClientConversation } from '@/pages/chat/interface'; import { IClientConversation } from '@/pages/chat/interface';
import { useGetSharedChatSearchParams } from '@/pages/chat/shared-hooks';
import chatService from '@/services/chat-service'; import chatService from '@/services/chat-service';
import { import {
buildMessageListWithUuid, buildMessageListWithUuid,
@ -27,6 +28,7 @@ import { history, useSearchParams } from 'umi';
//#region logic //#region logic
export const useClickDialogCard = () => { export const useClickDialogCard = () => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const [_, setSearchParams] = useSearchParams(); const [_, setSearchParams] = useSearchParams();
const newQueryParameters: URLSearchParams = useMemo(() => { const newQueryParameters: URLSearchParams = useMemo(() => {
@ -243,6 +245,7 @@ export const useFetchNextConversationList = () => {
export const useFetchNextConversation = () => { export const useFetchNextConversation = () => {
const { isNew, conversationId } = useGetChatSearchParams(); const { isNew, conversationId } = useGetChatSearchParams();
const { sharedId } = useGetSharedChatSearchParams();
const { const {
data, data,
isFetching: loading, isFetching: loading,
@ -254,8 +257,13 @@ export const useFetchNextConversation = () => {
gcTime: 0, gcTime: 0,
refetchOnWindowFocus: false, refetchOnWindowFocus: false,
queryFn: async () => { queryFn: async () => {
if (isNew !== 'true' && isConversationIdExist(conversationId)) { if (
const { data } = await chatService.getConversation({ conversationId }); isNew !== 'true' &&
isConversationIdExist(sharedId || conversationId)
) {
const { data } = await chatService.getConversation({
conversationId: conversationId || sharedId,
});
const conversation = data?.data ?? {}; const conversation = data?.data ?? {};
@ -270,6 +278,33 @@ export const useFetchNextConversation = () => {
return { data, loading, refetch }; return { data, loading, refetch };
}; };
export const useFetchNextConversationSSE = () => {
const { isNew } = useGetChatSearchParams();
const { sharedId } = useGetSharedChatSearchParams();
const {
data,
isFetching: loading,
refetch,
} = useQuery<IClientConversation>({
queryKey: ['fetchConversationSSE', sharedId],
initialData: {} as IClientConversation,
gcTime: 0,
refetchOnWindowFocus: false,
queryFn: async () => {
if (isNew !== 'true' && isConversationIdExist(sharedId || '')) {
if (!sharedId) return {};
const { data } = await chatService.getConversationSSE({}, sharedId);
const conversation = data?.data ?? {};
const messageList = buildMessageListWithUuid(conversation?.message);
return { ...conversation, message: messageList };
}
return { message: [] };
},
});
return { data, loading, refetch };
};
export const useFetchManualConversation = () => { export const useFetchManualConversation = () => {
const { const {
data, data,
@ -547,7 +582,7 @@ export const useFetchMindMap = () => {
try { try {
const ret = await chatService.getMindMap(params); const ret = await chatService.getMindMap(params);
return ret?.data?.data ?? {}; return ret?.data?.data ?? {};
} catch (error) { } catch (error: any) {
if (has(error, 'message')) { if (has(error, 'message')) {
message.error(error.message); message.error(error.message);
} }

View File

@ -2,6 +2,7 @@ import { ResponseType } from '@/interfaces/database/base';
import { DSL, IFlow, IFlowTemplate } from '@/interfaces/database/flow'; import { DSL, IFlow, IFlowTemplate } from '@/interfaces/database/flow';
import { IDebugSingleRequestBody } from '@/interfaces/request/flow'; import { IDebugSingleRequestBody } from '@/interfaces/request/flow';
import i18n from '@/locales/config'; import i18n from '@/locales/config';
import { useGetSharedChatSearchParams } from '@/pages/chat/shared-hooks';
import flowService from '@/services/flow-service'; import flowService from '@/services/flow-service';
import { buildMessageListWithUuid } from '@/utils/chat'; import { buildMessageListWithUuid } from '@/utils/chat';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
@ -91,6 +92,8 @@ export const useFetchFlow = (): {
refetch: () => void; refetch: () => void;
} => { } => {
const { id } = useParams(); const { id } = useParams();
const { sharedId } = useGetSharedChatSearchParams();
const { const {
data, data,
isFetching: loading, isFetching: loading,
@ -103,7 +106,41 @@ export const useFetchFlow = (): {
refetchOnWindowFocus: false, refetchOnWindowFocus: false,
gcTime: 0, gcTime: 0,
queryFn: async () => { queryFn: async () => {
const { data } = await flowService.getCanvas({}, id); const { data } = await flowService.getCanvas({}, sharedId || id);
const messageList = buildMessageListWithUuid(
get(data, 'data.dsl.messages', []),
);
set(data, 'data.dsl.messages', messageList);
return data?.data ?? {};
},
});
return { data, loading, refetch };
};
export const useFetchFlowSSE = (): {
data: IFlow;
loading: boolean;
refetch: () => void;
} => {
const { sharedId } = useGetSharedChatSearchParams();
const {
data,
isFetching: loading,
refetch,
} = useQuery({
queryKey: ['flowDetailSSE'],
initialData: {} as IFlow,
refetchOnReconnect: false,
refetchOnMount: false,
refetchOnWindowFocus: false,
gcTime: 0,
queryFn: async () => {
if (!sharedId) return {};
const { data } = await flowService.getCanvasSSE({}, sharedId);
const messageList = buildMessageListWithUuid( const messageList = buildMessageListWithUuid(
get(data, 'data.dsl.messages', []), get(data, 'data.dsl.messages', []),

View File

@ -57,6 +57,7 @@ export interface IConversation {
create_time: number; create_time: number;
dialog_id: string; dialog_id: string;
id: string; id: string;
avatar: string;
message: Message[]; message: Message[];
reference: IReference[]; reference: IReference[];
name: string; name: string;

View File

@ -29,8 +29,8 @@ export interface IGraph {
edges: Edge[]; edges: Edge[];
} }
export interface IFlow { export declare interface IFlow {
avatar: null; avatar?: null | string;
canvas_type: null; canvas_type: null;
create_date: string; create_date: string;
create_time: number; create_time: number;

View File

@ -68,6 +68,7 @@ const ChatContainer = ({ controller }: IProps) => {
item={message} item={message}
nickname={userInfo.nickname} nickname={userInfo.nickname}
avatar={userInfo.avatar} avatar={userInfo.avatar}
avatardialog={conversation.avatar}
reference={buildMessageItemReference( reference={buildMessageItemReference(
{ {
message: derivedMessages, message: derivedMessages,

View File

@ -4,7 +4,7 @@ import { useClickDrawer } from '@/components/pdf-drawer/hooks';
import { MessageType } from '@/constants/chat'; import { MessageType } from '@/constants/chat';
import { useSendButtonDisabled } from '@/pages/chat/hooks'; import { useSendButtonDisabled } from '@/pages/chat/hooks';
import { Flex, Spin } from 'antd'; import { Flex, Spin } from 'antd';
import { forwardRef } from 'react'; import { forwardRef, useMemo } from 'react';
import { import {
useGetSharedChatSearchParams, useGetSharedChatSearchParams,
useSendSharedMessage, useSendSharedMessage,
@ -12,6 +12,8 @@ import {
import { buildMessageItemReference } from '../utils'; import { buildMessageItemReference } from '../utils';
import PdfDrawer from '@/components/pdf-drawer'; import PdfDrawer from '@/components/pdf-drawer';
import { useFetchNextConversationSSE } from '@/hooks/chat-hooks';
import { useFetchFlowSSE } from '@/hooks/flow-hooks';
import styles from './index.less'; import styles from './index.less';
const ChatContainer = () => { const ChatContainer = () => {
@ -30,6 +32,14 @@ const ChatContainer = () => {
hasError, hasError,
} = useSendSharedMessage(); } = useSendSharedMessage();
const sendDisabled = useSendButtonDisabled(value); const sendDisabled = useSendButtonDisabled(value);
const useData = (from: SharedFrom) =>
useMemo(() => {
return from === SharedFrom.Agent
? useFetchFlowSSE
: useFetchNextConversationSSE;
}, [from]);
const { data: InforForm } = useData(from)();
if (!conversationId) { if (!conversationId) {
return <div>empty</div>; return <div>empty</div>;
@ -45,6 +55,7 @@ const ChatContainer = () => {
return ( return (
<MessageItem <MessageItem
key={message.id} key={message.id}
avatardialog={InforForm?.avatar}
item={message} item={message}
nickname="You" nickname="You"
reference={buildMessageItemReference( reference={buildMessageItemReference(

View File

@ -9,6 +9,7 @@ import { useSendNextMessage } from './hooks';
import PdfDrawer from '@/components/pdf-drawer'; import PdfDrawer from '@/components/pdf-drawer';
import { useClickDrawer } from '@/components/pdf-drawer/hooks'; import { useClickDrawer } from '@/components/pdf-drawer/hooks';
import { useFetchFlow } from '@/hooks/flow-hooks';
import { useFetchUserInfo } from '@/hooks/user-setting-hooks'; import { useFetchUserInfo } from '@/hooks/user-setting-hooks';
import styles from './index.less'; import styles from './index.less';
@ -29,6 +30,7 @@ const FlowChatBox = () => {
useGetFileIcon(); useGetFileIcon();
const { t } = useTranslate('chat'); const { t } = useTranslate('chat');
const { data: userInfo } = useFetchUserInfo(); const { data: userInfo } = useFetchUserInfo();
const { data: cavasInfo } = useFetchFlow();
return ( return (
<> <>
@ -47,6 +49,7 @@ const FlowChatBox = () => {
key={message.id} key={message.id}
nickname={userInfo.nickname} nickname={userInfo.nickname}
avatar={userInfo.avatar} avatar={userInfo.avatar}
avatardialog={cavasInfo.avatar}
item={message} item={message}
reference={buildMessageItemReference( reference={buildMessageItemReference(
{ message: derivedMessages, reference }, { message: derivedMessages, reference },

View File

@ -89,9 +89,12 @@ const KnowledgeList = () => {
className={styles.knowledgeCardContainer} className={styles.knowledgeCardContainer}
> >
{nextList?.length > 0 ? ( {nextList?.length > 0 ? (
nextList.map((item: any) => { nextList.map((item: any, index: number) => {
return ( return (
<KnowledgeCard item={item} key={item.name}></KnowledgeCard> <KnowledgeCard
item={item}
key={`${item?.name}-${index}`}
></KnowledgeCard>
); );
}) })
) : ( ) : (

View File

@ -8,6 +8,7 @@ const {
listDialog, listDialog,
removeDialog, removeDialog,
getConversation, getConversation,
getConversationSSE,
setConversation, setConversation,
completeConversation, completeConversation,
listConversation, listConversation,
@ -53,6 +54,10 @@ const methods = {
url: getConversation, url: getConversation,
method: 'get', method: 'get',
}, },
getConversationSSE: {
url: getConversationSSE,
method: 'get',
},
setConversation: { setConversation: {
url: setConversation, url: setConversation,
method: 'post', method: 'post',

View File

@ -4,6 +4,7 @@ import request from '@/utils/request';
const { const {
getCanvas, getCanvas,
getCanvasSSE,
setCanvas, setCanvas,
listCanvas, listCanvas,
resetCanvas, resetCanvas,
@ -20,6 +21,10 @@ const methods = {
url: getCanvas, url: getCanvas,
method: 'get', method: 'get',
}, },
getCanvasSSE: {
url: getCanvasSSE,
method: 'get',
},
setCanvas: { setCanvas: {
url: setCanvas, url: setCanvas,
method: 'post', method: 'post',

View File

@ -71,6 +71,7 @@ export default {
listDialog: `${api_host}/dialog/list`, listDialog: `${api_host}/dialog/list`,
setConversation: `${api_host}/conversation/set`, setConversation: `${api_host}/conversation/set`,
getConversation: `${api_host}/conversation/get`, getConversation: `${api_host}/conversation/get`,
getConversationSSE: `${api_host}/conversation/getsse`,
listConversation: `${api_host}/conversation/list`, listConversation: `${api_host}/conversation/list`,
removeConversation: `${api_host}/conversation/rm`, removeConversation: `${api_host}/conversation/rm`,
completeConversation: `${api_host}/conversation/completion`, completeConversation: `${api_host}/conversation/completion`,
@ -113,6 +114,7 @@ export default {
listTemplates: `${api_host}/canvas/templates`, listTemplates: `${api_host}/canvas/templates`,
listCanvas: `${api_host}/canvas/list`, listCanvas: `${api_host}/canvas/list`,
getCanvas: `${api_host}/canvas/get`, getCanvas: `${api_host}/canvas/get`,
getCanvasSSE: `${api_host}/canvas/getsse`,
removeCanvas: `${api_host}/canvas/rm`, removeCanvas: `${api_host}/canvas/rm`,
setCanvas: `${api_host}/canvas/set`, setCanvas: `${api_host}/canvas/set`,
resetCanvas: `${api_host}/canvas/reset`, resetCanvas: `${api_host}/canvas/reset`,