add alot of api (#23)

* clean rust version project

* clean rust version project

* build python version rag-flow

* add alot of api
This commit is contained in:
KevinHuSh 2024-01-15 19:47:25 +08:00 committed by GitHub
parent 30791976d5
commit 3198faf2d2
16 changed files with 339 additions and 58 deletions

View File

@ -35,7 +35,7 @@ class Base(ABC):
class HuEmbedding(Base): class HuEmbedding(Base):
def __init__(self): def __init__(self, key="", model_name=""):
""" """
If you have trouble downloading HuggingFace models, -_^ this might help!! If you have trouble downloading HuggingFace models, -_^ this might help!!

View File

@ -411,8 +411,11 @@ class TextChunker(HuChunker):
flds = self.Fields() flds = self.Fields()
if self.is_binary_file(fnm): if self.is_binary_file(fnm):
return flds return flds
txt = ""
if isinstance(fnm, str):
with open(fnm, "r") as f: with open(fnm, "r") as f:
txt = f.read() txt = f.read()
else: txt = fnm.decode("utf-8")
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
flds.table_chunks = [] flds.table_chunks = []
return flds return flds

View File

@ -8,7 +8,7 @@ from rag.nlp import huqie, query
import numpy as np import numpy as np
def index_name(uid): return f"docgpt_{uid}" def index_name(uid): return f"ragflow_{uid}"
class Dealer: class Dealer:

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# #
import json import json
import logging
import os import os
import hashlib import hashlib
import copy import copy
@ -24,9 +25,10 @@ from timeit import default_timer as timer
from rag.llm import EmbeddingModel, CvModel from rag.llm import EmbeddingModel, CvModel
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH, num_tokens_from_string from rag.utils import ELASTICSEARCH
from rag.utils import MINIO from rag.utils import MINIO
from rag.utils import rmSpace, findMaxDt from rag.utils import rmSpace, findMaxTm
from rag.nlp import huchunk, huqie, search from rag.nlp import huchunk, huqie, search
from io import BytesIO from io import BytesIO
import pandas as pd import pandas as pd
@ -47,6 +49,7 @@ from rag.nlp.huchunk import (
from web_server.db import LLMType from web_server.db import LLMType
from web_server.db.services.document_service import DocumentService from web_server.db.services.document_service import DocumentService
from web_server.db.services.llm_service import TenantLLMService from web_server.db.services.llm_service import TenantLLMService
from web_server.settings import database_logger
from web_server.utils import get_format_time from web_server.utils import get_format_time
from web_server.utils.file_utils import get_project_base_directory from web_server.utils.file_utils import get_project_base_directory
@ -83,7 +86,7 @@ def collect(comm, mod, tm):
if len(docs) == 0: if len(docs) == 0:
return pd.DataFrame() return pd.DataFrame()
docs = pd.DataFrame(docs) docs = pd.DataFrame(docs)
mtm = str(docs["update_time"].max())[:19] mtm = docs["update_time"].max()
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm)) cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
return docs return docs
@ -99,11 +102,12 @@ def set_progress(docid, prog, msg="Processing...", begin=False):
cron_logger.error("set_progress:({}), {}".format(docid, str(e))) cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
def build(row): def build(row, cvmdl):
if row["size"] > DOC_MAXIMUM_SIZE: if row["size"] > DOC_MAXIMUM_SIZE:
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" % set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024))) (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return [] return []
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"])) res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
if ELASTICSEARCH.getTotal(res) > 0: if ELASTICSEARCH.getTotal(res) > 0:
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]), ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
@ -120,7 +124,8 @@ def build(row):
set_progress(row["id"], random.randint(0, 20) / set_progress(row["id"], random.randint(0, 20) /
100., "Finished preparing! Start to slice file!", True) 100., "Finished preparing! Start to slice file!", True)
try: try:
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"])) cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
except Exception as e: except Exception as e:
if re.search("(No such file|not found)", str(e)): if re.search("(No such file|not found)", str(e)):
set_progress( set_progress(
@ -131,6 +136,9 @@ def build(row):
row["id"], -1, f"Internal server error: %s" % row["id"], -1, f"Internal server error: %s" %
str(e).replace( str(e).replace(
"'", "")) "'", ""))
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
return [] return []
if not obj.text_chunks and not obj.table_chunks: if not obj.text_chunks and not obj.table_chunks:
@ -144,7 +152,7 @@ def build(row):
"Finished slicing files. Start to embedding the content.") "Finished slicing files. Start to embedding the content.")
doc = { doc = {
"doc_id": row["did"], "doc_id": row["id"],
"kb_id": [str(row["kb_id"])], "kb_id": [str(row["kb_id"])],
"docnm_kwd": os.path.split(row["location"])[-1], "docnm_kwd": os.path.split(row["location"])[-1],
"title_tks": huqie.qie(row["name"]), "title_tks": huqie.qie(row["name"]),
@ -164,10 +172,10 @@ def build(row):
docs.append(d) docs.append(d)
continue continue
if isinstance(img, Image): if isinstance(img, bytes):
img.save(output_buffer, format='JPEG')
else:
output_buffer = BytesIO(img) output_buffer = BytesIO(img)
else:
img.save(output_buffer, format='JPEG')
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
@ -215,15 +223,16 @@ def embedding(docs, mdl):
def model_instance(tenant_id, llm_type): def model_instance(tenant_id, llm_type):
model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING) model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
if not model_config:return if not model_config:
model_config = model_config[0] model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
else: model_config = model_config[0].to_dict()
if llm_type == LLMType.EMBEDDING: if llm_type == LLMType.EMBEDDING:
if model_config.llm_factory not in EmbeddingModel: return if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT: if llm_type == LLMType.IMAGE2TEXT:
if model_config.llm_factory not in CvModel: return if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name) return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
def main(comm, mod): def main(comm, mod):
@ -231,7 +240,7 @@ def main(comm, mod):
from rag.llm import HuEmbedding from rag.llm import HuEmbedding
model = HuEmbedding() model = HuEmbedding()
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm") tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
tm = findMaxDt(tm_fnm) tm = findMaxTm(tm_fnm)
rows = collect(comm, mod, tm) rows = collect(comm, mod, tm)
if len(rows) == 0: if len(rows) == 0:
return return
@ -247,7 +256,7 @@ def main(comm, mod):
st_tm = timer() st_tm = timer()
cks = build(r, cv_mdl) cks = build(r, cv_mdl)
if not cks: if not cks:
tmf.write(str(r["updated_at"]) + "\n") tmf.write(str(r["update_time"]) + "\n")
continue continue
# TODO: exception handler # TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ") ## set_progress(r["did"], -1, "ERROR: ")
@ -268,12 +277,19 @@ def main(comm, mod):
cron_logger.error(str(es_r)) cron_logger.error(str(es_r))
else: else:
set_progress(r["id"], 1., "Done!") set_progress(r["id"], 1., "Done!")
DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm}) DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
tmf.write(str(r["update_time"]) + "\n") tmf.write(str(r["update_time"]) + "\n")
tmf.close() tmf.close()
if __name__ == "__main__": if __name__ == "__main__":
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank()) main(comm.Get_size(), comm.Get_rank())

