mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 20:59:00 +08:00
Fix/remove tsne position test (#5858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
parent
d468f8b75c
commit
0944ca9d91
@ -4,10 +4,6 @@ import time
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.manifold import TSNE
|
from sklearn.manifold import TSNE
|
||||||
|
|
||||||
from core.embedding.cached_embedding import CacheEmbedding
|
|
||||||
from core.model_manager import ModelManager
|
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
|
||||||
from core.rag.datasource.entity.embedding import Embeddings
|
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||||
@ -45,17 +41,6 @@ class HitTestingService:
|
|||||||
if not retrieval_model:
|
if not retrieval_model:
|
||||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||||
|
|
||||||
# get embedding model
|
|
||||||
model_manager = ModelManager()
|
|
||||||
embedding_model = model_manager.get_model_instance(
|
|
||||||
tenant_id=dataset.tenant_id,
|
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
provider=dataset.embedding_model_provider,
|
|
||||||
model=dataset.embedding_model
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings = CacheEmbedding(embedding_model)
|
|
||||||
|
|
||||||
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
@ -80,20 +65,10 @@ class HitTestingService:
|
|||||||
db.session.add(dataset_query)
|
db.session.add(dataset_query)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
|
return cls.compact_retrieve_response(dataset, query, all_documents)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: list[Document]):
|
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
|
||||||
text_embeddings = [
|
|
||||||
embeddings.embed_query(query)
|
|
||||||
]
|
|
||||||
|
|
||||||
text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))
|
|
||||||
|
|
||||||
tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)
|
|
||||||
|
|
||||||
query_position = tsne_position_data.pop(0)
|
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
records = []
|
records = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
@ -113,7 +88,6 @@ class HitTestingService:
|
|||||||
record = {
|
record = {
|
||||||
"segment": segment,
|
"segment": segment,
|
||||||
"score": document.metadata.get('score', None),
|
"score": document.metadata.get('score', None),
|
||||||
"tsne_position": tsne_position_data[i]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
records.append(record)
|
records.append(record)
|
||||||
@ -123,7 +97,6 @@ class HitTestingService:
|
|||||||
return {
|
return {
|
||||||
"query": {
|
"query": {
|
||||||
"content": query,
|
"content": query,
|
||||||
"tsne_position": query_position,
|
|
||||||
},
|
},
|
||||||
"records": records
|
"records": records
|
||||||
}
|
}
|
||||||
|
@ -2,51 +2,16 @@ import type { FC } from 'react'
|
|||||||
import React from 'react'
|
import React from 'react'
|
||||||
import cn from 'classnames'
|
import cn from 'classnames'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import ReactECharts from 'echarts-for-react'
|
|
||||||
import { SegmentIndexTag } from '../documents/detail/completed'
|
import { SegmentIndexTag } from '../documents/detail/completed'
|
||||||
import s from '../documents/detail/completed/style.module.css'
|
import s from '../documents/detail/completed/style.module.css'
|
||||||
import type { SegmentDetailModel } from '@/models/datasets'
|
import type { SegmentDetailModel } from '@/models/datasets'
|
||||||
import Divider from '@/app/components/base/divider'
|
import Divider from '@/app/components/base/divider'
|
||||||
|
|
||||||
type IScatterChartProps = {
|
|
||||||
data: Array<number[]>
|
|
||||||
curr: Array<number[]>
|
|
||||||
}
|
|
||||||
|
|
||||||
const ScatterChart: FC<IScatterChartProps> = ({ data, curr }) => {
|
|
||||||
const option = {
|
|
||||||
xAxis: {},
|
|
||||||
yAxis: {},
|
|
||||||
tooltip: {
|
|
||||||
trigger: 'item',
|
|
||||||
axisPointer: {
|
|
||||||
type: 'cross',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
series: [
|
|
||||||
{
|
|
||||||
type: 'effectScatter',
|
|
||||||
symbolSize: 5,
|
|
||||||
data: curr,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
type: 'scatter',
|
|
||||||
symbolSize: 5,
|
|
||||||
data,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
return (
|
|
||||||
<ReactECharts option={option} style={{ height: 380, width: 430 }} />
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
type IHitDetailProps = {
|
type IHitDetailProps = {
|
||||||
segInfo?: Partial<SegmentDetailModel> & { id: string }
|
segInfo?: Partial<SegmentDetailModel> & { id: string }
|
||||||
vectorInfo?: { curr: Array<number[]>; points: Array<number[]> }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => {
|
const HitDetail: FC<IHitDetailProps> = ({ segInfo }) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
|
|
||||||
const renderContent = () => {
|
const renderContent = () => {
|
||||||
@ -65,8 +30,8 @@ const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className='flex flex-row overflow-x-auto'>
|
<div className='overflow-x-auto'>
|
||||||
<div className="flex-1 bg-gray-25 p-6 min-w-[300px]">
|
<div className="bg-gray-25 p-6">
|
||||||
<div className="flex items-center">
|
<div className="flex items-center">
|
||||||
<SegmentIndexTag
|
<SegmentIndexTag
|
||||||
positionId={segInfo?.position || ''}
|
positionId={segInfo?.position || ''}
|
||||||
@ -94,20 +59,6 @@ const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => {
|
|||||||
})}
|
})}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex-1 bg-white p-6">
|
|
||||||
<div className="flex items-center">
|
|
||||||
<div className={cn(s.commonIcon, s.bezierCurveIcon)} />
|
|
||||||
<span className={s.numberInfo}>
|
|
||||||
{t('datasetDocuments.segment.vectorHash')}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
<div
|
|
||||||
className={cn(s.numberInfo, 'w-[400px] truncate text-gray-700 mt-1')}
|
|
||||||
>
|
|
||||||
{segInfo?.index_node_hash}
|
|
||||||
</div>
|
|
||||||
<ScatterChart data={vectorInfo?.points || []} curr={vectorInfo?.curr || []} />
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
'use client'
|
'use client'
|
||||||
import type { FC } from 'react'
|
import type { FC } from 'react'
|
||||||
import React, { useEffect, useMemo, useState } from 'react'
|
import React, { useEffect, useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import useSWR from 'swr'
|
import useSWR from 'swr'
|
||||||
import { omit } from 'lodash-es'
|
import { omit } from 'lodash-es'
|
||||||
@ -62,8 +62,6 @@ const HitTesting: FC<Props> = ({ datasetId }: Props) => {
|
|||||||
|
|
||||||
const total = recordsRes?.total || 0
|
const total = recordsRes?.total || 0
|
||||||
|
|
||||||
const points = useMemo(() => (hitResult?.records.map(v => [v.tsne_position.x, v.tsne_position.y]) || []), [hitResult?.records])
|
|
||||||
|
|
||||||
const onClickCard = (detail: HitTestingType) => {
|
const onClickCard = (detail: HitTestingType) => {
|
||||||
setCurrParagraph({ paraInfo: detail, showModal: true })
|
setCurrParagraph({ paraInfo: detail, showModal: true })
|
||||||
}
|
}
|
||||||
@ -194,17 +192,13 @@ const HitTesting: FC<Props> = ({ datasetId }: Props) => {
|
|||||||
</div>
|
</div>
|
||||||
</FloatRightContainer>
|
</FloatRightContainer>
|
||||||
<Modal
|
<Modal
|
||||||
className='!max-w-[960px] !p-0'
|
className='w-[520px] p-0'
|
||||||
closable
|
closable
|
||||||
onClose={() => setCurrParagraph({ showModal: false })}
|
onClose={() => setCurrParagraph({ showModal: false })}
|
||||||
isShow={currParagraph.showModal}
|
isShow={currParagraph.showModal}
|
||||||
>
|
>
|
||||||
{currParagraph.showModal && <HitDetail
|
{currParagraph.showModal && <HitDetail
|
||||||
segInfo={currParagraph.paraInfo?.segment}
|
segInfo={currParagraph.paraInfo?.segment}
|
||||||
vectorInfo={{
|
|
||||||
curr: [[currParagraph.paraInfo?.tsne_position?.x || 0, currParagraph.paraInfo?.tsne_position.y || 0]],
|
|
||||||
points,
|
|
||||||
}}
|
|
||||||
/>}
|
/>}
|
||||||
</Modal>
|
</Modal>
|
||||||
<Drawer isOpen={isShowModifyRetrievalModal} onClose={() => setIsShowModifyRetrievalModal(false)} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'>
|
<Drawer isOpen={isShowModifyRetrievalModal} onClose={() => setIsShowModifyRetrievalModal(false)} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user