remove unused codes, seperate layout detection out as a new api. Add new rag methed 'table' (#55)

This commit is contained in:
KevinHuSh 2024-02-05 18:08:17 +08:00 committed by GitHub
parent f305776217
commit 407b2523b6
33 changed files with 306 additions and 505 deletions

View File

@ -28,8 +28,6 @@ from api.utils import CustomJSONEncoder
from flask_session import Session
from flask_login import LoginManager
from api.settings import RetCode, SECRET_KEY, stat_logger
from api.hook import HookManager
from api.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters
from api.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger
from api.utils.api_utils import get_json_result, server_error_response
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
@ -96,37 +94,7 @@ client_urls_prefix = [
]
def client_authentication_before_request():
result = HookManager.client_authentication(ClientAuthenticationParameters(
request.full_path, request.headers,
request.form, request.data, request.json,
))
if result.code != RetCode.SUCCESS:
return get_json_result(result.code, result.message)
def site_authentication_before_request():
for url_prefix in client_urls_prefix:
if request.path.startswith(url_prefix):
return
result = HookManager.site_authentication(AuthenticationParameters(
request.headers.get('site_signature'),
request.json,
))
if result.code != RetCode.SUCCESS:
return get_json_result(result.code, result.message)
@app.before_request
def authentication_before_request():
if CLIENT_AUTHENTICATION:
return client_authentication_before_request()
if SITE_AUTHENTICATION:
return site_authentication_before_request()
@login_manager.request_loader
def load_user(web_request):

View File

@ -57,7 +57,7 @@ def list():
for id in sres.ids:
d = {
"chunk_id": id,
"content_ltks": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_ltks"],
"content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_with_weight"],
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []),
@ -134,7 +134,7 @@ def set():
q, a = rmPrefix(arr[0]), rmPrefix[arr[1]]
d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q+a]))
v, c = embd_mdl.encode([doc.name, req["content_ltks"]])
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
@ -175,13 +175,13 @@ def rm():
@manager.route('/create', methods=['POST'])
@login_required
@validate_request("doc_id", "content_ltks")
@validate_request("doc_id", "content_with_weight")
def create():
req = request.json
md5 = hashlib.md5()
md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8"))
md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
chunck_id = md5.hexdigest()
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])}
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"])}
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req.get("important_kwd", [])
d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", [])))
@ -201,7 +201,7 @@ def create():
embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value)
v, c = embd_mdl.encode([doc.name, req["content_ltks"]])
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0)
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()

View File

@ -175,7 +175,7 @@ def chat(dialog, messages, **kwargs):
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]]
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
if not knowledges and prompt_config["empty_response"]:
return {"answer": prompt_config["empty_response"], "retrieval": kbinfos}

View File

@ -73,6 +73,7 @@ def upload():
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"parser_config": kb.parser_config,
"created_by": current_user.id,
"type": filename_type(filename),
"name": filename,
@ -108,6 +109,7 @@ def create():
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"parser_config": kb.parser_config,
"created_by": current_user.id,
"type": FileType.VIRTUAL,
"name": req["name"],
@ -128,8 +130,8 @@ def list():
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
keywords = request.args.get("keywords", "")
page_number = request.args.get("page", 1)
items_per_page = request.args.get("page_size", 15)
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 15))
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
try:
@ -214,7 +216,9 @@ def run():
req = request.json
try:
for id in req["doc_ids"]:
DocumentService.update_by_id(id, {"run": str(req["run"]), "progress": 0})
info = {"run": str(req["run"]), "progress": 0}
if str(req["run"]) == TaskStatus.RUNNING.value:info["progress_msg"] = ""
DocumentService.update_by_id(id, info)
if str(req["run"]) == TaskStatus.CANCEL.value:
tenant_id = DocumentService.get_tenant_id(id)
if not tenant_id:

View File

@ -29,7 +29,7 @@ from api.utils.api_utils import get_json_result
@manager.route('/create', methods=['post'])
@login_required
@validate_request("name", "description", "permission", "parser_id")
@validate_request("name")
def create():
req = request.json
req["name"] = req["name"].strip()

View File