View File

@ -40,6 +40,25 @@ def findMaxDt(fnm):
print("WARNING: can't find " + fnm) print("WARNING: can't find " + fnm)
return m return m
def findMaxTm(fnm):
m = 0
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:
break
l = l.strip("\n")
if l == 'nan':
continue
if int(l) > m:
m = int(l)
except Exception as e:
print("WARNING: can't find " + fnm)
return m
def num_tokens_from_string(string: str) -> int: def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string.""" """Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base') encoding = tiktoken.get_encoding('cl100k_base')

View File

@ -294,6 +294,7 @@ class HuEs:
except Exception as e: except Exception as e:
es_logger.error("ES updateByQuery deleteByQuery: " + es_logger.error("ES updateByQuery deleteByQuery: " +
str(e) + "【Q】" + str(query.to_dict())) str(e) + "【Q】" + str(query.to_dict()))
if str(e).find("NotFoundError") > 0: return True
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import base64
import pathlib import pathlib
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
@ -195,11 +196,15 @@ def rm():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(retmsg="Document not found!") return get_data_error_result(retmsg="Document not found!")
if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)):
return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
if not DocumentService.delete_by_id(req["doc_id"]): if not DocumentService.delete_by_id(req["doc_id"]):
return get_data_error_result( return get_data_error_result(
retmsg="Database error (Document removal)!") retmsg="Database error (Document removal)!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
MINIO.rm(kb.id, doc.location) MINIO.rm(doc.kb_id, doc.location)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -233,3 +238,43 @@ def rename():
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/get', methods=['GET'])
@login_required
def get():
doc_id = request.args["doc_id"]
try:
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(retmsg="Document not found!")
blob = MINIO.get(doc.kb_id, doc.location)
return get_json_result(data={"base64": base64.b64decode(blob)})
except Exception as e:
return server_error_response(e)
@manager.route('/change_parser', methods=['POST'])
@login_required
@validate_request("doc_id", "parser_id")
def change_parser():
req = request.json
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")
if doc.parser_id.lower() == req["parser_id"].lower():
return get_json_result(data=True)
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
if not e:
return get_data_error_result(retmsg="Document not found!")
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1)
if not e:
return get_data_error_result(retmsg="Document not found!")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

