add dialog api (#33)

This commit is contained in:
KevinHuSh 2024-01-17 20:20:42 +08:00 committed by GitHub
parent 6be3dd56fa
commit 9bf75d4511
50 changed files with 511 additions and 273 deletions

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -21,17 +21,17 @@ from flask import Blueprint, Flask, request
from werkzeug.wrappers.request import Request from werkzeug.wrappers.request import Request
from flask_cors import CORS from flask_cors import CORS
from web_server.db import StatusEnum from api.db import StatusEnum
from web_server.db.services import UserService from api.db.services import UserService
from web_server.utils import CustomJSONEncoder from api.utils import CustomJSONEncoder
from flask_session import Session from flask_session import Session
from flask_login import LoginManager from flask_login import LoginManager
from web_server.settings import RetCode, SECRET_KEY, stat_logger from api.settings import RetCode, SECRET_KEY, stat_logger
from web_server.hook import HookManager from api.hook import HookManager
from web_server.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters from api.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters
from web_server.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger from api.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger
from web_server.utils.api_utils import get_json_result, server_error_response from api.utils.api_utils import get_json_result, server_error_response
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
__all__ = ['app'] __all__ = ['app']
@ -68,7 +68,7 @@ def search_pages_path(pages_dir):
def register_page(page_path): def register_page(page_path):
page_name = page_path.stem.rstrip('_app') page_name = page_path.stem.rstrip('_app')
module_name = '.'.join(page_path.parts[page_path.parts.index('web_server'):-1] + (page_name, )) module_name = '.'.join(page_path.parts[page_path.parts.index('api'):-1] + (page_name, ))
spec = spec_from_file_location(module_name, page_path) spec = spec_from_file_location(module_name, page_path)
page = module_from_spec(spec) page = module_from_spec(spec)
@ -86,7 +86,7 @@ def register_page(page_path):
pages_dir = [ pages_dir = [
Path(__file__).parent, Path(__file__).parent,
Path(__file__).parent.parent / 'web_server' / 'apps', Path(__file__).parent.parent / 'api' / 'apps',
] ]
client_urls_prefix = [ client_urls_prefix = [

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,31 +13,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import base64
import hashlib import hashlib
import pathlib
import re import re
from elasticsearch_dsl import Q import numpy as np
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from rag.nlp import search, huqie from rag.nlp import search, huqie
from rag.utils import ELASTICSEARCH, rmSpace from rag.utils import ELASTICSEARCH, rmSpace
from web_server.db import LLMType from api.db import LLMType
from web_server.db.services import duplicate_name from api.db.services import duplicate_name
from web_server.db.services.kb_service import KnowledgebaseService from api.db.services.kb_service import KnowledgebaseService
from web_server.db.services.llm_service import TenantLLMService from api.db.services.llm_service import TenantLLMService
from web_server.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid from api.db.services.document_service import DocumentService
from web_server.db.services.document_service import DocumentService from api.settings import RetCode
from web_server.settings import RetCode from api.utils.api_utils import get_json_result
from web_server.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO
from web_server.utils.file_utils import filename_type
retrival = search.Dealer(ELASTICSEARCH, None) retrival = search.Dealer(ELASTICSEARCH)
@manager.route('/list', methods=['POST']) @manager.route('/list', methods=['POST'])
@login_required @login_required
@ -45,16 +40,29 @@ retrival = search.Dealer(ELASTICSEARCH, None)
def list(): def list():
req = request.json req = request.json
doc_id = req["doc_id"] doc_id = req["doc_id"]
page = req.get("page", 1) page = int(req.get("page", 1))
size = req.get("size", 30) size = int(req.get("size", 30))
question = req.get("keywords", "") question = req.get("keywords", "")
try: try:
tenants = UserTenantService.query(user_id=current_user.id) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenants: if not tenant_id: return get_data_error_result(retmsg="Tenant not found!")
return get_data_error_result(retmsg="Tenant not found!") query = {
res = retrival.search({
"doc_ids": [doc_id], "page": page, "size": size, "question": question "doc_ids": [doc_id], "page": page, "size": size, "question": question
}, search.index_name(tenants[0].tenant_id)) }
if "available_int" in req: query["available_int"] = int(req["available_int"])
sres = retrival.search(query, search.index_name(tenant_id))
res = {"total": sres.total, "chunks": []}
for id in sres.ids:
d = {
"chunk_id": id,
"content_ltks": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_ltks"],
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""),
"available_int": sres.field[id].get("available_int", 1),
}
res["chunks"].append(d)
return get_json_result(data=res) return get_json_result(data=res)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
@ -102,6 +110,7 @@ def set():
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req["important_kwd"] d["important_kwd"] = req["important_kwd"]
d["important_tks"] = huqie.qie(" ".join(req["important_kwd"])) d["important_tks"] = huqie.qie(" ".join(req["important_kwd"]))
if "available_int" in req: d["available_int"] = req["available_int"]
try: try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
@ -116,10 +125,27 @@ def set():
return server_error_response(e) return server_error_response(e)
@manager.route('/switch', methods=['POST'])
@login_required
@validate_request("chunk_ids", "available_int", "doc_id")
def switch():
req = request.json
try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: return get_data_error_result(retmsg="Tenant not found!")
if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]],
search.index_name(tenant_id)):
return get_data_error_result(retmsg="Index updating failure")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@manager.route('/create', methods=['POST']) @manager.route('/create', methods=['POST'])
@login_required @login_required
@validate_request("doc_id", "content_ltks", "important_kwd") @validate_request("doc_id", "content_ltks", "important_kwd")
def set(): def create():
req = request.json req = request.json
md5 = hashlib.md5() md5 = hashlib.md5()
md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8")) md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8"))
@ -148,3 +174,64 @@ def set():
return get_json_result(data={"chunk_id": chunck_id}) return get_json_result(data={"chunk_id": chunck_id})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/retrieval_test', methods=['POST'])
@login_required
@validate_request("kb_id", "question")
def retrieval_test():
req = request.json
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
kb_id = req["kb_id"]
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.4))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top", 1024))
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return get_data_error_result(retmsg="Knowledgebase not found!")
embd_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.EMBEDDING.value)
sres = retrival.search({"kb_ids": [kb_id], "doc_ids": doc_ids, "size": top,
"question": question, "vector": True,
"similarity": similarity_threshold},
search.index_name(kb.tenant_id),
embd_mdl)
sim, tsim, vsim = retrival.rerank(sres, question, 1-vector_similarity_weight, vector_similarity_weight)
idx = np.argsort(sim*-1)
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
start_idx = (page-1)*size
for i in idx:
ranks["total"] += 1
if sim[i] < similarity_threshold: break
start_idx -= 1
if start_idx >= 0:continue
if len(ranks["chunks"]) == size:continue
id = sres.ids[i]
dnm = sres.field[id]["docnm_kwd"]
d = {
"chunk_id": id,
"content_ltks": sres.field[id]["content_ltks"],
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": dnm,
"kb_id": sres.field[id]["kb_id"],
"important_kwd": sres.field[id].get("important_kwd", []),
"img_id": sres.field[id].get("img_id", ""),
"similarity": sim[i],
"vector_similarity": vsim[i],
"term_similarity": tsim[i]
}
ranks["chunks"].append(d)
if dnm not in ranks["doc_aggs"]:ranks["doc_aggs"][dnm] = 0
ranks["doc_aggs"][dnm] += 1
return get_json_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, retmsg=f'Index not found!',
retcode=RetCode.DATA_ERROR)
return server_error_response(e)