@ -77,3 +77,4 @@ class ParserType(StrEnum):
RESUME = "resume"
BOOK = "book"
QA = "qa"
TABLE = "table"

View File

@ -29,7 +29,7 @@ from peewee import (
)
from playhouse.pool import PooledMySQLDatabase
from api.db import SerializedType
from api.db import SerializedType, ParserType
from api.settings import DATABASE, stat_logger, SECRET_KEY
from api.utils.log_utils import getLogger
from api import utils
@ -381,7 +381,8 @@ class Tenant(DataBaseModel):
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID")
parser_ids = CharField(max_length=128, null=False, help_text="document processors")
credit = IntegerField(default=512)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta:
@ -472,7 +473,8 @@ class Knowledgebase(DataBaseModel):
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value)
parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000})
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self):
@ -487,6 +489,7 @@ class Document(DataBaseModel):
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000})
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
type = CharField(max_length=32, null=False, help_text="file extension")
created_by = CharField(max_length=32, null=False, help_text="who created it")

View File

@ -1,157 +0,0 @@
#
# Copyright 2021 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
import json
import time
from functools import wraps
from shortuuid import ShortUUID
from api.versions import get_rag_version
from api.errors.error_services import *
from api.settings import (
GRPC_PORT, HOST, HTTP_PORT,
RANDOM_INSTANCE_ID, stat_logger,
)
instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}'
server_instance = (
f'{HOST}:{GRPC_PORT}',
json.dumps({
'instance_id': instance_id,
'timestamp': round(time.time() * 1000),
'version': get_rag_version() or '',
'host': HOST,
'grpc_port': GRPC_PORT,
'http_port': HTTP_PORT,
}),
)
def check_service_supported(method):
"""Decorator to check if `service_name` is supported.
The attribute `supported_services` MUST be defined in class.
The first and second arguments of `method` MUST be `self` and `service_name`.
:param Callable method: The class method.
:return: The inner wrapper function.
:rtype: Callable
"""
@wraps(method)
def magic(self, service_name, *args, **kwargs):
if service_name not in self.supported_services:
raise ServiceNotSupported(service_name=service_name)
return method(self, service_name, *args, **kwargs)
return magic
class ServicesDB(abc.ABC):
"""Database for storage service urls.
Abstract base class for the real backends.
"""
@property
@abc.abstractmethod
def supported_services(self):
"""The names of supported services.
The returned list SHOULD contain `ragflow` (model download) and `servings` (RAG-Serving).
:return: The service names.
:rtype: list
"""
pass
@abc.abstractmethod
def _get_serving(self):
pass
def get_serving(self):
try:
return self._get_serving()
except ServicesError as e:
stat_logger.exception(e)
return []
@abc.abstractmethod
def _insert(self, service_name, service_url, value=''):
pass
@check_service_supported
def insert(self, service_name, service_url, value=''):
"""Insert a service url to database.
:param str service_name: The service name.
:param str service_url: The service url.
:return: None
"""
try:
self._insert(service_name, service_url, value)
except ServicesError as e:
stat_logger.exception(e)
@abc.abstractmethod
def _delete(self, service_name, service_url):
pass
@check_service_supported
def delete(self, service_name, service_url):
"""Delete a service url from database.
:param str service_name: The service name.
:param str service_url: The service url.
:return: None
"""
try:
self._delete(service_name, service_url)
except ServicesError as e:
stat_logger.exception(e)
def register_flow(self):
"""Call `self.insert` for insert the flow server address to databae.
:return: None
"""
self.insert('flow-server', *server_instance)
def unregister_flow(self):
"""Call `self.delete` for delete the flow server address from databae.
:return: None
"""
self.delete('flow-server', server_instance[0])
@abc.abstractmethod
def _get_urls(self, service_name, with_values=False):
pass
@check_service_supported
def get_urls(self, service_name, with_values=False):
"""Query service urls from database. The urls may belong to other nodes.
Currently, only `ragflow` (model download) urls and `servings` (RAG-Serving) urls are supported.
`ragflow` is a url containing scheme, host, port and path,
while `servings` only contains host and port.
:param str service_name: The service name.
:return: The service urls.
:rtype: list
"""
try:
return self._get_urls(service_name, with_values)
except ServicesError as e:
stat_logger.exception(e)
return []

