mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-11 01:38:59 +08:00
Common: Support postgreSQL database as the metadata db. (#2357)
https://github.com/infiniflow/ragflow/issues/2356 ### What problem does this PR solve? As title ### Type of change - [X] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
ba834aee26
commit
f8e9a0590f
@ -18,18 +18,19 @@ import os
|
||||
import sys
|
||||
import typing
|
||||
import operator
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from flask_login import UserMixin
|
||||
from playhouse.migrate import MySQLMigrator, migrate
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
from peewee import (
|
||||
BigIntegerField, BooleanField, CharField,
|
||||
CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
|
||||
Field, Model, Metadata
|
||||
)
|
||||
from playhouse.pool import PooledMySQLDatabase
|
||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||
from api.db import SerializedType, ParserType
|
||||
from api.settings import DATABASE, stat_logger, SECRET_KEY
|
||||
from api.settings import DATABASE, stat_logger, SECRET_KEY, DATABASE_TYPE
|
||||
from api.utils.log_utils import getLogger
|
||||
from api import utils
|
||||
|
||||
@ -58,8 +59,13 @@ AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
|
||||
"write_access"}
|
||||
|
||||
|
||||
class TextFieldType(Enum):
|
||||
MYSQL = 'LONGTEXT'
|
||||
POSTGRES = 'TEXT'
|
||||
|
||||
|
||||
class LongTextField(TextField):
|
||||
field_type = 'LONGTEXT'
|
||||
field_type = TextFieldType[DATABASE_TYPE.upper()].value
|
||||
|
||||
|
||||
class JSONField(LongTextField):
|
||||
@ -266,18 +272,69 @@ class JsonSerializedField(SerializedField):
|
||||
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook, **kwargs)
|
||||
|
||||
class PooledDatabase(Enum):
|
||||
MYSQL = PooledMySQLDatabase
|
||||
POSTGRES = PooledPostgresqlDatabase
|
||||
|
||||
|
||||
class DatabaseMigrator(Enum):
|
||||
MYSQL = MySQLMigrator
|
||||
POSTGRES = PostgresqlMigrator
|
||||
|
||||
|
||||
@singleton
|
||||
class BaseDataBase:
|
||||
def __init__(self):
|
||||
database_config = DATABASE.copy()
|
||||
db_name = database_config.pop("name")
|
||||
self.database_connection = PooledMySQLDatabase(
|
||||
db_name, **database_config)
|
||||
stat_logger.info('init mysql database on cluster mode successfully')
|
||||
self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
stat_logger.info('init database on cluster mode successfully')
|
||||
|
||||
class PostgresDatabaseLock:
|
||||
def __init__(self, lock_name, timeout=10, db=None):
|
||||
self.lock_name = lock_name
|
||||
self.timeout = int(timeout)
|
||||
self.db = db if db else DB
|
||||
|
||||
class DatabaseLock:
|
||||
def lock(self):
|
||||
cursor = self.db.execute_sql("SELECT pg_try_advisory_lock(%s)", self.timeout)
|
||||
ret = cursor.fetchone()
|
||||
if ret[0] == 0:
|
||||
raise Exception(f'acquire postgres lock {self.lock_name} timeout')
|
||||
elif ret[0] == 1:
|
||||
return True
|
||||
else:
|
||||
raise Exception(f'failed to acquire lock {self.lock_name}')
|
||||
|
||||
def unlock(self):
|
||||
cursor = self.db.execute_sql("SELECT pg_advisory_unlock(%s)", self.timeout)
|
||||
ret = cursor.fetchone()
|
||||
if ret[0] == 0:
|
||||
raise Exception(
|
||||
f'postgres lock {self.lock_name} was not established by this thread')
|
||||
elif ret[0] == 1:
|
||||
return True
|
||||
else:
|
||||
raise Exception(f'postgres lock {self.lock_name} does not exist')
|
||||
|
||||
def __enter__(self):
|
||||
if isinstance(self.db, PostgresDatabaseLock):
|
||||
self.lock()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if isinstance(self.db, PostgresDatabaseLock):
|
||||
self.unlock()
|
||||
|
||||
def __call__(self, func):
|
||||
@wraps(func)
|
||||
def magic(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return magic
|
||||
|
||||
class MysqlDatabaseLock:
|
||||
def __init__(self, lock_name, timeout=10, db=None):
|
||||
self.lock_name = lock_name
|
||||
self.timeout = int(timeout)
|
||||
@ -325,8 +382,13 @@ class DatabaseLock:
|
||||
return magic
|
||||
|
||||
|
||||
class DatabaseLock(Enum):
|
||||
MYSQL = MysqlDatabaseLock
|
||||
POSTGRES = PostgresDatabaseLock
|
||||
|
||||
|
||||
DB = BaseDataBase().database_connection
|
||||
DB.lock = DatabaseLock
|
||||
DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
|
||||
|
||||
|
||||
def close_connection():
|
||||
@ -918,7 +980,7 @@ class CanvasTemplate(DataBaseModel):
|
||||
|
||||
def migrate_db():
|
||||
with DB.transaction():
|
||||
migrator = MySQLMigrator(DB)
|
||||
migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
|
||||
|
@ -17,6 +17,8 @@ import operator
|
||||
from functools import reduce
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
from playhouse.pool import PooledMySQLDatabase
|
||||
|
||||
from api.utils import current_timestamp, timestamp_to_date
|
||||
|
||||
from api.db.db_models import DB, DataBaseModel
|
||||
@ -49,7 +51,10 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
|
||||
with DB.atomic():
|
||||
query = model.insert_many(data_source[i:i + batch_size])
|
||||
if replace_on_conflict:
|
||||
query = query.on_conflict(preserve=preserve)
|
||||
if isinstance(DB, PooledMySQLDatabase):
|
||||
query = query.on_conflict(preserve=preserve)
|
||||
else:
|
||||
query = query.on_conflict(conflict_target="id", preserve=preserve)
|
||||
query.execute()
|
||||
|
||||
|
||||
|
@ -164,7 +164,8 @@ RANDOM_INSTANCE_ID = get_base_config(
|
||||
PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
|
||||
PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
|
||||
|
||||
DATABASE = decrypt_database_config(name="mysql")
|
||||
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
|
||||
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
||||
|
||||
# Switch
|
||||
# upload
|
||||
|
@ -9,6 +9,14 @@ mysql:
|
||||
port: 3306
|
||||
max_connections: 100
|
||||
stale_timeout: 30
|
||||
postgres:
|
||||
name: 'rag_flow'
|
||||
user: 'rag_flow'
|
||||
password: 'infini_rag_flow'
|
||||
host: 'postgres'
|
||||
port: 5432
|
||||
max_connections: 100
|
||||
stale_timeout: 30
|
||||
minio:
|
||||
user: 'rag_flow'
|
||||
password: 'infini_rag_flow'
|
||||
|
@ -31,6 +31,7 @@ Flask==3.0.3
|
||||
Flask_Cors==5.0.0
|
||||
Flask_Login==0.6.3
|
||||
flask_session==0.8.0
|
||||
psycopg2==2.9.9
|
||||
google_search_results==2.4.2
|
||||
groq==0.9.0
|
||||
hanziconv==0.3.2
|
||||
|
Loading…
x
Reference in New Issue
Block a user