Common: Support postgreSQL database as the metadata db. (#2357)

https://github.com/infiniflow/ragflow/issues/2356

### What problem does this PR solve?

As title

### Type of change

- [X] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Fachuan Bai 2024-09-12 15:12:39 +08:00 committed by GitHub
parent ba834aee26
commit f8e9a0590f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 89 additions and 12 deletions

View File

@ -18,18 +18,19 @@ import os
import sys
import typing
import operator
from enum import Enum
from functools import wraps
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from flask_login import UserMixin
from playhouse.migrate import MySQLMigrator, migrate
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
from peewee import (
BigIntegerField, BooleanField, CharField,
CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
Field, Model, Metadata
)
from playhouse.pool import PooledMySQLDatabase
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api.db import SerializedType, ParserType
from api.settings import DATABASE, stat_logger, SECRET_KEY
from api.settings import DATABASE, stat_logger, SECRET_KEY, DATABASE_TYPE
from api.utils.log_utils import getLogger
from api import utils
@ -58,8 +59,13 @@ AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
"write_access"}
class TextFieldType(Enum):
MYSQL = 'LONGTEXT'
POSTGRES = 'TEXT'
class LongTextField(TextField):
field_type = 'LONGTEXT'
field_type = TextFieldType[DATABASE_TYPE.upper()].value
class JSONField(LongTextField):
@ -266,18 +272,69 @@ class JsonSerializedField(SerializedField):
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
object_pairs_hook=object_pairs_hook, **kwargs)
class PooledDatabase(Enum):
MYSQL = PooledMySQLDatabase
POSTGRES = PooledPostgresqlDatabase
class DatabaseMigrator(Enum):
MYSQL = MySQLMigrator
POSTGRES = PostgresqlMigrator
@singleton
class BaseDataBase:
def __init__(self):
database_config = DATABASE.copy()
db_name = database_config.pop("name")
self.database_connection = PooledMySQLDatabase(
db_name, **database_config)
stat_logger.info('init mysql database on cluster mode successfully')
self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
stat_logger.info('init database on cluster mode successfully')
class PostgresDatabaseLock:
def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name
self.timeout = int(timeout)
self.db = db if db else DB
class DatabaseLock:
def lock(self):
cursor = self.db.execute_sql("SELECT pg_try_advisory_lock(%s)", self.timeout)
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(f'acquire postgres lock {self.lock_name} timeout')
elif ret[0] == 1:
return True
else:
raise Exception(f'failed to acquire lock {self.lock_name}')
def unlock(self):
cursor = self.db.execute_sql("SELECT pg_advisory_unlock(%s)", self.timeout)
ret = cursor.fetchone()
if ret[0] == 0:
raise Exception(
f'postgres lock {self.lock_name} was not established by this thread')
elif ret[0] == 1:
return True
else:
raise Exception(f'postgres lock {self.lock_name} does not exist')
def __enter__(self):
if isinstance(self.db, PostgresDatabaseLock):
self.lock()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(self.db, PostgresDatabaseLock):
self.unlock()
def __call__(self, func):
@wraps(func)
def magic(*args, **kwargs):
with self:
return func(*args, **kwargs)
return magic
class MysqlDatabaseLock:
def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name
self.timeout = int(timeout)
@ -325,8 +382,13 @@ class DatabaseLock:
return magic
class DatabaseLock(Enum):
MYSQL = MysqlDatabaseLock
POSTGRES = PostgresDatabaseLock
DB = BaseDataBase().database_connection
DB.lock = DatabaseLock
DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
def close_connection():
@ -918,7 +980,7 @@ class CanvasTemplate(DataBaseModel):
def migrate_db():
with DB.transaction():
migrator = MySQLMigrator(DB)
migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
try:
migrate(
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",

View File

@ -17,6 +17,8 @@ import operator
from functools import reduce
from typing import Dict, Type, Union
from playhouse.pool import PooledMySQLDatabase
from api.utils import current_timestamp, timestamp_to_date
from api.db.db_models import DB, DataBaseModel
@ -49,7 +51,10 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
with DB.atomic():
query = model.insert_many(data_source[i:i + batch_size])
if replace_on_conflict:
query = query.on_conflict(preserve=preserve)
if isinstance(DB, PooledMySQLDatabase):
query = query.on_conflict(preserve=preserve)
else:
query = query.on_conflict(conflict_target="id", preserve=preserve)
query.execute()

View File

@ -164,7 +164,8 @@ RANDOM_INSTANCE_ID = get_base_config(
PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy")
PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol")
DATABASE = decrypt_database_config(name="mysql")
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
# Switch
# upload

View File

@ -9,6 +9,14 @@ mysql:
port: 3306
max_connections: 100
stale_timeout: 30
postgres:
name: 'rag_flow'
user: 'rag_flow'
password: 'infini_rag_flow'
host: 'postgres'
port: 5432
max_connections: 100
stale_timeout: 30
minio:
user: 'rag_flow'
password: 'infini_rag_flow'

View File

@ -31,6 +31,7 @@ Flask==3.0.3
Flask_Cors==5.0.0
Flask_Login==0.6.3
flask_session==0.8.0
psycopg2==2.9.9
google_search_results==2.4.2
groq==0.9.0
hanziconv==0.3.2