View File

@ -63,7 +63,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64):
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\

View File

@ -52,7 +52,8 @@ class KnowledgebaseService(CommonService):
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id]
cls.model.parser_id,
cls.model.parser_config]
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)

View File

@ -27,7 +27,7 @@ class TaskService(CommonService):
@classmethod
@DB.connection_context()
def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64):
fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.parser_config, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Document, on=(cls.model.doc_id == Document.id)) \
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \
@ -53,3 +53,13 @@ class TaskService(CommonService):
except Exception as e:
pass
return True
@classmethod
@DB.connection_context()
def update_progress(cls, id, info):
cls.model.update(progress_msg=cls.model.progress_msg + "\n"+info["progress_msg"]).where(
cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id).execute()

View File

@ -92,6 +92,12 @@ class TenantService(CommonService):
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts())
@classmethod
@DB.connection_context()
def decrease(cls, user_id, num):
num = cls.model.update(credit=cls.model.credit - num).where(
cls.model.id == user_id).execute()
if num == 0: raise LookupError("Tenant not found which is supposed to be there")
class UserTenantService(CommonService):
model = UserTenant

View File

@ -1,10 +0,0 @@
from .general_error import *
class RagFlowError(Exception):
message = 'Unknown Rag Flow Error'
def __init__(self, message=None, *args, **kwargs):
message = str(message) if message is not None else self.message
message = message.format(*args, **kwargs)
super().__init__(message)

View File

@ -1,13 +0,0 @@
from api.errors import RagFlowError
__all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured',
'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError']
class ServicesError(RagFlowError):
message = 'Unknown services error'
class ServiceNotSupported(ServicesError):
message = 'The service {service_name} is not supported'

View File

@ -1,21 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class ParameterError(Exception):
pass
class PassError(Exception):
pass

View File

@ -1,57 +0,0 @@
import importlib
from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, \
SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters
from api.settings import HOOK_MODULE, stat_logger,RetCode
class HookManager:
SITE_SIGNATURE = []
SITE_AUTHENTICATION = []
CLIENT_AUTHENTICATION = []
PERMISSION_CHECK = []
@staticmethod
def init():
if HOOK_MODULE is not None:
for modules in HOOK_MODULE.values():
for module in modules.split(";"):
try:
importlib.import_module(module)
except Exception as e:
stat_logger.exception(e)
@staticmethod
def register_site_signature_hook(func):
HookManager.SITE_SIGNATURE.append(func)
@staticmethod
def register_site_authentication_hook(func):
HookManager.SITE_AUTHENTICATION.append(func)
@staticmethod
def register_client_authentication_hook(func):
HookManager.CLIENT_AUTHENTICATION.append(func)
@staticmethod
def register_permission_check_hook(func):
HookManager.PERMISSION_CHECK.append(func)
@staticmethod
def client_authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn:
if HookManager.CLIENT_AUTHENTICATION:
return HookManager.CLIENT_AUTHENTICATION[0](parm)
return ClientAuthenticationReturn()
@staticmethod
def site_signature(parm: SignatureParameters) -> SignatureReturn:
if HookManager.SITE_SIGNATURE:
return HookManager.SITE_SIGNATURE[0](parm)
return SignatureReturn()
@staticmethod
def site_authentication(parm: AuthenticationParameters) -> AuthenticationReturn:
if HookManager.SITE_AUTHENTICATION:
return HookManager.SITE_AUTHENTICATION[0](parm)
return AuthenticationReturn()

View File

@ -1,29 +0,0 @@
import requests
from api.db.service_registry import ServiceRegistry
from api.settings import RegistryServiceName
from api.hook import HookManager
from api.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn
from api.settings import HOOK_SERVER_NAME
@HookManager.register_client_authentication_hook
def authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn:
service_list = ServiceRegistry.load_service(
server_name=HOOK_SERVER_NAME,
service_name=RegistryServiceName.CLIENT_AUTHENTICATION.value
)
if not service_list:
raise Exception(f"client authentication error: no found server"
f" {HOOK_SERVER_NAME} service client_authentication")
service = service_list[0]
response = getattr(requests, service.f_method.lower(), None)(
url=service.f_url,
json=parm.to_dict()
)
if response.status_code != 200:
raise Exception(
f"client authentication error: request authentication url failed, status code {response.status_code}")
elif response.json().get("code") != 0:
return ClientAuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg"))
return ClientAuthenticationReturn()

