From 0944ca9d91b72f380073e0be634fc6fce7ebb89c Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Tue, 2 Jul 2024 17:57:42 +0800 Subject: [PATCH] Fix/remove tsne position test (#5858) Co-authored-by: StyleZhang --- api/services/hit_testing_service.py | 31 +---------- .../datasets/hit-testing/hit-detail.tsx | 55 +------------------ .../components/datasets/hit-testing/index.tsx | 10 +--- 3 files changed, 7 insertions(+), 89 deletions(-) diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 8ff96c7337..0378370d88 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -4,10 +4,6 @@ import time import numpy as np from sklearn.manifold import TSNE -from core.embedding.cached_embedding import CacheEmbedding -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document from core.rag.retrieval.retrival_methods import RetrievalMethod @@ -45,17 +41,6 @@ class HitTestingService: if not retrieval_model: retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model - # get embedding model - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - provider=dataset.embedding_model_provider, - model=dataset.embedding_model - ) - - embeddings = CacheEmbedding(embedding_model) - all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], dataset_id=dataset.id, query=query, @@ -80,20 +65,10 @@ class HitTestingService: db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(dataset, embeddings, query, all_documents) + return cls.compact_retrieve_response(dataset, query, all_documents) @classmethod - def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: list[Document]): - text_embeddings = [ - embeddings.embed_query(query) - ] - - text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents])) - - tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings) - - query_position = tsne_position_data.pop(0) - + def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): i = 0 records = [] for document in documents: @@ -113,7 +88,6 @@ class HitTestingService: record = { "segment": segment, "score": document.metadata.get('score', None), - "tsne_position": tsne_position_data[i] } records.append(record) @@ -123,7 +97,6 @@ class HitTestingService: return { "query": { "content": query, - "tsne_position": query_position, }, "records": records } diff --git a/web/app/components/datasets/hit-testing/hit-detail.tsx b/web/app/components/datasets/hit-testing/hit-detail.tsx index 806662aeb5..5af022202b 100644 --- a/web/app/components/datasets/hit-testing/hit-detail.tsx +++ b/web/app/components/datasets/hit-testing/hit-detail.tsx @@ -2,51 +2,16 @@ import type { FC } from 'react' import React from 'react' import cn from 'classnames' import { useTranslation } from 'react-i18next' -import ReactECharts from 'echarts-for-react' import { SegmentIndexTag } from '../documents/detail/completed' import s from '../documents/detail/completed/style.module.css' import type { SegmentDetailModel } from '@/models/datasets' import Divider from '@/app/components/base/divider' -type IScatterChartProps = { - data: Array - curr: Array -} - -const ScatterChart: FC = ({ data, curr }) => { - const option = { - xAxis: {}, - yAxis: {}, - tooltip: { - trigger: 'item', - axisPointer: { - type: 'cross', - }, - }, - series: [ - { - type: 'effectScatter', - symbolSize: 5, - data: curr, - }, - { - type: 'scatter', - symbolSize: 5, - data, - }, - ], - } - return ( - - ) -} - type IHitDetailProps = { segInfo?: Partial & { id: string } - vectorInfo?: { curr: Array; points: Array } } -const HitDetail: FC = ({ segInfo, vectorInfo }) => { +const HitDetail: FC = ({ segInfo }) => { const { t } = useTranslation() const renderContent = () => { @@ -65,8 +30,8 @@ const HitDetail: FC = ({ segInfo, vectorInfo }) => { } return ( -
-
+
+
= ({ segInfo, vectorInfo }) => { })}
-
-
-
- - {t('datasetDocuments.segment.vectorHash')} - -
-
- {segInfo?.index_node_hash} -
- -
) } diff --git a/web/app/components/datasets/hit-testing/index.tsx b/web/app/components/datasets/hit-testing/index.tsx index bff13314b0..8c665b1889 100644 --- a/web/app/components/datasets/hit-testing/index.tsx +++ b/web/app/components/datasets/hit-testing/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React, { useEffect, useMemo, useState } from 'react' +import React, { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import useSWR from 'swr' import { omit } from 'lodash-es' @@ -62,8 +62,6 @@ const HitTesting: FC = ({ datasetId }: Props) => { const total = recordsRes?.total || 0 - const points = useMemo(() => (hitResult?.records.map(v => [v.tsne_position.x, v.tsne_position.y]) || []), [hitResult?.records]) - const onClickCard = (detail: HitTestingType) => { setCurrParagraph({ paraInfo: detail, showModal: true }) } @@ -194,17 +192,13 @@ const HitTesting: FC = ({ datasetId }: Props) => {
setCurrParagraph({ showModal: false })} isShow={currParagraph.showModal} > {currParagraph.showModal && } setIsShowModifyRetrievalModal(false)} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'>