mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-21 05:29:57 +08:00

### What problem does this PR solve? Feat: Add user registration toggle feature. Added a user registration toggle REGISTER_ENABLED in the settings and .env config file. The user creation interface now checks the state of this toggle to control the enabling and disabling of the user registration feature. the front-end implementation is done, the registration button does not appear if registration is not allowed. I did the actual tests on my local server and it worked smoothly. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: wenju.li <wenju.li@deepctr.cn> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
180 lines
6.1 KiB
Python
180 lines
6.1 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import os
|
|
from datetime import date
|
|
from enum import IntEnum, Enum
|
|
import json
|
|
import rag.utils.es_conn
|
|
import rag.utils.infinity_conn
|
|
|
|
import rag.utils
|
|
from rag.nlp import search
|
|
from graphrag import search as kg_search
|
|
from api.utils import get_base_config, decrypt_database_config
|
|
from api.constants import RAG_FLOW_SERVICE_NAME
|
|
from api.utils.file_utils import get_project_base_directory
|
|
|
|
LIGHTEN = int(os.environ.get('LIGHTEN', "0"))
|
|
|
|
LLM = None
|
|
LLM_FACTORY = None
|
|
LLM_BASE_URL = None
|
|
CHAT_MDL = ""
|
|
EMBEDDING_MDL = ""
|
|
RERANK_MDL = ""
|
|
ASR_MDL = ""
|
|
IMAGE2TEXT_MDL = ""
|
|
API_KEY = None
|
|
PARSERS = None
|
|
HOST_IP = None
|
|
HOST_PORT = None
|
|
SECRET_KEY = None
|
|
FACTORY_LLM_INFOS = None
|
|
|
|
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
|
|
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
|
|
|
# authentication
|
|
AUTHENTICATION_CONF = None
|
|
|
|
# client
|
|
CLIENT_AUTHENTICATION = None
|
|
HTTP_APP_KEY = None
|
|
GITHUB_OAUTH = None
|
|
FEISHU_OAUTH = None
|
|
|
|
DOC_ENGINE = None
|
|
docStoreConn = None
|
|
|
|
retrievaler = None
|
|
kg_retrievaler = None
|
|
|
|
# user registration switch
|
|
REGISTER_ENABLED = 1
|
|
|
|
|
|
def init_settings():
|
|
global LLM, LLM_FACTORY, LLM_BASE_URL, LIGHTEN, DATABASE_TYPE, DATABASE, FACTORY_LLM_INFOS, REGISTER_ENABLED
|
|
LIGHTEN = int(os.environ.get('LIGHTEN', "0"))
|
|
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
|
|
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
|
LLM = get_base_config("user_default_llm", {})
|
|
LLM_DEFAULT_MODELS = LLM.get("default_models", {})
|
|
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
|
LLM_BASE_URL = LLM.get("base_url")
|
|
try:
|
|
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
with open(os.path.join(get_project_base_directory(), "conf", "llm_factories.json"), "r") as f:
|
|
FACTORY_LLM_INFOS = json.load(f)["factory_llm_infos"]
|
|
except Exception:
|
|
FACTORY_LLM_INFOS = []
|
|
|
|
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
|
if not LIGHTEN:
|
|
EMBEDDING_MDL = "BAAI/bge-large-zh-v1.5@BAAI"
|
|
|
|
if LLM_DEFAULT_MODELS:
|
|
CHAT_MDL = LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL)
|
|
EMBEDDING_MDL = LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL)
|
|
RERANK_MDL = LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL)
|
|
ASR_MDL = LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL)
|
|
IMAGE2TEXT_MDL = LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL)
|
|
|
|
# factory can be specified in the config name with "@". LLM_FACTORY will be used if not specified
|
|
CHAT_MDL = CHAT_MDL + (f"@{LLM_FACTORY}" if "@" not in CHAT_MDL and CHAT_MDL != "" else "")
|
|
EMBEDDING_MDL = EMBEDDING_MDL + (f"@{LLM_FACTORY}" if "@" not in EMBEDDING_MDL and EMBEDDING_MDL != "" else "")
|
|
RERANK_MDL = RERANK_MDL + (f"@{LLM_FACTORY}" if "@" not in RERANK_MDL and RERANK_MDL != "" else "")
|
|
ASR_MDL = ASR_MDL + (f"@{LLM_FACTORY}" if "@" not in ASR_MDL and ASR_MDL != "" else "")
|
|
IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + (
|
|
f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "")
|
|
|
|
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
|
API_KEY = LLM.get("api_key", "")
|
|
PARSERS = LLM.get(
|
|
"parsers",
|
|
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag")
|
|
|
|
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
|
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
|
|
|
SECRET_KEY = get_base_config(
|
|
RAG_FLOW_SERVICE_NAME,
|
|
{}).get("secret_key", str(date.today()))
|
|
|
|
global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH
|
|
# authentication
|
|
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
|
|
|
# client
|
|
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
|
|
"client", {}).get(
|
|
"switch", False)
|
|
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
|
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
|
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
|
|
|
|
global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler
|
|
DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
|
|
lower_case_doc_engine = DOC_ENGINE.lower()
|
|
if lower_case_doc_engine == "elasticsearch":
|
|
docStoreConn = rag.utils.es_conn.ESConnection()
|
|
elif lower_case_doc_engine == "infinity":
|
|
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
|
else:
|
|
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
|
|
|
retrievaler = search.Dealer(docStoreConn)
|
|
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
|
|
|
|
|
class CustomEnum(Enum):
|
|
@classmethod
|
|
def valid(cls, value):
|
|
try:
|
|
cls(value)
|
|
return True
|
|
except BaseException:
|
|
return False
|
|
|
|
@classmethod
|
|
def values(cls):
|
|
return [member.value for member in cls.__members__.values()]
|
|
|
|
@classmethod
|
|
def names(cls):
|
|
return [member.name for member in cls.__members__.values()]
|
|
|
|
|
|
class RetCode(IntEnum, CustomEnum):
|
|
SUCCESS = 0
|
|
NOT_EFFECTIVE = 10
|
|
EXCEPTION_ERROR = 100
|
|
ARGUMENT_ERROR = 101
|
|
DATA_ERROR = 102
|
|
OPERATING_ERROR = 103
|
|
CONNECTION_ERROR = 105
|
|
RUNNING = 106
|
|
PERMISSION_ERROR = 108
|
|
AUTHENTICATION_ERROR = 109
|
|
UNAUTHORIZED = 401
|
|
SERVER_ERROR = 500
|
|
FORBIDDEN = 403
|
|
NOT_FOUND = 404
|