View File

@ -1,25 +0,0 @@
import requests
from api.db.service_registry import ServiceRegistry
from api.settings import RegistryServiceName
from api.hook import HookManager
from api.hook.common.parameters import PermissionCheckParameters, PermissionReturn
from api.settings import HOOK_SERVER_NAME
@HookManager.register_permission_check_hook
def permission(parm: PermissionCheckParameters) -> PermissionReturn:
service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.PERMISSION_CHECK.value)
if not service_list:
raise Exception(f"permission check error: no found server {HOOK_SERVER_NAME} service permission")
service = service_list[0]
response = getattr(requests, service.f_method.lower(), None)(
url=service.f_url,
json=parm.to_dict()
)
if response.status_code != 200:
raise Exception(
f"permission check error: request permission url failed, status code {response.status_code}")
elif response.json().get("code") != 0:
return PermissionReturn(code=response.json().get("code"), message=response.json().get("msg"))
return PermissionReturn()

View File

@ -1,49 +0,0 @@
import requests
from api.db.service_registry import ServiceRegistry
from api.settings import RegistryServiceName
from api.hook import HookManager
from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\
SignatureReturn
from api.settings import HOOK_SERVER_NAME, PARTY_ID
@HookManager.register_site_signature_hook
def signature(parm: SignatureParameters) -> SignatureReturn:
service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.SIGNATURE.value)
if not service_list:
raise Exception(f"signature error: no found server {HOOK_SERVER_NAME} service signature")
service = service_list[0]
response = getattr(requests, service.f_method.lower(), None)(
url=service.f_url,
json=parm.to_dict()
)
if response.status_code == 200:
if response.json().get("code") == 0:
return SignatureReturn(site_signature=response.json().get("data"))
else:
raise Exception(f"signature error: request signature url failed, result: {response.json()}")
else:
raise Exception(f"signature error: request signature url failed, status code {response.status_code}")
@HookManager.register_site_authentication_hook
def authentication(parm: AuthenticationParameters) -> AuthenticationReturn:
if not parm.src_party_id or str(parm.src_party_id) == "0":
parm.src_party_id = PARTY_ID
service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME,
service_name=RegistryServiceName.SITE_AUTHENTICATION.value)
if not service_list:
raise Exception(
f"site authentication error: no found server {HOOK_SERVER_NAME} service site_authentication")
service = service_list[0]
response = getattr(requests, service.f_method.lower(), None)(
url=service.f_url,
json=parm.to_dict()
)
if response.status_code != 200:
raise Exception(
f"site authentication error: request site_authentication url failed, status code {response.status_code}")
elif response.json().get("code") != 0:
return AuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg"))
return AuthenticationReturn()

View File

@ -1,56 +0,0 @@
from api.settings import RetCode
class ParametersBase:
def to_dict(self):
d = {}
for k, v in self.__dict__.items():
d[k] = v
return d
class ClientAuthenticationParameters(ParametersBase):
def __init__(self, full_path, headers, form, data, json):
self.full_path = full_path
self.headers = headers
self.form = form
self.data = data
self.json = json
class ClientAuthenticationReturn(ParametersBase):
def __init__(self, code=RetCode.SUCCESS, message="success"):
self.code = code
self.message = message
class SignatureParameters(ParametersBase):
def __init__(self, party_id, body):
self.party_id = party_id
self.body = body
class SignatureReturn(ParametersBase):
def __init__(self, code=RetCode.SUCCESS, site_signature=None):
self.code = code
self.site_signature = site_signature
class AuthenticationParameters(ParametersBase):
def __init__(self, site_signature, body):
self.site_signature = site_signature
self.body = body
class AuthenticationReturn(ParametersBase):
def __init__(self, code=RetCode.SUCCESS, message="success"):
self.code = code
self.message = message
class PermissionReturn(ParametersBase):
def __init__(self, code=RetCode.SUCCESS, message="success"):
self.code = code
self.message = message