163
api/apps/dialog_app.py Normal file
View File

@ -0,0 +1,163 @@
#
# Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hashlib
import re
import numpy as np
from flask import request
from flask_login import login_required, current_user
from api.db.services.dialog_service import DialogService
from rag.nlp import search, huqie
from rag.utils import ELASTICSEARCH, rmSpace
from api.db import LLMType, StatusEnum
from api.db.services import duplicate_name
from api.db.services.kb_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import UserTenantService, TenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.db.services.document_service import DocumentService
from api.settings import RetCode, stat_logger
from api.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO
from api.utils.file_utils import filename_type
@manager.route('/set', methods=['POST'])
@login_required
def set():
req = request.json
dialog_id = req.get("dialog_id")
name = req.get("name", "New Dialog")
description = req.get("description", "A helpful Dialog")
language = req.get("language", "Chinese")
llm_setting_type = req.get("llm_setting_type", "Precise")
llm_setting = req.get("llm_setting", {
"Creative": {
"temperature": 0.9,
"top_p": 0.9,
"frequency_penalty": 0.2,
"presence_penalty": 0.4,
"max_tokens": 512
},
"Precise": {
"temperature": 0.1,
"top_p": 0.3,
"frequency_penalty": 0.7,
"presence_penalty": 0.4,
"max_tokens": 215
},
"Evenly": {
"temperature": 0.5,
"top_p": 0.5,
"frequency_penalty": 0.7,
"presence_penalty": 0.4,
"max_tokens": 215
},
"Custom": {
"temperature": 0.2,
"top_p": 0.3,
"frequency_penalty": 0.6,
"presence_penalty": 0.3,
"max_tokens": 215
},
})
prompt_config = req.get("prompt_config", {
"system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。
以下是知识库
{knowledge}
以上是知识库""",
"prologue": "您好我是您的助手小樱长得可爱又善良can I help you?",
"parameters": [
{"key": "knowledge", "optional": False}
],
"empty_response": "Sorry! 知识库中未找到相关内容!"
})
if len(prompt_config["parameters"]) < 1:
return get_data_error_result(retmsg="'knowledge' should be in parameters")
for p in prompt_config["parameters"]:
if prompt_config["system"].find("{%s}"%p["key"]) < 0:
return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"]))
try:
e, tenant = TenantService.get_by_id(current_user.id)
if not e:return get_data_error_result(retmsg="Tenant not found!")
llm_id = req.get("llm_id", tenant.llm_id)
if not dialog_id:
dia = {
"id": get_uuid(),
"tenant_id": current_user.id,
"name": name,
"description": description,
"language": language,
"llm_id": llm_id,
"llm_setting_type": llm_setting_type,
"llm_setting": llm_setting,
"prompt_config": prompt_config
}
if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!")
e, dia = DialogService.get_by_id(dia["id"])
if not e: return get_data_error_result(retmsg="Fail to new a dialog!")
return get_json_result(data=dia.to_json())
else:
del req["dialog_id"]
if "kb_names" in req: del req["kb_names"]
if not DialogService.update_by_id(dialog_id, req):
return get_data_error_result(retmsg="Dialog not found!")
e, dia = DialogService.get_by_id(dialog_id)
if not e: return get_data_error_result(retmsg="Fail to update a dialog!")
dia = dia.to_dict()
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
return get_json_result(data=dia)
except Exception as e:
return server_error_response(e)
@manager.route('/get', methods=['GET'])
@login_required
def get():
dialog_id = request.args["dialog_id"]
try:
e,dia = DialogService.get_by_id(dialog_id)
if not e: return get_data_error_result(retmsg="Dialog not found!")
dia = dia.to_dict()
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
return get_json_result(data=dia)
except Exception as e:
return server_error_response(e)
def get_kb_names(kb_ids):
ids, nms = [], []
for kid in kb_ids:
e, kb = KnowledgebaseService.get_by_id(kid)
if not e or kb.status != StatusEnum.VALID.value:continue
ids.append(kid)
nms.append(kb.name)
return ids, nms
@manager.route('/list', methods=['GET'])
@login_required
def list():
try:
diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value)
diags = [d.to_dict() for d in diags]
for d in diags:
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
return get_json_result(data=diags)
except Exception as e:
return server_error_response(e)

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,22 +16,23 @@
import base64 import base64
import pathlib import pathlib
import flask
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from rag.nlp import search from rag.nlp import search
from rag.utils import ELASTICSEARCH from rag.utils import ELASTICSEARCH
from web_server.db.services import duplicate_name from api.db.services import duplicate_name
from web_server.db.services.kb_service import KnowledgebaseService from api.db.services.kb_service import KnowledgebaseService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid from api.utils import get_uuid
from web_server.db import FileType from api.db import FileType
from web_server.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from web_server.settings import RetCode from api.settings import RetCode
from web_server.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO from rag.utils.minio_conn import MINIO
from web_server.utils.file_utils import filename_type from api.utils.file_utils import filename_type
@manager.route('/upload', methods=['POST']) @manager.route('/upload', methods=['POST'])
@ -163,21 +164,13 @@ def change_status():
if str(req["status"]) == "0": if str(req["status"]) == "0":
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
scripts=""" scripts="ctx._source.available_int=0;",
if(ctx._source.kb_id.contains('%s'))
ctx._source.kb_id.remove(
ctx._source.kb_id.indexOf('%s')
);
""" % (doc.kb_id, doc.kb_id),
idxnm=search.index_name( idxnm=search.index_name(
kb.tenant_id) kb.tenant_id)
) )
else: else:
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
scripts=""" scripts="ctx._source.available_int=1;",
if(!ctx._source.kb_id.contains('%s'))
ctx._source.kb_id.add('%s');
""" % (doc.kb_id, doc.kb_id),
idxnm=search.index_name( idxnm=search.index_name(
kb.tenant_id) kb.tenant_id)
) )
@ -195,8 +188,7 @@ def rm():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(retmsg="Document not found!") return get_data_error_result(retmsg="Document not found!")
if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)): ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id))
return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0) DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
if not DocumentService.delete_by_id(req["doc_id"]): if not DocumentService.delete_by_id(req["doc_id"]):
@ -277,3 +269,15 @@ def change_parser():
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/image/<image_id>', methods=['GET'])
@login_required
def get_image(image_id):
try:
bkt, nm = image_id.split("-")
response = flask.make_response(MINIO.get(bkt, nm))
response.headers.set('Content-Type', 'image/JPEG')
return response
except Exception as e:
return server_error_response(e)

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,15 +16,15 @@
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from web_server.db.services import duplicate_name from api.db.services import duplicate_name
from web_server.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid, get_format_time from api.utils import get_uuid, get_format_time
from web_server.db import StatusEnum, UserTenantRole from api.db import StatusEnum, UserTenantRole
from web_server.db.services.kb_service import KnowledgebaseService from api.db.services.kb_service import KnowledgebaseService
from web_server.db.db_models import Knowledgebase from api.db.db_models import Knowledgebase
from web_server.settings import stat_logger, RetCode from api.settings import stat_logger, RetCode
from web_server.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
@manager.route('/create', methods=['post']) @manager.route('/create', methods=['post'])

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,16 +16,16 @@
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from web_server.db.services import duplicate_name from api.db.services import duplicate_name
from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
from web_server.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from web_server.utils import get_uuid, get_format_time from api.utils import get_uuid, get_format_time
from web_server.db import StatusEnum, UserTenantRole from api.db import StatusEnum, UserTenantRole
from web_server.db.services.kb_service import KnowledgebaseService from api.db.services.kb_service import KnowledgebaseService
from web_server.db.db_models import Knowledgebase, TenantLLM from api.db.db_models import Knowledgebase, TenantLLM
from web_server.settings import stat_logger, RetCode from api.settings import stat_logger, RetCode
from web_server.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
@manager.route('/factories', methods=['GET']) @manager.route('/factories', methods=['GET'])

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,15 +17,15 @@ from flask import request, session, redirect, url_for
from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.security import generate_password_hash, check_password_hash
from flask_login import login_required, current_user, login_user, logout_user from flask_login import login_required, current_user, login_user, logout_user
from web_server.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from web_server.db.services.llm_service import TenantLLMService from api.db.services.llm_service import TenantLLMService
from web_server.utils.api_utils import server_error_response, validate_request from api.utils.api_utils import server_error_response, validate_request
from web_server.utils import get_uuid, get_format_time, decrypt, download_img from api.utils import get_uuid, get_format_time, decrypt, download_img
from web_server.db import UserTenantRole, LLMType from api.db import UserTenantRole, LLMType
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS from api.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
from web_server.db.services.user_service import UserService, TenantService, UserTenantService from api.db.services.user_service import UserService, TenantService, UserTenantService
from web_server.settings import stat_logger from api.settings import stat_logger
from web_server.utils.api_utils import get_json_result, cors_reponse from api.utils.api_utils import get_json_result, cors_reponse
@manager.route('/login', methods=['POST', 'GET']) @manager.route('/login', methods=['POST', 'GET'])

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -51,4 +51,11 @@ class LLMType(StrEnum):
CHAT = 'chat' CHAT = 'chat'
EMBEDDING = 'embedding' EMBEDDING = 'embedding'
SPEECH2TEXT = 'speech2text' SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text' IMAGE2TEXT = 'image2text'
class ChatStyle(StrEnum):
CREATIVE = 'Creative'
PRECISE = 'Precise'
EVENLY = 'Evenly'
CUSTOM = 'Custom'

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -29,10 +29,10 @@ from peewee import (
) )
from playhouse.pool import PooledMySQLDatabase from playhouse.pool import PooledMySQLDatabase
from web_server.db import SerializedType from api.db import SerializedType
from web_server.settings import DATABASE, stat_logger, SECRET_KEY from api.settings import DATABASE, stat_logger, SECRET_KEY
from web_server.utils.log_utils import getLogger from api.utils.log_utils import getLogger
from web_server import utils from api import utils
LOGGER = getLogger() LOGGER = getLogger()
@ -467,6 +467,8 @@ class Knowledgebase(DataBaseModel):
doc_num = IntegerField(default=0) doc_num = IntegerField(default=0)
token_num = IntegerField(default=0) token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0) chunk_num = IntegerField(default=0)
similarity_threshold = FloatField(default=0.4)
vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID") parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
@ -516,19 +518,20 @@ class Dialog(DataBaseModel):
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced") prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?", prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好我是您的助手小樱长得可爱又善良can I help you?",
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"}) "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
kb_ids = JSONField(null=False, default=[])
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
class Meta: class Meta:
db_table = "dialog" db_table = "dialog"
class DialogKb(DataBaseModel): # class DialogKb(DataBaseModel):
dialog_id = CharField(max_length=32, null=False, index=True) # dialog_id = CharField(max_length=32, null=False, index=True)
kb_id = CharField(max_length=32, null=False) # kb_id = CharField(max_length=32, null=False)
#
class Meta: # class Meta:
db_table = "dialog_kb" # db_table = "dialog_kb"
primary_key = CompositeKey('dialog_id', 'kb_id') # primary_key = CompositeKey('dialog_id', 'kb_id')
class Conversation(DataBaseModel): class Conversation(DataBaseModel):

View File

@ -1,5 +1,5 @@
# #
# Copyright 2021 The RAG Flow Authors. All Rights Reserved. # Copyright 2021 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -19,10 +19,10 @@ import time
from functools import wraps from functools import wraps
from shortuuid import ShortUUID from shortuuid import ShortUUID
from web_server.versions import get_rag_version from api.versions import get_rag_version
from web_server.errors.error_services import * from api.errors.error_services import *
from web_server.settings import ( from api.settings import (
GRPC_PORT, HOST, HTTP_PORT, GRPC_PORT, HOST, HTTP_PORT,
RANDOM_INSTANCE_ID, stat_logger, RANDOM_INSTANCE_ID, stat_logger,
) )

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,11 +17,11 @@ import operator
from functools import reduce from functools import reduce
from typing import Dict, Type, Union from typing import Dict, Type, Union
from web_server.utils import current_timestamp, timestamp_to_date from api.utils import current_timestamp, timestamp_to_date
from web_server.db.db_models import DB, DataBaseModel from api.db.db_models import DB, DataBaseModel
from web_server.db.runtime_config import RuntimeConfig from api.db.runtime_config import RuntimeConfig
from web_server.utils.log_utils import getLogger from api.utils.log_utils import getLogger
from enum import Enum from enum import Enum
@ -123,9 +123,3 @@ def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0,
data = data.offset(offset) data = data.offset(offset)
return list(data), count return list(data), count
class StatusEnum(Enum):
# 样本可用状态
VALID = "1"
IN_VALID = "0"

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,10 +16,10 @@
import time import time
import uuid import uuid
from web_server.db import LLMType from api.db import LLMType
from web_server.db.db_models import init_database_tables as init_web_db from api.db.db_models import init_database_tables as init_web_db
from web_server.db.services import UserService from api.db.services import UserService
from web_server.db.services.llm_service import LLMFactoriesService, LLMService from api.db.services.llm_service import LLMFactoriesService, LLMService
def init_superuser(): def init_superuser():

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,5 +17,5 @@
import operator import operator
import time import time
import typing import typing
from web_server.utils.log_utils import sql_logger from api.utils.log_utils import sql_logger
import peewee import peewee

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from web_server.versions import get_versions from api.versions import get_versions
from .reload_config_base import ReloadConfigBase from .reload_config_base import ReloadConfigBase

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,8 +17,8 @@ from datetime import datetime
import peewee import peewee
from web_server.db.db_models import DB from api.db.db_models import DB
from web_server.utils import datetime_format from api.utils import datetime_format
class CommonService: class CommonService:

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,14 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import peewee from api.db.db_models import Dialog, Conversation
from werkzeug.security import generate_password_hash, check_password_hash from api.db.services.common_service import CommonService
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import Dialog, Conversation, DialogKb
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class DialogService(CommonService): class DialogService(CommonService):
@ -29,7 +23,3 @@ class DialogService(CommonService):
class ConversationService(CommonService): class ConversationService(CommonService):
model = Conversation model = Conversation
class DialogKbService(CommonService):
model = DialogKb

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,12 +15,12 @@
# #
from peewee import Expression from peewee import Expression
from web_server.db import TenantPermission, FileType from api.db import TenantPermission, FileType
from web_server.db.db_models import DB, Knowledgebase, Tenant from api.db.db_models import DB, Knowledgebase, Tenant
from web_server.db.db_models import Document from api.db.db_models import Document
from web_server.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from web_server.db.services.kb_service import KnowledgebaseService from api.db.services.kb_service import KnowledgebaseService
from web_server.db.db_utils import StatusEnum from api.db import StatusEnum
class DocumentService(CommonService): class DocumentService(CommonService):

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,15 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db import TenantPermission from api.db import TenantPermission
from web_server.db.db_models import DB, UserTenant, Tenant from api.db.db_models import DB, Tenant
from web_server.db.db_models import Knowledgebase from api.db.db_models import Knowledgebase
from web_server.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time from api.db import StatusEnum
from web_server.db.db_utils import StatusEnum
class KnowledgebaseService(CommonService): class KnowledgebaseService(CommonService):

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,14 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import peewee from api.db.db_models import Knowledgebase, Document
from werkzeug.security import generate_password_hash, check_password_hash from api.db.services.common_service import CommonService
from web_server.db.db_models import DB, UserTenant
from web_server.db.db_models import Knowledgebase, Document
from web_server.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum
class KnowledgebaseService(CommonService): class KnowledgebaseService(CommonService):

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,15 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from rag.llm import EmbeddingModel, CvModel from rag.llm import EmbeddingModel, CvModel
from web_server.db import LLMType from api.db import LLMType
from web_server.db.db_models import DB, UserTenant from api.db.db_models import DB, UserTenant
from web_server.db.db_models import LLMFactories, LLM, TenantLLM from api.db.db_models import LLMFactories, LLM, TenantLLM
from web_server.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from web_server.db.db_utils import StatusEnum from api.db import StatusEnum
class LLMFactoriesService(CommonService): class LLMFactoriesService(CommonService):

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,12 +16,12 @@
import peewee import peewee
from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.security import generate_password_hash, check_password_hash
from web_server.db import UserTenantRole from api.db import UserTenantRole
from web_server.db.db_models import DB, UserTenant from api.db.db_models import DB, UserTenant
from web_server.db.db_models import User, Tenant from api.db.db_models import User, Tenant
from web_server.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from web_server.utils import get_uuid, get_format_time from api.utils import get_uuid, get_format_time
from web_server.db.db_utils import StatusEnum from api.db import StatusEnum
class UserService(CommonService): class UserService(CommonService):

View File

@ -1,4 +1,4 @@
from web_server.errors import RagFlowError from api.errors import RagFlowError
__all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured', __all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured',
'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError'] 'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError']

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,8 +1,8 @@
import importlib import importlib
from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, \ from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, \
SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters
from web_server.settings import HOOK_MODULE, stat_logger,RetCode from api.settings import HOOK_MODULE, stat_logger,RetCode
class HookManager: class HookManager:

View File

@ -1,10 +1,10 @@
import requests import requests
from web_server.db.service_registry import ServiceRegistry from api.db.service_registry import ServiceRegistry
from web_server.settings import RegistryServiceName from api.settings import RegistryServiceName
from web_server.hook import HookManager from api.hook import HookManager
from web_server.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn from api.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn
from web_server.settings import HOOK_SERVER_NAME from api.settings import HOOK_SERVER_NAME
@HookManager.register_client_authentication_hook @HookManager.register_client_authentication_hook

View File

@ -1,10 +1,10 @@
import requests import requests
from web_server.db.service_registry import ServiceRegistry from api.db.service_registry import ServiceRegistry
from web_server.settings import RegistryServiceName from api.settings import RegistryServiceName
from web_server.hook import HookManager from api.hook import HookManager
from web_server.hook.common.parameters import PermissionCheckParameters, PermissionReturn from api.hook.common.parameters import PermissionCheckParameters, PermissionReturn
from web_server.settings import HOOK_SERVER_NAME from api.settings import HOOK_SERVER_NAME
@HookManager.register_permission_check_hook @HookManager.register_permission_check_hook

View File

@ -1,11 +1,11 @@
import requests import requests
from web_server.db.service_registry import ServiceRegistry from api.db.service_registry import ServiceRegistry
from web_server.settings import RegistryServiceName from api.settings import RegistryServiceName
from web_server.hook import HookManager from api.hook import HookManager
from web_server.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\ from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\
SignatureReturn SignatureReturn
from web_server.settings import HOOK_SERVER_NAME, PARTY_ID from api.settings import HOOK_SERVER_NAME, PARTY_ID
@HookManager.register_site_signature_hook @HookManager.register_site_signature_hook

View File

@ -1,4 +1,4 @@
from web_server.settings import RetCode from api.settings import RetCode
class ParametersBase: class ParametersBase:

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -23,17 +23,17 @@ import traceback
from werkzeug.serving import run_simple from werkzeug.serving import run_simple
from web_server.apps import app from api.apps import app
from web_server.db.runtime_config import RuntimeConfig from api.db.runtime_config import RuntimeConfig
from web_server.hook import HookManager from api.hook import HookManager
from web_server.settings import ( from api.settings import (
HOST, HTTP_PORT, access_logger, database_logger, stat_logger, HOST, HTTP_PORT, access_logger, database_logger, stat_logger,
) )
from web_server import utils from api import utils
from web_server.db.db_models import init_database_tables as init_web_db from api.db.db_models import init_database_tables as init_web_db
from web_server.db.init_data import init_web_data from api.db.init_data import init_web_data
from web_server.versions import get_versions from api.versions import get_versions
if __name__ == '__main__': if __name__ == '__main__':
stat_logger.info( stat_logger.info(

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,9 +17,9 @@ import os
from enum import IntEnum, Enum from enum import IntEnum, Enum
from web_server.utils import get_base_config,decrypt_database_config from api.utils import get_base_config,decrypt_database_config
from web_server.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from web_server.utils.log_utils import LoggerFactory, getLogger from api.utils.log_utils import LoggerFactory, getLogger
# Server # Server
@ -71,7 +71,7 @@ PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
DATABASE = decrypt_database_config() DATABASE = decrypt_database_config()
# Logger # Logger
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "web_server")) LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "api"))
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
LoggerFactory.LEVEL = 10 LoggerFactory.LEVEL = 10

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -24,16 +24,16 @@ from flask import (
) )
from werkzeug.http import HTTP_STATUS_CODES from werkzeug.http import HTTP_STATUS_CODES
from web_server.utils import json_dumps from api.utils import json_dumps
from web_server.versions import get_rag_version from api.versions import get_rag_version
from web_server.settings import RetCode from api.settings import RetCode
from web_server.settings import ( from api.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
) )
import requests import requests
import functools import functools
from web_server.utils import CustomJSONEncoder from api.utils import CustomJSONEncoder
from uuid import uuid1 from uuid import uuid1
from base64 import b64encode from base64 import b64encode
from hmac import HMAC from hmac import HMAC

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -21,7 +21,7 @@ import re
from cachetools import LRUCache, cached from cachetools import LRUCache, cached
from ruamel.yaml import YAML from ruamel.yaml import YAML
from web_server.db import FileType from api.db import FileType
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
RAG_BASE = os.getenv("RAG_BASE") RAG_BASE = os.getenv("RAG_BASE")

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -21,7 +21,7 @@ import inspect
from logging.handlers import TimedRotatingFileHandler from logging.handlers import TimedRotatingFileHandler
from threading import RLock from threading import RLock
from web_server.utils import file_utils from api.utils import file_utils
class LoggerFactory(object): class LoggerFactory(object):
TYPE = "FILE" TYPE = "FILE"

View File

@ -1,7 +1,7 @@
import base64, os, sys import base64, os, sys
from Cryptodome.PublicKey import RSA from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
from web_server.utils import decrypt, file_utils from api.utils import decrypt, file_utils
def crypt(line): def crypt(line):
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem") file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem")

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,7 +18,7 @@ import os
import dotenv import dotenv
import typing import typing
from web_server.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
def get_versions() -> typing.Mapping[str, typing.Any]: def get_versions() -> typing.Mapping[str, typing.Any]:

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -60,6 +60,10 @@ class HuEmbedding(Base):
res.extend(self.model.encode(texts[i:i + batch_size]).tolist()) res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
return np.array(res), token_count return np.array(res), token_count
def encode_queries(self, text: str):
token_count = num_tokens_from_string(text)
return self.model.encode_queries([text]).tolist()[0], token_count
class OpenAIEmbed(Base): class OpenAIEmbed(Base):
def __init__(self, key, model_name="text-embedding-ada-002"): def __init__(self, key, model_name="text-embedding-ada-002"):

View File

@ -9,7 +9,7 @@ import string
import sys import sys
from hanziconv import HanziConv from hanziconv import HanziConv
from web_server.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
class Huqie: class Huqie:

View File

@ -147,7 +147,7 @@ class EsQueryer:
atks = toDict(atks) atks = toDict(atks)
btkss = [toDict(tks) for tks in btkss] btkss = [toDict(tks) for tks in btkss]
tksim = [self.similarity(atks, btks) for btks in btkss] tksim = [self.similarity(atks, btks) for btks in btkss]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, sims[0], tksim
def similarity(self, qtwt, dtwt): def similarity(self, qtwt, dtwt):
if isinstance(dtwt, type("")): if isinstance(dtwt, type("")):

View File

@ -15,7 +15,7 @@ def index_name(uid): return f"ragflow_{uid}"
class Dealer: class Dealer:
def __init__(self, es, emb_mdl): def __init__(self, es):
self.qryr = query.EsQueryer(es) self.qryr = query.EsQueryer(es)
self.qryr.flds = [ self.qryr.flds = [
"title_tks^10", "title_tks^10",
@ -23,7 +23,6 @@ class Dealer:
"content_ltks^2", "content_ltks^2",
"content_sm_ltks"] "content_sm_ltks"]
self.es = es self.es = es
self.emb_mdl = emb_mdl
@dataclass @dataclass
class SearchResult: class SearchResult:
@ -36,23 +35,26 @@ class Dealer:
keywords: Optional[List[str]] = None keywords: Optional[List[str]] = None
group_docs: List[List] = None group_docs: List[List] = None
def _vector(self, txt, sim=0.8, topk=10): def _vector(self, txt, emb_mdl, sim=0.8, topk=10):
qv = self.emb_mdl.encode_queries(txt) qv, c = emb_mdl.encode_queries(txt)
return { return {
"field": "q_%d_vec"%len(qv), "field": "q_%d_vec"%len(qv),
"k": topk, "k": topk,
"similarity": sim, "similarity": sim,
"num_candidates": 1000, "num_candidates": topk*2,
"query_vector": qv "query_vector": qv
} }
def search(self, req, idxnm, tks_num=3): def search(self, req, idxnm, emb_mdl=None):
qst = req.get("question", "") qst = req.get("question", "")
bqry, keywords = self.qryr.question(qst) bqry, keywords = self.qryr.question(qst)
if req.get("kb_ids"): if req.get("kb_ids"):
bqry.filter.append(Q("terms", kb_id=req["kb_ids"])) bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
if req.get("doc_ids"): if req.get("doc_ids"):
bqry.filter.append(Q("terms", doc_id=req["doc_ids"])) bqry.filter.append(Q("terms", doc_id=req["doc_ids"]))
if "available_int" in req:
if req["available_int"] == 0: bqry.filter.append(Q("range", available_int={"lt": 1}))
else: bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1})))
bqry.boost = 0.05 bqry.boost = 0.05
s = Search() s = Search()
@ -60,7 +62,7 @@ class Dealer:
ps = int(req.get("size", 1000)) ps = int(req.get("size", 1000))
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id","img_id", src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id","img_id",
"image_id", "doc_id", "q_512_vec", "q_768_vec", "image_id", "doc_id", "q_512_vec", "q_768_vec",
"q_1024_vec", "q_1536_vec"]) "q_1024_vec", "q_1536_vec", "available_int"])
s = s.query(bqry)[pg * ps:(pg + 1) * ps] s = s.query(bqry)[pg * ps:(pg + 1) * ps]
s = s.highlight("content_ltks") s = s.highlight("content_ltks")
@ -80,7 +82,8 @@ class Dealer:
s = s.to_dict() s = s.to_dict()
q_vec = [] q_vec = []
if req.get("vector"): if req.get("vector"):
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps) assert emb_mdl, "No embedding model selected"
s["knn"] = self._vector(qst, emb_mdl, req.get("similarity", 0.4), ps)
s["knn"]["filter"] = bqry.to_dict() s["knn"]["filter"] = bqry.to_dict()
if "highlight" in s: del s["highlight"] if "highlight" in s: del s["highlight"]
q_vec = s["knn"]["query_vector"] q_vec = s["knn"]["query_vector"]
@ -168,7 +171,7 @@ class Dealer:
def trans2floats(txt): def trans2floats(txt):
return [float(t) for t in txt.split("\t")] return [float(t) for t in txt.split("\t")]
def insert_citations(self, ans, top_idx, sres, def insert_citations(self, ans, top_idx, sres, emb_mdl,
vfield="q_vec", cfield="content_ltks"): vfield="q_vec", cfield="content_ltks"):
ins_embd = [Dealer.trans2floats( ins_embd = [Dealer.trans2floats(
@ -179,15 +182,14 @@ class Dealer:
res = "" res = ""
def citeit(): def citeit():
nonlocal s, e, ans, res nonlocal s, e, ans, res, emb_mdl
if not ins_embd: if not ins_embd:
return return
embd = self.emb_mdl.encode(ans[s: e]) embd = emb_mdl.encode(ans[s: e])
sim = self.qryr.hybrid_similarity(embd, sim = self.qryr.hybrid_similarity(embd,
ins_embd, ins_embd,
huqie.qie(ans[s:e]).split(" "), huqie.qie(ans[s:e]).split(" "),
ins_tw) ins_tw)
print(ans[s: e], sim)
mx = np.max(sim) * 0.99 mx = np.max(sim) * 0.99
if mx < 0.55: if mx < 0.55:
return return
@ -225,20 +227,18 @@ class Dealer:
return res return res
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"):
vfield="q_vec", cfield="content_ltks"):
ins_embd = [ ins_embd = [
Dealer.trans2floats( Dealer.trans2floats(
sres.field[i]["q_vec"]) for i in sres.ids] sres.field[i]["q_%d_vec"%len(sres.query_vector)]) for i in sres.ids]
if not ins_embd: if not ins_embd:
return [] return []
ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids] ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids]
# return CosineSimilarity([sres.query_vector], ins_embd)[0] sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
sim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd, ins_embd,
huqie.qie(query).split(" "), huqie.qie(query).split(" "),
ins_tw, tkweight, vtweight) ins_tw, tkweight, vtweight)
return sim return sim, tksim, vtsim

View File

@ -4,7 +4,7 @@ import time
import logging import logging
import re import re
from web_server.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
class Dealer: class Dealer:

View File

@ -5,7 +5,7 @@ import re
import os import os
import numpy as np import numpy as np
from rag.nlp import huqie from rag.nlp import huqie
from web_server.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
class Dealer: class Dealer:

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
# #
import os import os
from web_server.utils import get_base_config,decrypt_database_config from api.utils import get_base_config,decrypt_database_config
from web_server.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from web_server.utils.log_utils import LoggerFactory, getLogger from api.utils.log_utils import LoggerFactory, getLogger
# Server # Server

View File

@ -1,5 +1,5 @@
# #
# Copyright 2019 The RAG Flow Authors. All Rights Reserved. # Copyright 2019 The InfiniFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -47,12 +47,12 @@ from rag.nlp.huchunk import (
PptChunker, PptChunker,
TextChunker TextChunker
) )
from web_server.db import LLMType from api.db import LLMType
from web_server.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from web_server.db.services.llm_service import TenantLLMService from api.db.services.llm_service import TenantLLMService
from web_server.settings import database_logger from api.settings import database_logger
from web_server.utils import get_format_time from api.utils import get_format_time
from web_server.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
BATCH_SIZE = 64 BATCH_SIZE = 64
@ -257,7 +257,6 @@ def main(comm, mod):
cron_logger.error(str(e)) cron_logger.error(str(e))
continue continue
set_progress(r["id"], random.randint(70, 95) / 100., set_progress(r["id"], random.randint(70, 95) / 100.,
"Finished embedding! Start to build index!") "Finished embedding! Start to build index!")
init_kb(r) init_kb(r)

View File

@ -66,7 +66,6 @@ class HuEs:
body=d, body=d,
id=id, id=id,
refresh=False, refresh=False,
doc_type="_doc",
retry_on_conflict=100) retry_on_conflict=100)
es_logger.info("Successfully upsert: %s" % id) es_logger.info("Successfully upsert: %s" % id)
T = True T = True