From cf93d8d6e2407571108c23f4085a9ec34dc3c6af Mon Sep 17 00:00:00 2001 From: KVOJJJin Date: Fri, 28 Jul 2023 20:47:15 +0800 Subject: [PATCH] Feat: Q&A format segmentation support (#668) Co-authored-by: jyong <718720800@qq.com> Co-authored-by: StyleZhang --- api/controllers/console/datasets/datasets.py | 5 +- .../console/datasets/datasets_document.py | 12 +- .../console/datasets/datasets_segments.py | 93 ++++++- .../console/datasets/hit_testing.py | 1 + api/core/data_loader/loader/excel.py | 2 +- api/core/docstore/dataset_docstore.py | 7 +- api/core/generator/llm_generator.py | 33 ++- .../keyword_table_index.py | 10 + api/core/index/vector_index/test-embedding.py | 123 +++++++++ api/core/indexing_runner.py | 165 +++++++++++- api/core/prompt/prompts.py | 10 + api/core/tool/dataset_index_tool.py | 102 +++++++ api/core/tool/dataset_retriever_tool.py | 19 +- .../8d2d099ceb74_add_qa_model_support.py | 42 +++ api/models/dataset.py | 8 +- api/services/completion_service.py | 1 + api/services/dataset_service.py | 103 +++++++- api/tasks/create_segment_to_index_task.py | 102 +++++++ ...ask.py => enable_segment_to_index_task.py} | 12 +- api/tasks/generate_test_task.py | 24 ++ api/tasks/update_segment_index_task.py | 114 ++++++++ .../update_segment_keyword_index_task.py | 81 ++++++ .../base/auto-height-textarea/common.tsx | 52 ++++ .../public/common/message-chat-square.svg | 4 + .../assets/vender/line/general/edit-03.svg | 5 + .../assets/vender/line/general/hash-02.svg | 5 + .../src/public/common/MessageChatSquare.json | 37 +++ .../src/public/common/MessageChatSquare.tsx | 14 + .../base/icons/src/public/common/index.ts | 1 + .../icons/src/vender/line/general/Edit03.json | 39 +++ .../icons/src/vender/line/general/Edit03.tsx | 14 + .../icons/src/vender/line/general/Hash02.json | 38 +++ .../icons/src/vender/line/general/Hash02.tsx | 14 + .../icons/src/vender/line/general/index.ts | 2 + .../datasets/create/step-two/index.module.css | 2 +- .../datasets/create/step-two/index.tsx | 132 +++++++-- .../create/step-two/preview-item/index.tsx | 41 ++- .../detail/completed/InfiniteVirtualList.tsx | 58 ++-- .../detail/completed/SegmentCard.tsx | 250 ++++++++++-------- .../documents/detail/completed/index.tsx | 209 ++++++++++++--- .../detail/completed/style.module.css | 3 + .../documents/detail/embedding/index.tsx | 4 +- .../datasets/documents/detail/index.tsx | 15 +- .../documents/detail/new-segment-modal.tsx | 140 ++++++++++ .../components/datasets/documents/list.tsx | 15 +- .../datasets/hit-testing/hit-detail.tsx | 78 +++--- web/i18n/lang/dataset-creation.en.ts | 5 + web/i18n/lang/dataset-creation.zh.ts | 5 + web/i18n/lang/dataset-documents.en.ts | 9 + web/i18n/lang/dataset-documents.zh.ts | 9 + web/models/datasets.ts | 14 + web/service/datasets.ts | 29 +- 52 files changed, 2038 insertions(+), 274 deletions(-) create mode 100644 api/core/index/vector_index/test-embedding.py create mode 100644 api/core/tool/dataset_index_tool.py create mode 100644 api/migrations/versions/8d2d099ceb74_add_qa_model_support.py create mode 100644 api/tasks/create_segment_to_index_task.py rename api/tasks/{add_segment_to_index_task.py => enable_segment_to_index_task.py} (84%) create mode 100644 api/tasks/generate_test_task.py create mode 100644 api/tasks/update_segment_index_task.py create mode 100644 api/tasks/update_segment_keyword_index_task.py create mode 100644 web/app/components/base/auto-height-textarea/common.tsx create mode 100644 web/app/components/base/icons/assets/public/common/message-chat-square.svg create mode 100644 web/app/components/base/icons/assets/vender/line/general/edit-03.svg create mode 100644 web/app/components/base/icons/assets/vender/line/general/hash-02.svg create mode 100644 web/app/components/base/icons/src/public/common/MessageChatSquare.json create mode 100644 web/app/components/base/icons/src/public/common/MessageChatSquare.tsx create mode 100644 web/app/components/base/icons/src/vender/line/general/Edit03.json create mode 100644 web/app/components/base/icons/src/vender/line/general/Edit03.tsx create mode 100644 web/app/components/base/icons/src/vender/line/general/Hash02.json create mode 100644 web/app/components/base/icons/src/vender/line/general/Hash02.tsx create mode 100644 web/app/components/datasets/documents/detail/new-segment-modal.tsx diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 4c880e5243..1881103447 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -220,6 +220,7 @@ class DatasetIndexingEstimateApi(Resource): parser = reqparse.RequestParser() parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') + parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) @@ -234,12 +235,12 @@ class DatasetIndexingEstimateApi(Resource): raise NotFound("File not found.") indexing_runner = IndexingRunner() - response = indexing_runner.file_indexing_estimate(file_details, args['process_rule']) + response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form']) elif args['info_list']['data_source_type'] == 'notion_import': indexing_runner = IndexingRunner() response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'], - args['process_rule']) + args['process_rule'], args['doc_form']) else: raise ValueError('Data source type not support') return response, 200 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index e165d2130a..02ddfbf467 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -60,6 +60,7 @@ document_fields = { 'display_status': fields.String, 'word_count': fields.Integer, 'hit_count': fields.Integer, + 'doc_form': fields.String, } document_with_segments_fields = { @@ -86,6 +87,7 @@ document_with_segments_fields = { 'total_segments': fields.Integer } + class DocumentResource(Resource): def get_document(self, dataset_id: str, document_id: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) @@ -269,6 +271,7 @@ class DatasetDocumentListApi(Resource): parser.add_argument('process_rule', type=dict, required=False, location='json') parser.add_argument('duplicate', type=bool, nullable=False, location='json') parser.add_argument('original_document_id', type=str, required=False, location='json') + parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') args = parser.parse_args() if not dataset.indexing_technique and not args['indexing_technique']: @@ -313,6 +316,7 @@ class DatasetInitApi(Resource): nullable=False, location='json') parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') + parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') args = parser.parse_args() # validate args @@ -488,6 +492,8 @@ class DocumentBatchIndexingStatusApi(DocumentResource): DocumentSegment.status != 're_segment').count() document.completed_segments = completed_segments document.total_segments = total_segments + if document.is_paused: + document.indexing_status = 'paused' documents_status.append(marshal(document, self.document_status_fields)) data = { 'data': documents_status @@ -583,7 +589,8 @@ class DocumentDetailApi(DocumentResource): 'segment_count': document.segment_count, 'average_segment_length': document.average_segment_length, 'hit_count': document.hit_count, - 'display_status': document.display_status + 'display_status': document.display_status, + 'doc_form': document.doc_form } else: process_rules = DatasetService.get_process_rules(dataset_id) @@ -614,7 +621,8 @@ class DocumentDetailApi(DocumentResource): 'segment_count': document.segment_count, 'average_segment_length': document.average_segment_length, 'hit_count': document.hit_count, - 'display_status': document.display_status + 'display_status': document.display_status, + 'doc_form': document.doc_form } return response, 200 diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index ffa01cbb48..57abe97a0b 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -15,8 +15,8 @@ from extensions.ext_redis import redis_client from models.dataset import DocumentSegment from libs.helper import TimestampField -from services.dataset_service import DatasetService, DocumentService -from tasks.add_segment_to_index_task import add_segment_to_index_task +from services.dataset_service import DatasetService, DocumentService, SegmentService +from tasks.enable_segment_to_index_task import enable_segment_to_index_task from tasks.remove_segment_from_index_task import remove_segment_from_index_task segment_fields = { @@ -24,6 +24,7 @@ segment_fields = { 'position': fields.Integer, 'document_id': fields.String, 'content': fields.String, + 'answer': fields.String, 'word_count': fields.Integer, 'tokens': fields.Integer, 'keywords': fields.List(fields.String), @@ -125,6 +126,7 @@ class DatasetDocumentSegmentListApi(Resource): return { 'data': marshal(segments, segment_fields), + 'doc_form': document.doc_form, 'has_more': has_more, 'limit': limit, 'total': total @@ -180,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource): # Set cache to prevent indexing the same segment multiple times redis_client.setex(indexing_cache_key, 600, 1) - add_segment_to_index_task.delay(segment.id) + enable_segment_to_index_task.delay(segment.id) return {'result': 'success'}, 200 elif action == "disable": @@ -202,7 +204,92 @@ class DatasetDocumentSegmentApi(Resource): raise InvalidActionError() +class DatasetDocumentSegmentAddApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, dataset_id, document_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound('Document not found.') + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument('content', type=str, required=True, nullable=False, location='json') + parser.add_argument('answer', type=str, required=False, nullable=True, location='json') + parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + args = parser.parse_args() + SegmentService.segment_create_args_validate(args, document) + segment = SegmentService.create_segment(args, document) + return { + 'data': marshal(segment, segment_fields), + 'doc_form': document.doc_form + }, 200 + + +class DatasetDocumentSegmentUpdateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def patch(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound('Document not found.') + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), + DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound('Segment not found.') + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument('content', type=str, required=True, nullable=False, location='json') + parser.add_argument('answer', type=str, required=False, nullable=True, location='json') + parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + args = parser.parse_args() + SegmentService.segment_create_args_validate(args, document) + segment = SegmentService.update_segment(args, segment, document) + return { + 'data': marshal(segment, segment_fields), + 'doc_form': document.doc_form + }, 200 + + api.add_resource(DatasetDocumentSegmentListApi, '/datasets//documents//segments') api.add_resource(DatasetDocumentSegmentApi, '/datasets//segments//') +api.add_resource(DatasetDocumentSegmentAddApi, + '/datasets//documents//segment') +api.add_resource(DatasetDocumentSegmentUpdateApi, + '/datasets//documents//segments/') + diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index c1ccb30dd6..f627949d33 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -28,6 +28,7 @@ segment_fields = { 'position': fields.Integer, 'document_id': fields.String, 'content': fields.String, + 'answer': fields.String, 'word_count': fields.Integer, 'tokens': fields.Integer, 'keywords': fields.List(fields.String), diff --git a/api/core/data_loader/loader/excel.py b/api/core/data_loader/loader/excel.py index 202068432a..0d4ff02242 100644 --- a/api/core/data_loader/loader/excel.py +++ b/api/core/data_loader/loader/excel.py @@ -39,7 +39,7 @@ class ExcelLoader(BaseLoader): row_dict = dict(zip(keys, list(map(str, row)))) row_dict = {k: v for k, v in row_dict.items() if v} item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items()) - document = Document(page_content=item) + document = Document(page_content=item, metadata={'source': self._file_path}) data.append(document) return data diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index b8af3bf01b..016e711378 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -68,7 +68,7 @@ class DatesetDocumentStore: self, docs: Sequence[Document], allow_update: bool = True ) -> None: max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document == self._document_id + DocumentSegment.document_id == self._document_id ).scalar() if max_position is None: @@ -105,9 +105,14 @@ class DatesetDocumentStore: tokens=tokens, created_by=self._user_id, ) + if 'answer' in doc.metadata and doc.metadata['answer']: + segment_document.answer = doc.metadata.pop('answer', '') + db.session.add(segment_document) else: segment_document.content = doc.page_content + if 'answer' in doc.metadata and doc.metadata['answer']: + segment_document.answer = doc.metadata.pop('answer', '') segment_document.index_node_hash = doc.metadata['doc_hash'] segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index 9c6dbdfc8f..a5294add23 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -2,7 +2,7 @@ import logging from langchain import PromptTemplate from langchain.chat_models.base import BaseChatModel -from langchain.schema import HumanMessage, OutputParserException, BaseMessage +from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage from core.constant import llm_constant from core.llm.llm_builder import LLMBuilder @@ -12,8 +12,8 @@ from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorO from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate -from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT - +from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \ + GENERATOR_QA_PROMPT # gpt-3.5-turbo works not well generate_base_model = 'text-davinci-003' @@ -31,7 +31,8 @@ class LLMGenerator: llm: StreamableOpenAI = LLMBuilder.to_llm( tenant_id=tenant_id, model_name='gpt-3.5-turbo', - max_tokens=50 + max_tokens=50, + timeout=600 ) if isinstance(llm, BaseChatModel): @@ -185,3 +186,27 @@ class LLMGenerator: } return rule_config + + @classmethod + async def generate_qa_document(cls, llm: StreamableOpenAI, query): + prompt = GENERATOR_QA_PROMPT + + + if isinstance(llm, BaseChatModel): + prompt = [SystemMessage(content=prompt), HumanMessage(content=query)] + + response = llm.generate([prompt]) + answer = response.generations[0][0].text + return answer.strip() + + @classmethod + def generate_qa_document_sync(cls, llm: StreamableOpenAI, query): + prompt = GENERATOR_QA_PROMPT + + + if isinstance(llm, BaseChatModel): + prompt = [SystemMessage(content=prompt), HumanMessage(content=query)] + + response = llm.generate([prompt]) + answer = response.generations[0][0].text + return answer.strip() diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/index/keyword_table_index/keyword_table_index.py index 1a205cd572..34ee7c8ff7 100644 --- a/api/core/index/keyword_table_index/keyword_table_index.py +++ b/api/core/index/keyword_table_index/keyword_table_index.py @@ -205,6 +205,16 @@ class KeywordTableIndex(BaseIndex): document_segment.keywords = keywords db.session.commit() + def create_segment_keywords(self, node_id: str, keywords: List[str]): + keyword_table = self._get_dataset_keyword_table() + self._update_segment_keywords(node_id, keywords) + keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) + self._save_dataset_keyword_table(keyword_table) + + def update_segment_keywords_index(self, node_id: str, keywords: List[str]): + keyword_table = self._get_dataset_keyword_table() + keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) + self._save_dataset_keyword_table(keyword_table) class KeywordTableRetriever(BaseRetriever, BaseModel): index: KeywordTableIndex diff --git a/api/core/index/vector_index/test-embedding.py b/api/core/index/vector_index/test-embedding.py new file mode 100644 index 0000000000..dc23d48f5d --- /dev/null +++ b/api/core/index/vector_index/test-embedding.py @@ -0,0 +1,123 @@ +import numpy as np +import sklearn.decomposition +import pickle +import time + + +# Apply 'Algorithm 1' to the ada-002 embeddings to make them isotropic, taken from the paper: +# ALL-BUT-THE-TOP: SIMPLE AND EFFECTIVE POST- PROCESSING FOR WORD REPRESENTATIONS +# Jiaqi Mu, Pramod Viswanath + +# This uses Principal Component Analysis (PCA) to 'evenly distribute' the embedding vectors (make them isotropic) +# For more information on PCA, see https://jamesmccaffrey.wordpress.com/2021/07/16/computing-pca-using-numpy-without-scikit/ + + +# get the file pointer of the pickle containing the embeddings +fp = open('/path/to/your/data/Embedding-Latest.pkl', 'rb') + + +# the embedding data here is a dict consisting of key / value pairs +# the key is the hash of the message (SHA3-256), the value is the embedding from ada-002 (array of dimension 1536) +# the hash can be used to lookup the orignal text in a database +E = pickle.load(fp) # load the data into memory + +# seperate the keys (hashes) and values (embeddings) into seperate vectors +K = list(E.keys()) # vector of all the hash values +X = np.array(list(E.values())) # vector of all the embeddings, converted to numpy arrays + + +# list the total number of embeddings +# this can be truncated if there are too many embeddings to do PCA on +print(f"Total number of embeddings: {len(X)}") + +# get dimension of embeddings, used later +Dim = len(X[0]) + +# flash out the first few embeddings +print("First two embeddings are: ") +print(X[0]) +print(f"First embedding length: {len(X[0])}") +print(X[1]) +print(f"Second embedding length: {len(X[1])}") + + +# compute the mean of all the embeddings, and flash the result +mu = np.mean(X, axis=0) # same as mu in paper +print(f"Mean embedding vector: {mu}") +print(f"Mean embedding vector length: {len(mu)}") + + +# subtract the mean vector from each embedding vector ... vectorized in numpy +X_tilde = X - mu # same as v_tilde(w) in paper + + + +# do the heavy lifting of extracting the principal components +# note that this is a function of the embeddings you currently have here, and this set may grow over time +# therefore the PCA basis vectors may change over time, and your final isotropic embeddings may drift over time +# but the drift should stabilize after you have extracted enough embedding data to characterize the nature of the embedding engine +print(f"Performing PCA on the normalized embeddings ...") +pca = sklearn.decomposition.PCA() # new object +TICK = time.time() # start timer +pca.fit(X_tilde) # do the heavy lifting! +TOCK = time.time() # end timer +DELTA = TOCK - TICK + +print(f"PCA finished in {DELTA} seconds ...") + +# dimensional reduction stage (the only hyperparameter) +# pick max dimension of PCA components to express embddings +# in general this is some integer less than or equal to the dimension of your embeddings +# it could be set as a high percentile, say 95th percentile of pca.explained_variance_ratio_ +# but just hardcoding a constant here +D = 15 # hyperparameter on dimension (out of 1536 for ada-002), paper recommeds D = Dim/100 + + +# form the set of v_prime(w), which is the final embedding +# this could be vectorized in numpy to speed it up, but coding it directly here in a double for-loop to avoid errors and to be transparent +E_prime = dict() # output dict of the new embeddings +N = len(X_tilde) +N10 = round(N/10) +U = pca.components_ # set of PCA basis vectors, sorted by most significant to least significant +print(f"Shape of full set of PCA componenents {U.shape}") +U = U[0:D,:] # take the top D dimensions (or take them all if D is the size of the embedding vector) +print(f"Shape of downselected PCA componenents {U.shape}") +for ii in range(N): + v_tilde = X_tilde[ii] + v = X[ii] + v_projection = np.zeros(Dim) # start to build the projection + # project the original embedding onto the PCA basis vectors, use only first D dimensions + for jj in range(D): + u_jj = U[jj,:] # vector + v_jj = np.dot(u_jj,v) # scaler + v_projection += v_jj*u_jj # vector + v_prime = v_tilde - v_projection # final embedding vector + v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector + E_prime[K[ii]] = v_prime + + if (ii%N10 == 0) or (ii == N-1): + print(f"Finished with {ii+1} embeddings out of {N} ({round(100*ii/N)}% done)") + + +# save as new pickle +print("Saving new pickle ...") +embeddingName = '/path/to/your/data/Embedding-Latest-Isotropic.pkl' +with open(embeddingName, 'wb') as f: # Python 3: open(..., 'wb') + pickle.dump([E_prime,mu,U], f) + print(embeddingName) + +print("Done!") + +# When working with live data with a new embedding from ada-002, be sure to tranform it first with this function before comparing it +# +def projectEmbedding(v,mu,U): + v = np.array(v) + v_tilde = v - mu + v_projection = np.zeros(len(v)) # start to build the projection + # project the original embedding onto the PCA basis vectors, use only first D dimensions + for u in U: + v_jj = np.dot(u,v) # scaler + v_projection += v_jj*u # vector + v_prime = v_tilde - v_projection # final embedding vector + v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector + return v_prime \ No newline at end of file diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 30ad5a99df..9c6b723e14 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -1,13 +1,20 @@ +import asyncio +import concurrent import datetime import json import logging import re +import threading import time import uuid +from multiprocessing import Process from typing import Optional, List, cast -from flask import current_app +import openai +from billiard.pool import Pool +from flask import current_app, Flask from flask_login import current_user +from gevent.threadpool import ThreadPoolExecutor from langchain.embeddings import OpenAIEmbeddings from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter @@ -16,11 +23,13 @@ from core.data_loader.file_extractor import FileExtractor from core.data_loader.loader.notion import NotionLoader from core.docstore.dataset_docstore import DatesetDocumentStore from core.embedding.cached_embedding import CacheEmbedding +from core.generator.llm_generator import LLMGenerator from core.index.index import IndexBuilder from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.vector_index.vector_index import VectorIndex from core.llm.error import ProviderTokenNotInitError from core.llm.llm_builder import LLMBuilder +from core.llm.streamable_open_ai import StreamableOpenAI from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter from core.llm.token_calculator import TokenCalculator from extensions.ext_database import db @@ -70,7 +79,13 @@ class IndexingRunner: dataset_document=dataset_document, processing_rule=processing_rule ) - + # new_documents = [] + # for document in documents: + # response = LLMGenerator.generate_qa_document(dataset.tenant_id, document.page_content) + # document_qa_list = self.format_split_text(response) + # for result in document_qa_list: + # document = Document(page_content=result['question'], metadata={'source': result['answer']}) + # new_documents.append(document) # build index self._build_index( dataset=dataset, @@ -91,6 +106,22 @@ class IndexingRunner: dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() + def format_split_text(self, text): + regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" + matches = re.findall(regex, text, re.MULTILINE) + + result = [] + for match in matches: + q = match[0] + a = match[1] + if q and a: + result.append({ + "question": q, + "answer": re.sub(r"\n\s*", "\n", a.strip()) + }) + + return result + def run_in_splitting_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is splitting.""" try: @@ -205,7 +236,8 @@ class IndexingRunner: dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() - def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict: + def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict, + doc_form: str = None) -> dict: """ Estimate the indexing for the document. """ @@ -225,7 +257,7 @@ class IndexingRunner: splitter = self._get_splitter(processing_rule) # split to documents - documents = self._split_to_documents( + documents = self._split_to_documents_for_estimate( text_docs=text_docs, splitter=splitter, processing_rule=processing_rule @@ -237,7 +269,25 @@ class IndexingRunner: tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, self.filter_string(document.page_content)) - + if doc_form and doc_form == 'qa_model': + if len(preview_texts) > 0: + # qa model document + llm: StreamableOpenAI = LLMBuilder.to_llm( + tenant_id=current_user.current_tenant_id, + model_name='gpt-3.5-turbo', + max_tokens=2000 + ) + response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0]) + document_qa_list = self.format_split_text(response) + return { + "total_segments": total_segments * 20, + "tokens": total_segments * 2000, + "total_price": '{:f}'.format( + TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), + "currency": TokenCalculator.get_currency(self.embedding_model_name), + "qa_preview": document_qa_list, + "preview": preview_texts + } return { "total_segments": total_segments, "tokens": tokens, @@ -246,7 +296,7 @@ class IndexingRunner: "preview": preview_texts } - def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict) -> dict: + def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict: """ Estimate the indexing for the document. """ @@ -285,7 +335,7 @@ class IndexingRunner: splitter = self._get_splitter(processing_rule) # split to documents - documents = self._split_to_documents( + documents = self._split_to_documents_for_estimate( text_docs=documents, splitter=splitter, processing_rule=processing_rule @@ -296,7 +346,25 @@ class IndexingRunner: preview_texts.append(document.page_content) tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) - + if doc_form and doc_form == 'qa_model': + if len(preview_texts) > 0: + # qa model document + llm: StreamableOpenAI = LLMBuilder.to_llm( + tenant_id=current_user.current_tenant_id, + model_name='gpt-3.5-turbo', + max_tokens=2000 + ) + response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0]) + document_qa_list = self.format_split_text(response) + return { + "total_segments": total_segments * 20, + "tokens": total_segments * 2000, + "total_price": '{:f}'.format( + TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), + "currency": TokenCalculator.get_currency(self.embedding_model_name), + "qa_preview": document_qa_list, + "preview": preview_texts + } return { "total_segments": total_segments, "tokens": tokens, @@ -391,7 +459,9 @@ class IndexingRunner: documents = self._split_to_documents( text_docs=text_docs, splitter=splitter, - processing_rule=processing_rule + processing_rule=processing_rule, + tenant_id=dataset.tenant_id, + document_form=dataset_document.doc_form ) # save node to document segment @@ -428,7 +498,64 @@ class IndexingRunner: return documents def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule) -> List[Document]: + processing_rule: DatasetProcessRule, tenant_id: str, document_form: str) -> List[Document]: + """ + Split the text documents into nodes. + """ + all_documents = [] + for text_doc in text_docs: + # document clean + document_text = self._document_clean(text_doc.page_content, processing_rule) + text_doc.page_content = document_text + + # parse document to nodes + documents = splitter.split_documents([text_doc]) + split_documents = [] + llm: StreamableOpenAI = LLMBuilder.to_llm( + tenant_id=tenant_id, + model_name='gpt-3.5-turbo', + max_tokens=2000 + ) + self.format_document(llm, documents, split_documents, document_form) + all_documents.extend(split_documents) + + return all_documents + + def format_document(self, llm: StreamableOpenAI, documents: List[Document], split_documents: List, document_form: str): + for document_node in documents: + format_documents = [] + if document_node.page_content is None or not document_node.page_content.strip(): + return format_documents + if document_form == 'text_model': + # text model document + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + + document_node.metadata['doc_id'] = doc_id + document_node.metadata['doc_hash'] = hash + + format_documents.append(document_node) + elif document_form == 'qa_model': + try: + # qa model document + response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content) + document_qa_list = self.format_split_text(response) + qa_documents = [] + for result in document_qa_list: + qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(result['question']) + qa_document.metadata['answer'] = result['answer'] + qa_document.metadata['doc_id'] = doc_id + qa_document.metadata['doc_hash'] = hash + qa_documents.append(qa_document) + format_documents.extend(qa_documents) + except Exception: + continue + split_documents.extend(format_documents) + + def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter, + processing_rule: DatasetProcessRule) -> List[Document]: """ Split the text documents into nodes. """ @@ -445,7 +572,6 @@ class IndexingRunner: for document in documents: if document.page_content is None or not document.page_content.strip(): continue - doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document.page_content) @@ -487,6 +613,23 @@ class IndexingRunner: return text + def format_split_text(self, text): + regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式 + matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果 + + result = [] # 存储最终的结果 + for match in matches: + q = match[0] + a = match[1] + if q and a: + # 如果Q和A都存在,就将其添加到结果中 + result.append({ + "question": q, + "answer": re.sub(r"\n\s*", "\n", a.strip()) + }) + + return result + def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None: """ Build the index for the document. diff --git a/api/core/prompt/prompts.py b/api/core/prompt/prompts.py index fcfad5465e..781e8164cb 100644 --- a/api/core/prompt/prompts.py +++ b/api/core/prompt/prompts.py @@ -43,6 +43,16 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( "[\"question1\",\"question2\",\"question3\"]\n" ) +GENERATOR_QA_PROMPT = ( + "Please respond according to the language of the user's input text. If the text is in language [A], you must also reply in language [A].\n" + 'Step 1: Understand and summarize the main content of this text.\n' + 'Step 2: What key information or concepts are mentioned in this text?\n' + 'Step 3: Decompose or combine multiple pieces of information and concepts.\n' + 'Step 4: Generate 20 questions and answers based on these key information and concepts.' + 'The questions should be clear and detailed, and the answers should be detailed and complete.\n' + "Answer in the following format: Q1:\nA1:\nQ2:\nA2:...\n" +) + RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ the model prompt that best suits the input. You will be provided with the prompt, variables, and an opening statement. diff --git a/api/core/tool/dataset_index_tool.py b/api/core/tool/dataset_index_tool.py new file mode 100644 index 0000000000..c459ebaf13 --- /dev/null +++ b/api/core/tool/dataset_index_tool.py @@ -0,0 +1,102 @@ +from flask import current_app +from langchain.embeddings import OpenAIEmbeddings +from langchain.tools import BaseTool + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.embedding.cached_embedding import CacheEmbedding +from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig +from core.index.vector_index.vector_index import VectorIndex +from core.llm.llm_builder import LLMBuilder +from models.dataset import Dataset, DocumentSegment + + +class DatasetTool(BaseTool): + """Tool for querying a Dataset.""" + + dataset: Dataset + k: int = 2 + + def _run(self, tool_input: str) -> str: + if self.dataset.indexing_technique == "economy": + # use keyword table query + kw_table_index = KeywordTableIndex( + dataset=self.dataset, + config=KeywordTableConfig( + max_keywords_per_chunk=5 + ) + ) + + documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k}) + return str("\n".join([document.page_content for document in documents])) + else: + model_credentials = LLMBuilder.get_model_credentials( + tenant_id=self.dataset.tenant_id, + model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'), + model_name='text-embedding-ada-002' + ) + + embeddings = CacheEmbedding(OpenAIEmbeddings( + **model_credentials + )) + + vector_index = VectorIndex( + dataset=self.dataset, + config=current_app.config, + embeddings=embeddings + ) + + documents = vector_index.search( + tool_input, + search_type='similarity', + search_kwargs={ + 'k': self.k + } + ) + + hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id) + hit_callback.on_tool_end(documents) + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in documents] + segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + + if segments: + for segment in segments: + if segment.answer: + document_context_list.append(segment.answer) + else: + document_context_list.append(segment.content) + + return str("\n".join(document_context_list)) + + async def _arun(self, tool_input: str) -> str: + model_credentials = LLMBuilder.get_model_credentials( + tenant_id=self.dataset.tenant_id, + model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'), + model_name='text-embedding-ada-002' + ) + + embeddings = CacheEmbedding(OpenAIEmbeddings( + **model_credentials + )) + + vector_index = VectorIndex( + dataset=self.dataset, + config=current_app.config, + embeddings=embeddings + ) + + documents = await vector_index.asearch( + tool_input, + search_type='similarity', + search_kwargs={ + 'k': 10 + } + ) + + hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id) + hit_callback.on_tool_end(documents) + return str("\n".join([document.page_content for document in documents])) diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index d728528267..45c5a9f226 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -12,7 +12,7 @@ from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex from core.index.vector_index.vector_index import VectorIndex from core.llm.llm_builder import LLMBuilder from extensions.ext_database import db -from models.dataset import Dataset +from models.dataset import Dataset, DocumentSegment class DatasetRetrieverToolInput(BaseModel): @@ -69,6 +69,7 @@ class DatasetRetrieverTool(BaseTool): ) documents = kw_table_index.search(query, search_kwargs={'k': self.k}) + return str("\n".join([document.page_content for document in documents])) else: model_credentials = LLMBuilder.get_model_credentials( tenant_id=dataset.tenant_id, @@ -99,8 +100,22 @@ class DatasetRetrieverTool(BaseTool): hit_callback = DatasetIndexToolCallbackHandler(dataset.id) hit_callback.on_tool_end(documents) + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in documents] + segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() - return str("\n".join([document.page_content for document in documents])) + if segments: + for segment in segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}') + else: + document_context_list.append(segment.content) + + return str("\n".join(document_context_list)) async def _arun(self, tool_input: str) -> str: raise NotImplementedError() diff --git a/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py b/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py new file mode 100644 index 0000000000..e0915a5fb1 --- /dev/null +++ b/api/migrations/versions/8d2d099ceb74_add_qa_model_support.py @@ -0,0 +1,42 @@ +"""add_qa_model_support + +Revision ID: 8d2d099ceb74 +Revises: a5b56fb053ef +Create Date: 2023-07-18 15:25:15.293438 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '8d2d099ceb74' +down_revision = '7ce5a52e4eee' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.add_column(sa.Column('answer', sa.Text(), nullable=True)) + batch_op.add_column(sa.Column('updated_by', postgresql.UUID(), nullable=True)) + batch_op.add_column(sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False)) + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.add_column(sa.Column('doc_form', sa.String(length=255), server_default=sa.text("'text_model'::character varying"), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.drop_column('doc_form') + + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.drop_column('updated_at') + batch_op.drop_column('updated_by') + batch_op.drop_column('answer') + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 345eea5f47..b63b898df4 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -206,6 +206,8 @@ class Document(db.Model): server_default=db.text('CURRENT_TIMESTAMP(0)')) doc_type = db.Column(db.String(40), nullable=True) doc_metadata = db.Column(db.JSON, nullable=True) + doc_form = db.Column(db.String( + 255), nullable=False, server_default=db.text("'text_model'::character varying")) DATA_SOURCES = ['upload_file', 'notion_import'] @@ -308,6 +310,7 @@ class DocumentSegment(db.Model): document_id = db.Column(UUID, nullable=False) position = db.Column(db.Integer, nullable=False) content = db.Column(db.Text, nullable=False) + answer = db.Column(db.Text, nullable=True) word_count = db.Column(db.Integer, nullable=False) tokens = db.Column(db.Integer, nullable=False) @@ -327,6 +330,9 @@ class DocumentSegment(db.Model): created_by = db.Column(UUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by = db.Column(UUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, + server_default=db.text('CURRENT_TIMESTAMP(0)')) indexing_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) error = db.Column(db.Text, nullable=True) @@ -442,4 +448,4 @@ class Embedding(db.Model): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) def get_embedding(self) -> list[float]: - return pickle.loads(self.embedding) + return pickle.loads(self.embedding) \ No newline at end of file diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 6b8ee520eb..c081d8ec08 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -201,6 +201,7 @@ class CompletionService: conversation = db.session.query(Conversation).filter_by(id=conversation.id).first() # run + Completion.generate( task_id=generate_task_id, app=app_model, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 182adc6c78..c0a649544a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -3,16 +3,20 @@ import logging import datetime import time import random +import uuid from typing import Optional, List from flask import current_app +from sqlalchemy import func +from core.llm.token_calculator import TokenCalculator from extensions.ext_redis import redis_client from flask_login import current_user from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db +from libs import helper from models.account import Account from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment from models.model import UploadFile @@ -25,6 +29,10 @@ from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_update_task import document_indexing_update_task +from tasks.create_segment_to_index_task import create_segment_to_index_task +from tasks.update_segment_index_task import update_segment_index_task +from tasks.update_segment_keyword_index_task\ + import update_segment_keyword_index_task class DatasetService: @@ -308,6 +316,7 @@ class DocumentService: ).all() return documents + @staticmethod def get_document_file_detail(file_id: str): file_detail = db.session.query(UploadFile). \ @@ -440,6 +449,7 @@ class DocumentService: } document = DocumentService.save_document(dataset, dataset_process_rule.id, document_data["data_source"]["type"], + document_data["doc_form"], data_source_info, created_from, position, account, file_name, batch) db.session.add(document) @@ -484,6 +494,7 @@ class DocumentService: } document = DocumentService.save_document(dataset, dataset_process_rule.id, document_data["data_source"]["type"], + document_data["doc_form"], data_source_info, created_from, position, account, page['page_name'], batch) # if page['type'] == 'database': @@ -514,8 +525,9 @@ class DocumentService: return documents, batch @staticmethod - def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, data_source_info: dict, - created_from: str, position: int, account: Account, name: str, batch: str): + def save_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str, + data_source_info: dict, created_from: str, position: int, account: Account, name: str, + batch: str): document = Document( tenant_id=dataset.tenant_id, dataset_id=dataset.id, @@ -527,6 +539,7 @@ class DocumentService: name=name, created_from=created_from, created_by=account.id, + doc_form=document_form ) return document @@ -618,6 +631,7 @@ class DocumentService: document.splitting_completed_at = None document.updated_at = datetime.datetime.utcnow() document.created_from = created_from + document.doc_form = document_data['doc_form'] db.session.add(document) db.session.commit() # update document segment @@ -667,7 +681,7 @@ class DocumentService: DocumentService.data_source_args_validate(args) DocumentService.process_rule_args_validate(args) else: - if ('data_source' not in args and not args['data_source'])\ + if ('data_source' not in args and not args['data_source']) \ and ('process_rule' not in args and not args['process_rule']): raise ValueError("Data source or Process rule is required") else: @@ -694,10 +708,12 @@ class DocumentService: raise ValueError("Data source info is required") if args['data_source']['type'] == 'upload_file': - if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['file_info_list']: + if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ + 'file_info_list']: raise ValueError("File source info is required") if args['data_source']['type'] == 'notion_import': - if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list']['notion_info_list']: + if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ + 'notion_info_list']: raise ValueError("Notion source info is required") @classmethod @@ -843,3 +859,80 @@ class DocumentService: if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): raise ValueError("Process rule segmentation max_tokens is invalid") + + +class SegmentService: + @classmethod + def segment_create_args_validate(cls, args: dict, document: Document): + if document.doc_form == 'qa_model': + if 'answer' not in args or not args['answer']: + raise ValueError("Answer is required") + + @classmethod + def create_segment(cls, args: dict, document: Document): + content = args['content'] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + # calc embedding use tokens + tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content) + max_position = db.session.query(func.max(DocumentSegment.position)).filter( + DocumentSegment.document_id == document.id + ).scalar() + segment_document = DocumentSegment( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + created_by=current_user.id + ) + if document.doc_form == 'qa_model': + segment_document.answer = args['answer'] + + db.session.add(segment_document) + db.session.commit() + indexing_cache_key = 'segment_{}_indexing'.format(segment_document.id) + redis_client.setex(indexing_cache_key, 600, 1) + create_segment_to_index_task.delay(segment_document.id, args['keywords']) + return segment_document + + @classmethod + def update_segment(cls, args: dict, segment: DocumentSegment, document: Document): + indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise ValueError("Segment is indexing, please try again later") + content = args['content'] + if segment.content == content: + if document.doc_form == 'qa_model': + segment.answer = args['answer'] + if args['keywords']: + segment.keywords = args['keywords'] + db.session.add(segment) + db.session.commit() + # update segment index task + redis_client.setex(indexing_cache_key, 600, 1) + update_segment_keyword_index_task.delay(segment.id) + else: + segment_hash = helper.generate_text_hash(content) + # calc embedding use tokens + tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content) + segment.content = content + segment.index_node_hash = segment_hash + segment.word_count = len(content) + segment.tokens = tokens + segment.status = 'updating' + segment.updated_by = current_user.id + segment.updated_at = datetime.datetime.utcnow() + if document.doc_form == 'qa_model': + segment.answer = args['answer'] + db.session.add(segment) + db.session.commit() + # update segment index task + redis_client.setex(indexing_cache_key, 600, 1) + update_segment_index_task.delay(segment.id, args['keywords']) + return segment diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py new file mode 100644 index 0000000000..0f7bf6ebaa --- /dev/null +++ b/api/tasks/create_segment_to_index_task.py @@ -0,0 +1,102 @@ +import datetime +import logging +import time +from typing import Optional, List + +import click +from celery import shared_task +from langchain.schema import Document +from werkzeug.exceptions import NotFound + +from core.index.index import IndexBuilder +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DocumentSegment + + +@shared_task +def create_segment_to_index_task(segment_id: str, keywords: Optional[List[str]] = None): + """ + Async create segment to index + :param segment_id: + :param keywords: + Usage: create_segment_to_index_task.delay(segment_id) + """ + logging.info(click.style('Start create segment to index: {}'.format(segment_id), fg='green')) + start_at = time.perf_counter() + + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + if not segment: + raise NotFound('Segment not found') + + if segment.status != 'waiting': + return + + indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + + try: + # update segment status to indexing + update_params = { + DocumentSegment.status: "indexing", + DocumentSegment.indexing_at: datetime.datetime.utcnow() + } + DocumentSegment.query.filter_by(id=segment.id).update(update_params) + db.session.commit() + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + + dataset = segment.dataset + + if not dataset: + logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + return + + dataset_document = segment.document + + if not dataset_document: + logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + return + + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': + logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + return + + # save vector index + index = IndexBuilder.get_index(dataset, 'high_quality') + if index: + index.add_texts([document], duplicate_check=True) + + # save keyword index + index = IndexBuilder.get_index(dataset, 'economy') + if index: + if keywords and len(keywords) > 0: + index.create_segment_keywords(segment.index_node_id, keywords) + else: + index.add_texts([document]) + + # update segment to completed + update_params = { + DocumentSegment.status: "completed", + DocumentSegment.completed_at: datetime.datetime.utcnow() + } + DocumentSegment.query.filter_by(id=segment.id).update(update_params) + db.session.commit() + + end_at = time.perf_counter() + logging.info(click.style('Segment created to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + except Exception as e: + logging.exception("create segment to index failed") + segment.enabled = False + segment.disabled_at = datetime.datetime.utcnow() + segment.status = 'error' + segment.error = str(e) + db.session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/add_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py similarity index 84% rename from api/tasks/add_segment_to_index_task.py rename to api/tasks/enable_segment_to_index_task.py index bf96a0dc0b..5553420196 100644 --- a/api/tasks/add_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -14,14 +14,14 @@ from models.dataset import DocumentSegment @shared_task -def add_segment_to_index_task(segment_id: str): +def enable_segment_to_index_task(segment_id: str): """ - Async Add segment to index + Async enable segment to index :param segment_id: - Usage: add_segment_to_index.delay(segment_id) + Usage: enable_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start add segment to index: {}'.format(segment_id), fg='green')) + logging.info(click.style('Start enable segment to index: {}'.format(segment_id), fg='green')) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() @@ -71,9 +71,9 @@ def add_segment_to_index_task(segment_id: str): index.add_texts([document]) end_at = time.perf_counter() - logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info(click.style('Segment enabled to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) except Exception as e: - logging.exception("add segment to index failed") + logging.exception("enable segment to index failed") segment.enabled = False segment.disabled_at = datetime.datetime.utcnow() segment.status = 'error' diff --git a/api/tasks/generate_test_task.py b/api/tasks/generate_test_task.py new file mode 100644 index 0000000000..b2506ab0be --- /dev/null +++ b/api/tasks/generate_test_task.py @@ -0,0 +1,24 @@ +import logging +import time + +import click +import requests +from celery import shared_task + +from core.generator.llm_generator import LLMGenerator + + +@shared_task +def generate_test_task(): + logging.info(click.style('Start generate test', fg='green')) + start_at = time.perf_counter() + + try: + #res = requests.post('https://api.openai.com/v1/chat/completions') + answer = LLMGenerator.generate_conversation_name('84b2202c-c359-46b7-a810-bce50feaa4d1', 'avb', 'ccc') + print(f'answer: {answer}') + + end_at = time.perf_counter() + logging.info(click.style('Conversation test, latency: {}'.format(end_at - start_at), fg='green')) + except Exception: + logging.exception("generate test failed") diff --git a/api/tasks/update_segment_index_task.py b/api/tasks/update_segment_index_task.py new file mode 100644 index 0000000000..cf36793919 --- /dev/null +++ b/api/tasks/update_segment_index_task.py @@ -0,0 +1,114 @@ +import datetime +import logging +import time +from typing import List, Optional + +import click +from celery import shared_task +from langchain.schema import Document +from werkzeug.exceptions import NotFound + +from core.index.index import IndexBuilder +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DocumentSegment + + +@shared_task +def update_segment_index_task(segment_id: str, keywords: Optional[List[str]] = None): + """ + Async update segment index + :param segment_id: + :param keywords: + Usage: update_segment_index_task.delay(segment_id) + """ + logging.info(click.style('Start update segment index: {}'.format(segment_id), fg='green')) + start_at = time.perf_counter() + + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + if not segment: + raise NotFound('Segment not found') + + if segment.status != 'updating': + return + + indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + + try: + dataset = segment.dataset + + if not dataset: + logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + return + + dataset_document = segment.document + + if not dataset_document: + logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + return + + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': + logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + return + + # update segment status to indexing + update_params = { + DocumentSegment.status: "indexing", + DocumentSegment.indexing_at: datetime.datetime.utcnow() + } + DocumentSegment.query.filter_by(id=segment.id).update(update_params) + db.session.commit() + + vector_index = IndexBuilder.get_index(dataset, 'high_quality') + kw_index = IndexBuilder.get_index(dataset, 'economy') + + # delete from vector index + if vector_index: + vector_index.delete_by_ids([segment.index_node_id]) + + # delete from keyword index + kw_index.delete_by_ids([segment.index_node_id]) + + # add new index + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + + # save vector index + index = IndexBuilder.get_index(dataset, 'high_quality') + if index: + index.add_texts([document], duplicate_check=True) + + # save keyword index + index = IndexBuilder.get_index(dataset, 'economy') + if index: + if keywords and len(keywords) > 0: + index.create_segment_keywords(segment.index_node_id, keywords) + else: + index.add_texts([document]) + + # update segment to completed + update_params = { + DocumentSegment.status: "completed", + DocumentSegment.completed_at: datetime.datetime.utcnow() + } + DocumentSegment.query.filter_by(id=segment.id).update(update_params) + db.session.commit() + + end_at = time.perf_counter() + logging.info(click.style('Segment update index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + except Exception as e: + logging.exception("update segment index failed") + segment.enabled = False + segment.disabled_at = datetime.datetime.utcnow() + segment.status = 'error' + segment.error = str(e) + db.session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/update_segment_keyword_index_task.py b/api/tasks/update_segment_keyword_index_task.py new file mode 100644 index 0000000000..831076e708 --- /dev/null +++ b/api/tasks/update_segment_keyword_index_task.py @@ -0,0 +1,81 @@ +import datetime +import logging +import time +from typing import List, Optional + +import click +from celery import shared_task +from langchain.schema import Document +from werkzeug.exceptions import NotFound + +from core.index.index import IndexBuilder +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DocumentSegment + + +@shared_task +def update_segment_keyword_index_task(segment_id: str): + """ + Async update segment index + :param segment_id: + Usage: update_segment_keyword_index_task.delay(segment_id) + """ + logging.info(click.style('Start update segment keyword index: {}'.format(segment_id), fg='green')) + start_at = time.perf_counter() + + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + if not segment: + raise NotFound('Segment not found') + + indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + + try: + dataset = segment.dataset + + if not dataset: + logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + return + + dataset_document = segment.document + + if not dataset_document: + logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + return + + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': + logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + return + + kw_index = IndexBuilder.get_index(dataset, 'economy') + + # delete from keyword index + kw_index.delete_by_ids([segment.index_node_id]) + + # add new index + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + + # save keyword index + index = IndexBuilder.get_index(dataset, 'economy') + if index: + index.update_segment_keywords_index(segment.index_node_id, segment.keywords) + + end_at = time.perf_counter() + logging.info(click.style('Segment update index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + except Exception as e: + logging.exception("update segment index failed") + segment.enabled = False + segment.disabled_at = datetime.datetime.utcnow() + segment.status = 'error' + segment.error = str(e) + db.session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/web/app/components/base/auto-height-textarea/common.tsx b/web/app/components/base/auto-height-textarea/common.tsx new file mode 100644 index 0000000000..0c500974d0 --- /dev/null +++ b/web/app/components/base/auto-height-textarea/common.tsx @@ -0,0 +1,52 @@ +import { forwardRef, useEffect, useRef } from 'react' +import cn from 'classnames' + +type AutoHeightTextareaProps = + & React.DetailedHTMLProps, HTMLTextAreaElement> + & { outerClassName?: string } + +const AutoHeightTextarea = forwardRef( + ( + { + outerClassName, + value, + className, + placeholder, + autoFocus, + disabled, + ...rest + }, + outRef, + ) => { + const innerRef = useRef(null) + const ref = outRef || innerRef + + useEffect(() => { + if (autoFocus && !disabled && value) { + if (typeof ref !== 'function') { + ref.current?.setSelectionRange(`${value}`.length, `${value}`.length) + ref.current?.focus() + } + } + }, [autoFocus, disabled, ref]) + return ( +
+
+
+ {!value ? placeholder : `${value}`.replace(/\n$/, '\n ')} +
+