View File

@ -20,12 +20,9 @@ import os
import signal
import sys
import traceback
from werkzeug.serving import run_simple
from api.apps import app
from api.db.runtime_config import RuntimeConfig
from api.hook import HookManager
from api.settings import (
HOST, HTTP_PORT, access_logger, database_logger, stat_logger,
)
@ -60,8 +57,6 @@ if __name__ == '__main__':
RuntimeConfig.init_env()
RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
HookManager.init()
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
# rag_arch.common.log.ROpenHandler

View File

@ -47,7 +47,7 @@ LLM = get_base_config("llm", {})
CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo")
EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002")
ASR_MDL = LLM.get("asr_model", "whisper-1")
PARSERS = LLM.get("parsers", "General,Resume,Laws,Product Instructions,Books,Paper,Q&A,Programming Code,Power Point,Research Report")
PARSERS = LLM.get("parsers", "general:General,resume:esume,laws:Laws,manual:Manual,book:Book,paper:Paper,qa:Q&A,presentation:Presentation")
IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview")
# distribution

View File

@ -3,7 +3,7 @@ import random
import re
import numpy as np
from rag.parser import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge
hierarchical_merge, make_colon_as_title, naive_merge, random_choices
from rag.nlp import huqie
from rag.parser.docx_parser import HuDocxParser
from rag.parser.pdf_parser import HuParser
@ -51,7 +51,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
doc_parser = HuDocxParser()
# TODO: table of contents need to be removed
sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english(random.choices([t for t,_ in sections], k=200)))
remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200)))
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
@ -67,20 +67,20 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
l = f.readline()
if not l:break
txt += l
sections = txt.split("\n")
sections = txt.split("\n")
sections = [(l,"") for l in sections if l]
remove_contents_table(sections, eng = is_english(random.choices([t for t,_ in sections], k=200)))
remove_contents_table(sections, eng = is_english(random_choices([t for t,_ in sections], k=200)))
callback(0.8, "Finish parsing.")
else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
make_colon_as_title(sections)
bull = bullets_category([t for t in random.choices([t for t,_ in sections], k=100)])
bull = bullets_category([t for t in random_choices([t for t,_ in sections], k=100)])
if bull >= 0: cks = hierarchical_merge(bull, sections, 3)
else: cks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;!?"))
sections = [t for t, _ in sections]
# is it English
eng = is_english(random.choices(sections, k=218))
eng = is_english(random_choices(sections, k=218))
res = []
# add tables

View File

@ -86,7 +86,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
l = f.readline()
if not l:break
txt += l
sections = txt.split("\n")
sections = txt.split("\n")
sections = txt.split("\n")
sections = [l for l in sections if l]
callback(0.8, "Finish parsing.")
else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")

View File

@ -52,7 +52,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
l = f.readline()
if not l:break
txt += l
sections = txt.split("\n")
sections = txt.split("\n")
sections = [(l,"") for l in sections if l]
callback(0.8, "Finish parsing.")
else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")

View File

@ -1,6 +1,9 @@
import copy
import re
from collections import Counter
from api.db import ParserType
from rag.cv.ppdetection import PPDet
from rag.parser import tokenize
from rag.nlp import huqie
from rag.parser.pdf_parser import HuParser
@ -9,6 +12,10 @@ from rag.utils import num_tokens_from_string
class Pdf(HuParser):
def __init__(self):
self.model_speciess = ParserType.PAPER.value
super().__init__()
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
self.__images__(
@ -63,6 +70,15 @@ class Pdf(HuParser):
"[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)",
txt.lower().strip())
if from_page > 0:
return {
"title":"",
"authors": "",
"abstract": "",
"lines": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if
re.match(r"(text|title)", b.get("layoutno", "text"))],
"tables": tbls
}
# get title and authors
title = ""
authors = []
@ -115,18 +131,13 @@ class Pdf(HuParser):
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
pdf_parser = None
paper = {}
if re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
paper = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
else: raise NotImplementedError("file type not supported yet(pdf supported)")
doc = {
"docnm_kwd": paper["title"] if paper["title"] else filename,
"authors_tks": paper["authors"]
}
doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
doc = {"docnm_kwd": filename, "authors_tks": paper["authors"],
"title_tks": huqie.qie(paper["title"] if paper["title"] else filename)}
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"])
# is it English