View File

@ -29,7 +29,7 @@ from web_server.utils.api_utils import get_json_result
@manager.route('/create', methods=['post']) @manager.route('/create', methods=['post'])
@login_required @login_required
@validate_request("name", "description", "permission", "embd_id", "parser_id") @validate_request("name", "description", "permission", "parser_id")
def create(): def create():
req = request.json req = request.json
req["name"] = req["name"].strip() req["name"] = req["name"].strip()
@ -46,7 +46,7 @@ def create():
@manager.route('/update', methods=['post']) @manager.route('/update', methods=['post'])
@login_required @login_required
@validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id") @validate_request("kb_id", "name", "description", "permission", "parser_id")
def update(): def update():
req = request.json req = request.json
req["name"] = req["name"].strip() req["name"] = req["name"].strip()
@ -72,6 +72,18 @@ def update():
return server_error_response(e) return server_error_response(e)
@manager.route('/detail', methods=['GET'])
@login_required
def detail():
kb_id = request.args["kb_id"]
try:
kb = KnowledgebaseService.get_detail(kb_id)
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
return get_json_result(data=kb)
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET']) @manager.route('/list', methods=['GET'])
@login_required @login_required
def list(): def list():

View File

@ -0,0 +1,95 @@
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request
from flask_login import login_required, current_user
from web_server.db.services import duplicate_name
from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
from web_server.db.services.user_service import TenantService, UserTenantService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid, get_format_time
from web_server.db import StatusEnum, UserTenantRole
from web_server.db.services.kb_service import KnowledgebaseService
from web_server.db.db_models import Knowledgebase, TenantLLM
from web_server.settings import stat_logger, RetCode
from web_server.utils.api_utils import get_json_result
@manager.route('/factories', methods=['GET'])
@login_required
def factories():
try:
fac = LLMFactoriesService.get_all()
return get_json_result(data=fac.to_json())
except Exception as e:
return server_error_response(e)
@manager.route('/set_api_key', methods=['POST'])
@login_required
@validate_request("llm_factory", "api_key")
def set_api_key():
req = request.json
llm = {
"tenant_id": current_user.id,
"llm_factory": req["llm_factory"],
"api_key": req["api_key"]
}
# TODO: Test api_key
for n in ["model_type", "llm_name"]:
if n in req: llm[n] = req[n]
TenantLLM.insert(**llm).on_conflict("replace").execute()
return get_json_result(data=True)
@manager.route('/my_llms', methods=['GET'])
@login_required
def my_llms():
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
objs = [o.to_dict() for o in objs]
for o in objs: del o["api_key"]
return get_json_result(data=objs)
except Exception as e:
return server_error_response(e)
@manager.route('/list', methods=['GET'])
@login_required
def list():
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
objs = [o.to_dict() for o in objs if o.api_key]
fct = {}
for o in objs:
if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = False
if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
m["available"] = True
res = {}
for m in llms:
if m["fid"] not in res: res[m["fid"]] = []
res[m["fid"]].append(m)
return get_json_result(data=res)
except Exception as e:
return server_error_response(e)

View File