View File

@ -3,7 +3,7 @@ import re
from io import BytesIO
from nltk import word_tokenize
from openpyxl import load_workbook
from rag.parser import is_english
from rag.parser import is_english, random_choices
from rag.nlp import huqie, stemmer
@ -33,9 +33,9 @@ class Excel(object):
if len(res) % 999 == 0:
callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
self.is_english = is_english([rmPrefix(q) for q, _ in random.choices(res, k=30) if len(q)>1])
self.is_english = is_english([rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q)>1])
return res

170
rag/app/table.py Normal file
View File

@ -0,0 +1,170 @@
import copy
import random
import re
from io import BytesIO
from xpinyin import Pinyin
import numpy as np
import pandas as pd
from nltk import word_tokenize
from openpyxl import load_workbook
from dateutil.parser import parse as datetime_parse
from rag.parser import is_english, tokenize
from rag.nlp import huqie, stemmer
class Excel(object):
def __call__(self, fnm, binary=None, callback=None):
if not binary:
wb = load_workbook(fnm)
else:
wb = load_workbook(BytesIO(binary))
total = 0
for sheetname in wb.sheetnames:
total += len(list(wb[sheetname].rows))
res, fails, done = [], [], 0
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
headers = [cell.value for cell in rows[0]]
missed = set([i for i,h in enumerate(headers) if h is None])
headers = [cell.value for i,cell in enumerate(rows[0]) if i not in missed]
data = []
for i, r in enumerate(rows[1:]):
row = [cell.value for ii,cell in enumerate(r) if ii not in missed]
if len(row) != len(headers):
fails.append(str(i))
continue
data.append(row)
done += 1
if done % 999 == 0:
callback(done * 0.6/total, ("Extract records: {}".format(len(res)) + (f"{len(fails)} failure({sheetname}), line: %s..."%(",".join(fails[:3])) if fails else "")))
res.append(pd.DataFrame(np.array(data), columns=headers))
callback(0.6, ("Extract records: {}. ".format(done) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
def trans_datatime(s):
try:
return datetime_parse(s.strip()).strftime("%Y-%m-%dT%H:%M:%S")
except Exception as e:
pass
def trans_bool(s):
if re.match(r"(true|yes|是)$", str(s).strip(), flags=re.IGNORECASE): return ["yes", ""]
if re.match(r"(false|no|否)$", str(s).strip(), flags=re.IGNORECASE): return ["no", ""]
def column_data_type(arr):
uni = len(set([a for a in arr if a is not None]))
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
trans = {t:f for f,t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
for a in arr:
if a is None:continue
if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")):
counts["int"] += 1
elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")):
counts["float"] += 1
elif re.match(r"(true|false|yes|no|是|否)$", str(a), flags=re.IGNORECASE):
counts["bool"] += 1
elif trans_datatime(str(a)):
counts["datetime"] += 1
else: counts["text"] += 1
counts = sorted(counts.items(), key=lambda x: x[1]*-1)
ty = counts[0][0]
for i in range(len(arr)):
if arr[i] is None:continue
try:
arr[i] = trans[ty](str(arr[i]))
except Exception as e:
arr[i] = None
if ty == "text":
if len(arr) > 128 and uni/len(arr) < 0.1:
ty = "keyword"
return arr, ty
def chunk(filename, binary=None, callback=None, **kwargs):
dfs = []
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
dfs = excel_parser(filename, binary, callback)
elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = ""
if binary:
txt = binary.decode("utf-8")
else:
with open(filename, "r") as f:
while True:
l = f.readline()
if not l: break
txt += l
lines = txt.split("\n")
fails = []
headers = lines[0].split(kwargs.get("delimiter", "\t"))
rows = []
for i, line in enumerate(lines[1:]):
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
if len(row) != len(headers):
fails.append(str(i))
continue
rows.append(row)
if len(rows) % 999 == 0:
callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract records: {}".format(len(rows)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
else: raise NotImplementedError("file type not supported yet(excel, text, csv supported)")
res = []
PY = Pinyin()
fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
for df in dfs:
for n in ["id", "_id", "index", "idx"]:
if n in df.columns:del df[n]
clmns = df.columns.values
txts = list(copy.deepcopy(clmns))
py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns]
clmn_tys = []
for j in range(len(clmns)):
cln,ty = column_data_type(df[clmns[j]])
clmn_tys.append(ty)
df[clmns[j]] = cln
if ty == "text": txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))]
# TODO: set this column map to KB parser configuration
eng = is_english(txts)
for ii,row in df.iterrows():
d = {}
row_txt = []
for j in range(len(clmns)):
if row[clmns[j]] is None:continue
fld = clmns_map[j][0]
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]])
row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
if not row_txt:continue
tokenize(d, "; ".join(row_txt), eng)
print(d)
res.append(d)
callback(0.6, "")
return res
if __name__== "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