@ -16,9 +16,12 @@
from flask import request, session, redirect, url_for from flask import request, session, redirect, url_for
from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.security import generate_password_hash, check_password_hash
from flask_login import login_required, current_user, login_user, logout_user from flask_login import login_required, current_user, login_user, logout_user
from web_server.db.db_models import TenantLLM
from web_server.db.services.llm_service import TenantLLMService
from web_server.utils.api_utils import server_error_response, validate_request from web_server.utils.api_utils import server_error_response, validate_request
from web_server.utils import get_uuid, get_format_time, decrypt, download_img from web_server.utils import get_uuid, get_format_time, decrypt, download_img
from web_server.db import UserTenantRole from web_server.db import UserTenantRole, LLMType
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
from web_server.db.services.user_service import UserService, TenantService, UserTenantService from web_server.db.services.user_service import UserService, TenantService, UserTenantService
from web_server.settings import stat_logger from web_server.settings import stat_logger
@ -47,8 +50,9 @@ def login():
avatar = download_img(userinfo["avatar_url"]) avatar = download_img(userinfo["avatar_url"])
except Exception as e: except Exception as e:
stat_logger.exception(e) stat_logger.exception(e)
user_id = get_uuid()
try: try:
users = user_register({ users = user_register(user_id, {
"access_token": session["access_token"], "access_token": session["access_token"],
"email": userinfo["email"], "email": userinfo["email"],
"avatar": avatar, "avatar": avatar,
@ -63,6 +67,7 @@ def login():
login_user(user) login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
except Exception as e: except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e) stat_logger.exception(e)
return server_error_response(e) return server_error_response(e)
elif not request.json: elif not request.json:
@ -162,7 +167,25 @@ def user_info():
return get_json_result(data=current_user.to_dict()) return get_json_result(data=current_user.to_dict())
def user_register(user): def rollback_user_registration(user_id):
try:
TenantService.delete_by_id(user_id)
except Exception as e:
pass
try:
u = UserTenantService.query(tenant_id=user_id)
if u:
UserTenantService.delete_by_id(u[0].id)
except Exception as e:
pass
try:
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
except Exception as e:
pass
def user_register(user_id, user):
user_id = get_uuid() user_id = get_uuid()
user["id"] = user_id user["id"] = user_id
tenant = { tenant = {
@ -180,10 +203,12 @@ def user_register(user):
"invited_by": user_id, "invited_by": user_id,
"role": UserTenantRole.OWNER "role": UserTenantRole.OWNER
} }
tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
if not UserService.save(**user):return if not UserService.save(**user):return
TenantService.save(**tenant) TenantService.save(**tenant)
UserTenantService.save(**usr_tenant) UserTenantService.save(**usr_tenant)
TenantLLMService.save(**tenant_llm)
return UserService.query(email=user["email"]) return UserService.query(email=user["email"])
@ -203,14 +228,17 @@ def user_add():
"last_login_time": get_format_time(), "last_login_time": get_format_time(),
"is_superuser": False, "is_superuser": False,
} }
user_id = get_uuid()
try: try:
users = user_register(user_dict) users = user_register(user_id, user_dict)
if not users: raise Exception('Register user failure.') if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!') if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0] user = users[0]
login_user(user) login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
except Exception as e: except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e) stat_logger.exception(e)
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
@ -220,7 +248,7 @@ def user_add():
@login_required @login_required
def tenant_info(): def tenant_info():
try: try:
tenants = TenantService.get_by_user_id(current_user.id) tenants = TenantService.get_by_user_id(current_user.id)[0]
return get_json_result(data=tenants) return get_json_result(data=tenants)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -428,6 +428,7 @@ class LLMFactories(DataBaseModel):
class LLM(DataBaseModel): class LLM(DataBaseModel):
# defautlt LLMs for every users # defautlt LLMs for every users
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True) llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
fid = CharField(max_length=128, null=False, help_text="LLM factory id") fid = CharField(max_length=128, null=False, help_text="LLM factory id")
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
@ -442,8 +443,8 @@ class LLM(DataBaseModel):
class TenantLLM(DataBaseModel): class TenantLLM(DataBaseModel):
tenant_id = CharField(max_length=32, null=False) tenant_id = CharField(max_length=32, null=False)
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
llm_name = CharField(max_length=128, null=False, help_text="LLM name") llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
api_key = CharField(max_length=255, null=True, help_text="API KEY") api_key = CharField(max_length=255, null=True, help_text="API KEY")
api_base = CharField(max_length=255, null=True, help_text="API Base") api_base = CharField(max_length=255, null=True, help_text="API Base")
@ -452,7 +453,7 @@ class TenantLLM(DataBaseModel):
class Meta: class Meta:
db_table = "tenant_llm" db_table = "tenant_llm"
primary_key = CompositeKey('tenant_id', 'llm_factory') primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
class Knowledgebase(DataBaseModel): class Knowledgebase(DataBaseModel):
@ -464,7 +465,9 @@ class Knowledgebase(DataBaseModel):
permission = CharField(max_length=16, null=False, help_text="me|team") permission = CharField(max_length=16, null=False, help_text="me|team")
created_by = CharField(max_length=32, null=False) created_by = CharField(max_length=32, null=False)
doc_num = IntegerField(default=0) doc_num = IntegerField(default=0)
embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID") token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID") parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")