View File

@ -67,7 +67,7 @@ class Dealer:
ps = int(req.get("size", 1000))
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id",
"image_id", "doc_id", "q_512_vec", "q_768_vec",
"q_1024_vec", "q_1536_vec", "available_int"])
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
s = s.highlight("content_ltks")
@ -234,7 +234,7 @@ class Dealer:
sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids]
if not ins_embd:
return [], [], []
ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ")
ins_tw = [sres.field[i][cfield].split(" ")
for i in sres.ids]
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
@ -281,6 +281,7 @@ class Dealer:
d = {
"chunk_id": id,
"content_ltks": sres.field[id]["content_ltks"],
"content_with_weight": sres.field[id]["content_with_weight"],
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": dnm,
"kb_id": sres.field[id]["kb_id"],

View File

@ -1,4 +1,5 @@
import copy
import random
from .pdf_parser import HuParser as PdfParser
from .docx_parser import HuDocxParser as DocxParser
@ -38,6 +39,9 @@ BULLET_PATTERN = [[
]
]
def random_choices(arr, k):
k = min(len(arr), k)
return random.choices(arr, k=k)
def bullets_category(sections):
global BULLET_PATTERN

View File

@ -1,7 +1,10 @@
# -*- coding: utf-8 -*-
import os
import random
from functools import partial
import fitz
import requests
import xgboost as xgb
from io import BytesIO
import torch
@ -10,13 +13,14 @@ import pdfplumber
import logging
from PIL import Image
import numpy as np
from api.db import ParserType
from rag.nlp import huqie
from collections import Counter
from copy import deepcopy
from rag.cv.table_recognize import TableTransformer
from rag.cv.ppdetection import PPDet
from huggingface_hub import hf_hub_download
logging.getLogger("pdfminer").setLevel(logging.WARNING)
@ -25,8 +29,10 @@ class HuParser:
from paddleocr import PaddleOCR
logging.getLogger("ppocr").setLevel(logging.ERROR)
self.ocr = PaddleOCR(use_angle_cls=False, lang="ch")
self.layouter = PPDet("/data/newpeak/medical-gpt/res/ppdet")
self.tbl_det = PPDet("/data/newpeak/medical-gpt/res/ppdet.tbl")
if not hasattr(self, "model_speciess"):
self.model_speciess = ParserType.GENERAL.value
self.layouter = partial(self.__remote_call, self.model_speciess)
self.tbl_det = partial(self.__remote_call, "table_component")
self.updown_cnt_mdl = xgb.Booster()
if torch.cuda.is_available():
@ -45,6 +51,38 @@ class HuParser:
"""
def __remote_call(self, species, images, thr=0.7):
url = os.environ.get("INFINIFLOW_SERVER")
if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'")
token = os.environ.get("INFINIFLOW_TOKEN")
if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'")
def convert_image_to_bytes(PILimage):
image = BytesIO()
PILimage.save(image, format='png')
image.seek(0)
return image.getvalue()
images = [convert_image_to_bytes(img) for img in images]
def remote_call():
nonlocal images, thr
res = requests.post(url+"/v1/layout/detect/"+species, files=[("image", img) for img in images], data={"threashold": thr},
headers={"Authorization": token}, timeout=len(images) * 10)
res = res.json()
if res["retcode"] != 0: raise RuntimeError(res["retmsg"])
return res["data"]
for _ in range(3):
try:
return remote_call()
except RuntimeError as e:
raise e
except Exception as e:
logging.error("layout_predict:"+str(e))
return remote_call()
def __char_width(self, c):
return (c["x1"] - c["x0"]) // len(c["text"])
@ -344,7 +382,7 @@ class HuParser:
return layouts
def __table_paddle(self, images):
tbls = self.tbl_det([np.array(img) for img in images], thr=0.5)
tbls = self.tbl_det(images, thr=0.5)
res = []
# align left&right for rows, align top&bottom for columns
for tbl in tbls:
@ -522,7 +560,7 @@ class HuParser:
assert len(self.page_images) == len(self.boxes)
# Tag layout type
boxes = []
layouts = self.layouter([np.array(img) for img in self.page_images])
layouts = self.layouter(self.page_images)
assert len(self.page_images) == len(layouts)
for pn, lts in enumerate(layouts):
bxs = self.boxes[pn]
@ -1705,7 +1743,8 @@ class HuParser:
self.__ocr_paddle(i + 1, img, chars, zoomin)
if not self.is_english and not any([c for c in self.page_chars]) and self.boxes:
self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices([b for bxs in self.boxes for b in bxs], k=30)]))
bxes = [b for bxs in self.boxes for b in bxs]
self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))]))
logging.info("Is it English:", self.is_english)

View File

@ -134,5 +134,5 @@ if __name__ == "__main__":
while True:
dispatch()
time.sleep(3)
time.sleep(1)
update_progress()

View File

@ -36,7 +36,7 @@ from rag.nlp import search
from io import BytesIO
import pandas as pd
from rag.app import laws, paper, presentation, manual, qa
from rag.app import laws, paper, presentation, manual, qa, table,book
from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
@ -49,10 +49,12 @@ BATCH_SIZE = 64
FACTORY = {
ParserType.GENERAL.value: laws,
ParserType.PAPER.value: paper,
ParserType.BOOK.value: book,
ParserType.PRESENTATION.value: presentation,
ParserType.MANUAL.value: manual,
ParserType.LAWS.value: laws,
ParserType.QA.value: qa,
ParserType.TABLE.value: table,
}
@ -66,7 +68,7 @@ def set_progress(task_id, from_page, to_page, prog=None, msg="Processing..."):
d = {"progress_msg": msg}
if prog is not None: d["progress"] = prog
try:
TaskService.update_by_id(task_id, d)
TaskService.update_progress(task_id, d)
except Exception as e:
cron_logger.error("set_progress:({}), {}".format(task_id, str(e)))
@ -113,7 +115,7 @@ def build(row, cvmdl):
return []
callback = partial(set_progress, row["id"], row["from_page"], row["to_page"])
chunker = FACTORY[row["parser_id"]]
chunker = FACTORY[row["parser_id"].lower()]
try:
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
@ -154,6 +156,7 @@ def build(row, cvmdl):
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
del d["image"]
docs.append(d)
return docs
@ -168,7 +171,7 @@ def init_kb(row):
def embedding(docs, mdl):
tts, cnts = [d["docnm_kwd"] for d in docs if d.get("docnm_kwd")], [d["content_with_weight"] for d in docs]
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs]
tk_count = 0
if len(tts) == len(cnts):
tts, c = mdl.encode(tts)
@ -207,6 +210,7 @@ def main(comm, mod):
cks = build(r, cv_mdl)
if not cks:
tmf.write(str(r["update_time"]) + "\n")
callback(1., "No chunk! Done!")
continue
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
@ -215,7 +219,6 @@ def main(comm, mod):
except Exception as e:
callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e))
continue
callback(msg="Finished embedding! Start to build index!")
init_kb(r)
@ -227,6 +230,7 @@ def main(comm, mod):
else:
if TaskService.do_cancel(r["id"]):
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
continue
callback(1., "Done!")
DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))