View File

@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from peewee import Expression
from web_server.db import TenantPermission, FileType from web_server.db import TenantPermission, FileType
from web_server.db.db_models import DB, Knowledgebase from web_server.db.db_models import DB, Knowledgebase, Tenant
from web_server.db.db_models import Document from web_server.db.db_models import Document
from web_server.db.services.common_service import CommonService from web_server.db.services.common_service import CommonService
from web_server.db.services.kb_service import KnowledgebaseService from web_server.db.services.kb_service import KnowledgebaseService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum from web_server.db.db_utils import StatusEnum
@ -61,15 +62,28 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64): def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id] fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
docs = cls.model.select(fields).join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)).where( docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.where(
cls.model.status == StatusEnum.VALID.value, cls.model.status == StatusEnum.VALID.value,
cls.model.type != FileType.VIRTUAL, ~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0, cls.model.progress == 0,
cls.model.update_time >= tm, cls.model.update_time >= tm,
cls.model.create_time % (Expression(cls.model.create_time, "%%", comm) == mod))\
comm == mod).order_by( .order_by(cls.model.update_time.asc())\
cls.model.update_time.asc()).paginate( .paginate(1, items_per_page)
1,
items_per_page)
return list(docs.dicts()) return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation+duation).where(
cls.model.id == doc_id).execute()
if num == 0:raise LookupError("Document not found which is supposed to be there")
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
return num

View File

@ -17,7 +17,7 @@ import peewee
from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db import TenantPermission from web_server.db import TenantPermission
from web_server.db.db_models import DB, UserTenant from web_server.db.db_models import DB, UserTenant, Tenant
from web_server.db.db_models import Knowledgebase from web_server.db.db_models import Knowledgebase
from web_server.db.services.common_service import CommonService from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time from web_server.utils import get_uuid, get_format_time
@ -29,15 +29,42 @@ class KnowledgebaseService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc): def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, orderby, desc):
kbs = cls.model.select().where( kbs = cls.model.select().where(
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id)) ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
& (cls.model.status == StatusEnum.VALID.value) & (cls.model.status == StatusEnum.VALID.value)
) )
if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) if desc:
else: kbs = kbs.order_by(cls.model.getter_by(orderby).asc()) kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
else:
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
kbs = kbs.paginate(page_number, items_per_page) kbs = kbs.paginate(page_number, items_per_page)
return list(kbs.dicts()) return list(kbs.dicts())
@classmethod
@DB.connection_context()
def get_detail(cls, kb_id):
fields = [
cls.model.id,
Tenant.embd_id,
cls.model.avatar,
cls.model.name,
cls.model.description,
cls.model.permission,
cls.model.doc_num,
cls.model.token_num,
cls.model.chunk_num,
cls.model.parser_id]
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
(cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value)
)
if not kbs:
return
d = kbs[0].to_dict()
d["embd_id"] = kbs[0].tenant.embd_id
return d

View File

@ -33,3 +33,21 @@ class LLMService(CommonService):
class TenantLLMService(CommonService): class TenantLLMService(CommonService):
model = TenantLLM model = TenantLLM
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_type):
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
if objs and len(objs)>0 and objs[0].llm_name:
return objs[0]
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
(cls.model.tenant_id == tenant_id),
(cls.model.model_type == model_type),
(LLM.status == StatusEnum.VALID)
)
if not objs:return
return objs[0]

View File

@ -79,7 +79,7 @@ class TenantService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_by_user_id(cls, user_id): def get_by_user_id(cls, user_id):
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
return list(cls.model.select(*fields)\ return list(cls.model.select(*fields)\
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
.where(cls.model.status == StatusEnum.VALID.value).dicts()) .where(cls.model.status == StatusEnum.VALID.value).dicts())

View File

@ -143,7 +143,7 @@ def filename_type(filename):
if re.match(r".*\.pdf$", filename): if re.match(r".*\.pdf$", filename):
return FileType.PDF.value return FileType.PDF.value
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename): if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
return FileType.DOC.value return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):