mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-19 12:39:59 +08:00
build python version rag-flow (#21)
* clean rust version project * clean rust version project * build python version rag-flow
This commit is contained in:
parent
db8cae3f1e
commit
30791976d5
@ -1,9 +0,0 @@
|
||||
# Database
|
||||
HOST=127.0.0.1
|
||||
PORT=8000
|
||||
DATABASE_URL="postgresql://infiniflow:infiniflow@localhost/docgpt"
|
||||
|
||||
# S3 Storage
|
||||
MINIO_HOST="127.0.0.1:9000"
|
||||
MINIO_USR="infiniflow"
|
||||
MINIO_PWD="infiniflow_docgpt"
|
42
Cargo.toml
42
Cargo.toml
@ -1,42 +0,0 @@
|
||||
[package]
|
||||
name = "doc_gpt"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
actix-web = "4.3.1"
|
||||
actix-rt = "2.8.0"
|
||||
actix-files = "0.6.2"
|
||||
actix-multipart = "0.4"
|
||||
actix-session = { version = "0.5" }
|
||||
actix-identity = { version = "0.4" }
|
||||
actix-web-httpauth = { version = "0.6" }
|
||||
actix-ws = "0.2.5"
|
||||
uuid = { version = "1.6.1", features = [
|
||||
"v4",
|
||||
"fast-rng",
|
||||
"macro-diagnostics",
|
||||
] }
|
||||
thiserror = "1.0"
|
||||
postgres = "0.19.7"
|
||||
sea-orm = { version = "0.12.9", features = ["sqlx-postgres", "runtime-tokio-native-tls", "macros"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tracing-subscriber = "0.3.18"
|
||||
dotenvy = "0.15.7"
|
||||
listenfd = "1.0.1"
|
||||
chrono = "0.4.31"
|
||||
migration = { path = "./migration" }
|
||||
minio = "0.1.0"
|
||||
futures-util = "0.3.29"
|
||||
actix-multipart-extract = "0.1.5"
|
||||
regex = "1.10.2"
|
||||
tokio = { version = "1.35.1", features = ["rt", "time", "macros"] }
|
||||
|
||||
[[bin]]
|
||||
name = "doc_gpt"
|
||||
|
||||
[workspace]
|
||||
members = [".", "migration"]
|
0
python/conf/mapping.json → conf/mapping.json
Executable file → Normal file
0
python/conf/mapping.json → conf/mapping.json
Executable file → Normal file
30
conf/private.pem
Normal file
30
conf/private.pem
Normal file
@ -0,0 +1,30 @@
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: DES-EDE3-CBC,EFF8327C41E531AD
|
||||
|
||||
7jdPFDAA6fiTzOIU7XGzKuT324JKZEcK5vBRJqBkA5XO6ENN1wLdhh3zQbl1Ejfv
|
||||
KMSUIgbtQEJB4bvOzS//okbZa1vCNYuTS/NGcpKUnhqdOmAL3hl/kOtOLLjTZrwo
|
||||
3KX8iujLH7wQ64GxArtpUuaFq1k0whN1BB5RGJp3IO/L6pMpSWVRKO+JPUrD1Ujr
|
||||
XA/LUKQJaZtXVUVOYPtIwbyqPsh93QBetJnRwwV3gNOwGpcX2jDpyTxDUkLJCPPg
|
||||
6Hw0pwlQEd8A11sjxCBbASwLeJO1L0w69QiX9chyOkZ+sfDsVpPt/wf1NexA7Cdj
|
||||
9uifJ4JGbby39QD6mInZGtnRzQRdafjuXlBR2I0Qa7fBRu8QsfhmLbWZfWno7j08
|
||||
4bAAoqB1vRNfSu8LVJXdEEh/HKuwu11pgRr5eH8WQ3hJg+Y2k7zDHpp1VaHL7/Kn
|
||||
S+aN5bhQ4Xt0Ujdi1+rsmNchnF6LWsDezHWJeWUM6X7dJnqIBl8oCyghbghT8Tyw
|
||||
aEKWXc2+7FsP5yd0NfG3PFYOLdLgfI43pHTAv5PEQ47w9r1XOwfblKKBUDEzaput
|
||||
T3t5wQ6wxdyhRxeO4arCHfe/i+j3fzvhlwgbuwrmrkWGWSS86eMTaoGM8+uUrHv0
|
||||
6TbU0tj6DKKUslVk1dCHh9TnmNsXZuLJkceZF38PSKNxhzudU8OTtzhS0tFL91HX
|
||||
vo7N+XdiGMs8oOSpjE6RPlhFhVAKGJpXwBj/vXLLcmzesA7ZB2kYtFKMIdsUQpls
|
||||
PE/4K5PEX2d8pxA5zxo0HleA1YjW8i5WEcDQThZQzj2sWvg06zSjenVFrbCm9Bro
|
||||
hFpAB/3zJHxdRN2MpNpvK35WITy1aDUdX1WdyrlcRtIE5ssFTSoxSj9ibbDZ78+z
|
||||
gtbw/MUi6vU6Yz1EjvoYu/bmZAHt9Aagcxw6k58fjO2cEB9njK7xbbiZUSwpJhEe
|
||||
U/PxK+SdOU/MmGKeqdgqSfhJkq0vhacvsEjFGRAfivSCHkL0UjhObU+rSJ3g1RMO
|
||||
oukAev6TOAwbTKVWjg3/EX+pl/zorAgaPNYFX64TSH4lE3VjeWApITb9Z5C/sVxR
|
||||
xW6hU9qyjzWYWY+91y16nkw1l7VQvWHUZwV7QzTScC2BOzDVpeqY1KiYJxgoo6sX
|
||||
ZCqR5oh4vToG4W8ZrRyauwUaZJ3r+zhAgm+6n6TJQNwFEl0muji+1nPl32EiFsRs
|
||||
qR6CtuhUOVQM4VnILDwFJfuGYRFtKzQgvseLNU4ZqAVqQj8l4ARGAP2P1Au/uUKy
|
||||
oGzI7a+b5MvRHuvkxPAclOgXgX/8yyOLaBg+mgaqv9h2JIJD28PzouFl3BajRaVB
|
||||
7GWTnROJYhX5SuX/g585SLRKoQUtK0WhdJCjTRfyRJPwfdppgdTbWO99R4G+ir02
|
||||
JQdSkZf2vmZRXenPNTEPDOUY6nVN6sUuBjmtOwoUF194ODgpYB6IaHqK08sa1pUh
|
||||
1mZyxitHdPbygePTe20XWMZFoK2knAqN0JPPbbNjCqiVV+7oqQAnkDIutspu9t2m
|
||||
ny3jefFmNozbblQMghLUrq+x9wOEgvS76Sqvq3DG/2BkLzJF3MNkvw==
|
||||
-----END RSA PRIVATE KEY-----
|
9
conf/public.pem
Normal file
9
conf/public.pem
Normal file
@ -0,0 +1,9 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/
|
||||
z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp
|
||||
2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOO
|
||||
UEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVK
|
||||
RNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK
|
||||
6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs
|
||||
2wIDAQAB
|
||||
-----END PUBLIC KEY-----
|
28
conf/service_conf.yaml
Normal file
28
conf/service_conf.yaml
Normal file
@ -0,0 +1,28 @@
|
||||
authentication:
|
||||
client:
|
||||
switch: false
|
||||
http_app_key:
|
||||
http_secret_key:
|
||||
site:
|
||||
switch: false
|
||||
permission:
|
||||
switch: false
|
||||
component: false
|
||||
dataset: false
|
||||
ragflow:
|
||||
# you must set real ip address, 127.0.0.1 and 0.0.0.0 is not supported
|
||||
host: 127.0.0.1
|
||||
http_port: 9380
|
||||
database:
|
||||
name: 'rag_flow'
|
||||
user: 'root'
|
||||
passwd: 'infini_rag_flow'
|
||||
host: '123.60.95.134'
|
||||
port: 5455
|
||||
max_connections: 100
|
||||
stale_timeout: 30
|
||||
oauth:
|
||||
github:
|
||||
client_id: 302129228f0d96055bee
|
||||
secret_key: e518e55ccfcdfcae8996afc40f110e9c95f14fc4
|
||||
url: https://github.com/login/oauth/access_token
|
21
docker/.env
21
docker/.env
@ -1,21 +0,0 @@
|
||||
# Version of Elastic products
|
||||
STACK_VERSION=8.11.3
|
||||
|
||||
# Set the cluster name
|
||||
CLUSTER_NAME=docgpt
|
||||
|
||||
# Port to expose Elasticsearch HTTP API to the host
|
||||
ES_PORT=9200
|
||||
|
||||
# Port to expose Kibana to the host
|
||||
KIBANA_PORT=6601
|
||||
|
||||
# Increase or decrease based on the available host memory (in bytes)
|
||||
MEM_LIMIT=4073741824
|
||||
|
||||
POSTGRES_USER=root
|
||||
POSTGRES_PASSWORD=infiniflow_docgpt
|
||||
POSTGRES_DB=docgpt
|
||||
|
||||
MINIO_USER=infiniflow
|
||||
MINIO_PASSWORD=infiniflow_docgpt
|
@ -1,7 +1,7 @@
|
||||
version: '2.2'
|
||||
services:
|
||||
es01:
|
||||
container_name: docgpt-es-01
|
||||
container_name: ragflow-es-01
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:${STACK_VERSION}
|
||||
volumes:
|
||||
- esdata01:/usr/share/elasticsearch/data
|
||||
@ -20,14 +20,14 @@ services:
|
||||
soft: -1
|
||||
hard: -1
|
||||
networks:
|
||||
- docgpt
|
||||
- ragflow
|
||||
restart: always
|
||||
|
||||
kibana:
|
||||
depends_on:
|
||||
- es01
|
||||
image: docker.elastic.co/kibana/kibana:${STACK_VERSION}
|
||||
container_name: docgpt-kibana
|
||||
container_name: ragflow-kibana
|
||||
volumes:
|
||||
- kibanadata:/usr/share/kibana/data
|
||||
ports:
|
||||
@ -37,26 +37,39 @@ services:
|
||||
- ELASTICSEARCH_HOSTS=http://es01:9200
|
||||
mem_limit: ${MEM_LIMIT}
|
||||
networks:
|
||||
- docgpt
|
||||
- ragflow
|
||||
|
||||
postgres:
|
||||
image: postgres
|
||||
container_name: docgpt-postgres
|
||||
mysql:
|
||||
image: mysql:5.7.18
|
||||
container_name: ragflow-mysql
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
||||
- POSTGRES_DB=${POSTGRES_DB}
|
||||
- MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD}
|
||||
- TZ="Asia/Shanghai"
|
||||
command:
|
||||
--max_connections=1000
|
||||
--character-set-server=utf8mb4
|
||||
--collation-server=utf8mb4_general_ci
|
||||
--default-authentication-plugin=mysql_native_password
|
||||
--tls_version="TLSv1.2,TLSv1.3"
|
||||
--init-file /data/application/init.sql
|
||||
ports:
|
||||
- 5455:5432
|
||||
- ${MYSQL_PORT}:3306
|
||||
volumes:
|
||||
- pg_data:/var/lib/postgresql/data
|
||||
- mysql_data:/var/lib/mysql
|
||||
- ./init.sql:/data/application/init.sql
|
||||
networks:
|
||||
- docgpt
|
||||
- ragflow
|
||||
healthcheck:
|
||||
test: [ "CMD-SHELL", "curl --silent localhost:3306 >/dev/null || exit 1" ]
|
||||
interval: 10s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
restart: always
|
||||
|
||||
|
||||
minio:
|
||||
image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
|
||||
container_name: docgpt-minio
|
||||
container_name: ragflow-minio
|
||||
command: server --console-address ":9001" /data
|
||||
ports:
|
||||
- 9000:9000
|
||||
@ -67,7 +80,7 @@ services:
|
||||
volumes:
|
||||
- minio_data:/data
|
||||
networks:
|
||||
- docgpt
|
||||
- ragflow
|
||||
restart: always
|
||||
|
||||
|
||||
@ -76,11 +89,11 @@ volumes:
|
||||
driver: local
|
||||
kibanadata:
|
||||
driver: local
|
||||
pg_data:
|
||||
mysql_data:
|
||||
driver: local
|
||||
minio_data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
docgpt:
|
||||
ragflow:
|
||||
driver: bridge
|
||||
|
2
docker/init.sql
Normal file
2
docker/init.sql
Normal file
@ -0,0 +1,2 @@
|
||||
CREATE DATABASE IF NOT EXISTS rag_flow;
|
||||
USE rag_flow;
|
@ -1,20 +0,0 @@
|
||||
[package]
|
||||
name = "migration"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
name = "migration"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
async-std = { version = "1", features = ["attributes", "tokio1"] }
|
||||
chrono = "0.4.31"
|
||||
|
||||
[dependencies.sea-orm-migration]
|
||||
version = "0.12.0"
|
||||
features = [
|
||||
"runtime-tokio-rustls", # `ASYNC_RUNTIME` feature
|
||||
"sqlx-postgres", # `DATABASE_DRIVER` feature
|
||||
]
|
@ -1,41 +0,0 @@
|
||||
# Running Migrator CLI
|
||||
|
||||
- Generate a new migration file
|
||||
```sh
|
||||
cargo run -- generate MIGRATION_NAME
|
||||
```
|
||||
- Apply all pending migrations
|
||||
```sh
|
||||
cargo run
|
||||
```
|
||||
```sh
|
||||
cargo run -- up
|
||||
```
|
||||
- Apply first 10 pending migrations
|
||||
```sh
|
||||
cargo run -- up -n 10
|
||||
```
|
||||
- Rollback last applied migrations
|
||||
```sh
|
||||
cargo run -- down
|
||||
```
|
||||
- Rollback last 10 applied migrations
|
||||
```sh
|
||||
cargo run -- down -n 10
|
||||
```
|
||||
- Drop all tables from the database, then reapply all migrations
|
||||
```sh
|
||||
cargo run -- fresh
|
||||
```
|
||||
- Rollback all applied migrations, then reapply all migrations
|
||||
```sh
|
||||
cargo run -- refresh
|
||||
```
|
||||
- Rollback all applied migrations
|
||||
```sh
|
||||
cargo run -- reset
|
||||
```
|
||||
- Check the status of all migrations
|
||||
```sh
|
||||
cargo run -- status
|
||||
```
|
@ -1,12 +0,0 @@
|
||||
pub use sea_orm_migration::prelude::*;
|
||||
|
||||
mod m20220101_000001_create_table;
|
||||
|
||||
pub struct Migrator;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MigratorTrait for Migrator {
|
||||
fn migrations() -> Vec<Box<dyn MigrationTrait>> {
|
||||
vec![Box::new(m20220101_000001_create_table::Migration)]
|
||||
}
|
||||
}
|
@ -1,440 +0,0 @@
|
||||
use sea_orm_migration::prelude::*;
|
||||
use chrono::{ FixedOffset, Utc };
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn now() -> chrono::DateTime<FixedOffset> {
|
||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
||||
}
|
||||
#[derive(DeriveMigrationName)]
|
||||
pub struct Migration;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MigrationTrait for Migration {
|
||||
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||
manager.create_table(
|
||||
Table::create()
|
||||
.table(UserInfo::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(UserInfo::Uid)
|
||||
.big_integer()
|
||||
.not_null()
|
||||
.auto_increment()
|
||||
.primary_key()
|
||||
)
|
||||
.col(ColumnDef::new(UserInfo::Email).string().not_null())
|
||||
.col(ColumnDef::new(UserInfo::Nickname).string().not_null())
|
||||
.col(ColumnDef::new(UserInfo::AvatarBase64).string())
|
||||
.col(ColumnDef::new(UserInfo::ColorScheme).string().default("dark"))
|
||||
.col(ColumnDef::new(UserInfo::ListStyle).string().default("list"))
|
||||
.col(ColumnDef::new(UserInfo::Language).string().default("chinese"))
|
||||
.col(ColumnDef::new(UserInfo::Password).string().not_null())
|
||||
.col(
|
||||
ColumnDef::new(UserInfo::LastLoginAt)
|
||||
.timestamp_with_time_zone()
|
||||
.default(Expr::current_timestamp())
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(UserInfo::CreatedAt)
|
||||
.timestamp_with_time_zone()
|
||||
.default(Expr::current_timestamp())
|
||||
.not_null()
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(UserInfo::UpdatedAt)
|
||||
.timestamp_with_time_zone()
|
||||
.default(Expr::current_timestamp())
|
||||
.not_null()
|
||||
)
|
||||
.col(ColumnDef::new(UserInfo::IsDeleted).boolean().default(false))
|
||||
.to_owned()
|
||||
).await?;
|
||||
|
||||
manager.create_table(
|
||||
Table::create()
|
||||
.table(TagInfo::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(TagInfo::Tid)
|
||||
.big_integer()
|
||||
.not_null()
|
||||
.auto_increment()
|
||||
.primary_key()
|
||||
)
|
||||
.col(ColumnDef::new(TagInfo::Uid).big_integer().not_null())
|
||||
.col(ColumnDef::new(TagInfo::TagName).string().not_null())
|
||||
.col(ColumnDef::new(TagInfo::Regx).string())
|
||||
.col(ColumnDef::new(TagInfo::Color).tiny_unsigned().default(1))
|
||||
.col(ColumnDef::new(TagInfo::Icon).tiny_unsigned().default(1))
|
||||
.col(ColumnDef::new(TagInfo::FolderId).big_integer())
|
||||
.col(
|
||||
ColumnDef::new(TagInfo::CreatedAt)
|
||||
.timestamp_with_time_zone()
|
||||
.default(Expr::current_timestamp())
|
||||
.not_null()
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(TagInfo::UpdatedAt)
|
||||
.timestamp_with_time_zone()
|
||||
.default(Expr::current_timestamp())
|
||||
.not_null()
|
||||
)
|
||||
.col(ColumnDef::new(TagInfo::IsDeleted).boolean().default(false))
|
||||
.to_owned()
|
||||
).await?;
|
||||
|
||||
manager.create_table(
|
||||
Table::create()
|
||||
.table(Tag2Doc::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(Tag2Doc::Id)
|
||||
.big_integer()
|
||||
.not_null()
|
||||
.auto_increment()
|
||||
.primary_key()
|
||||
)
|
||||
.col(ColumnDef::new(Tag2Doc::TagId).big_integer())
|
||||
.col(ColumnDef::new(Tag2Doc::Did).big_integer())
|
||||
.to_owned()
|
||||
).await?;
|
||||
|
||||
manager.create_table(
|
||||
Table::create()
|
||||
.table(Kb2Doc::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(Kb2Doc::Id)
|
||||
.big_integer()
|
||||
.not_null()
|
||||
.auto_increment()
|
||||
.primary_key()
|
||||
)
|
||||
.col(ColumnDef::new(Kb2Doc::KbId).big_integer())
|
||||
.col(ColumnDef::new(Kb2Doc::Did).big_integer())
|
||||
.col(ColumnDef::new(Kb2Doc::KbProgress).float().default(0))
|
||||
.col(ColumnDef::new(Kb2Doc::KbProgressMsg).string().default(""))
|
||||
.col(
|
||||
ColumnDef::new(Kb2Doc::UpdatedAt)
|
||||
.timestamp_with_time_zone()
|
||||
.default(Expr::current_timestamp())
|
||||
.not_null()
|
||||
)
|
||||
.col(ColumnDef::new(Kb2Doc::IsDeleted).boolean().default(false))
|
||||
.to_owned()
|
||||
).await?;
|
||||
|
||||
manager.create_table(
|
||||
Table::create()
|
||||
.table(Dialog2Kb::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(Dialog2Kb::Id)
|
||||
.big_integer()
|
||||
.not_null()
|
||||
.auto_increment()
|
||||
.primary_key()
|
||||
)
|
||||
.col(ColumnDef::new(Dialog2Kb::DialogId).big_integer())
|
||||
.col(ColumnDef::new(Dialog2Kb::KbId).big_integer())
|
||||
.to_owned()
|
||||
).await?;
|
||||
|
||||
manager.create_table(
|
||||
Table::create()
|
||||
.table(Doc2Doc::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(Doc2Doc::Id)
|
||||
.big_integer()
|
||||
.not_null()
|
||||
.auto_increment()
|
||||
.primary_key()
|
||||
)
|
||||
.col(ColumnDef::new(Doc2Doc::ParentId).big_integer())
|
||||
.col(ColumnDef::new(Doc2Doc::Did).big_integer())
|
||||
.to_owned()
|
||||
).await?;
|
||||
|
||||
manager.create_table(
|
||||
Table::create()
|
||||
.table(KbInfo::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(KbInfo::KbId)
|
||||
.big_integer()
|
||||
.auto_increment()
|
||||
.not_null()
|
||||
.primary_key()
|
||||
)
|
||||
.col(ColumnDef::new(KbInfo::Uid).big_integer().not_null())
|
||||
.col(ColumnDef::new(KbInfo::KbName).string().not_null())
|
||||
.col(ColumnDef::new(KbInfo::Icon).tiny_unsigned().default(1))
|
||||
.col(
|
||||
ColumnDef::new(KbInfo::CreatedAt)
|
||||
.timestamp_with_time_zone()
|
||||
.default(Expr::current_timestamp())
|
||||
.not_null()
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(KbInfo::UpdatedAt)
|
||||
.timestamp_with_time_zone()
|
||||
.default(Expr::current_timestamp())
|
||||
.not_null()
|
||||
)
|
||||
.col(ColumnDef::new(KbInfo::IsDeleted).boolean().default(false))
|
||||
.to_owned()
|
||||
).await?;
|
||||
|
||||
manager.create_table(
|
||||
Table::create()
|
||||
.table(DocInfo::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(DocInfo::Did)
|
||||
.big_integer()
|
||||
.not_null()
|
||||
.auto_increment()
|
||||
.primary_key()
|
||||
)
|
||||
.col(ColumnDef::new(DocInfo::Uid).big_integer().not_null())
|
||||
.col(ColumnDef::new(DocInfo::DocName).string().not_null())
|
||||
.col(ColumnDef::new(DocInfo::Location).string().not_null())
|
||||
.col(ColumnDef::new(DocInfo::Size).big_integer().not_null())
|
||||
.col(ColumnDef::new(DocInfo::Type).string().not_null())
|
||||
.col(ColumnDef::new(DocInfo::ThumbnailBase64).string().default(""))
|
||||
.comment("doc type|folder")
|
||||
.col(
|
||||
ColumnDef::new(DocInfo::CreatedAt)
|
||||
.timestamp_with_time_zone()
|
||||
.default(Expr::current_timestamp())
|
||||
.not_null()
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(DocInfo::UpdatedAt)
|
||||
.timestamp_with_time_zone()
|
||||
.default(Expr::current_timestamp())
|
||||
.not_null()
|
||||
)
|
||||
.col(ColumnDef::new(DocInfo::IsDeleted).boolean().default(false))
|
||||
.to_owned()
|
||||
).await?;
|
||||
|
||||
manager.create_table(
|
||||
Table::create()
|
||||
.table(DialogInfo::Table)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(DialogInfo::DialogId)
|
||||
.big_integer()
|
||||
.not_null()
|
||||
.auto_increment()
|
||||
.primary_key()
|
||||
)
|
||||
.col(ColumnDef::new(DialogInfo::Uid).big_integer().not_null())
|
||||
.col(ColumnDef::new(DialogInfo::KbId).big_integer().not_null())
|
||||
.col(ColumnDef::new(DialogInfo::DialogName).string().not_null())
|
||||
.col(ColumnDef::new(DialogInfo::History).string().comment("json"))
|
||||
.col(
|
||||
ColumnDef::new(DialogInfo::CreatedAt)
|
||||
.timestamp_with_time_zone()
|
||||
.default(Expr::current_timestamp())
|
||||
.not_null()
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(DialogInfo::UpdatedAt)
|
||||
.timestamp_with_time_zone()
|
||||
.default(Expr::current_timestamp())
|
||||
.not_null()
|
||||
)
|
||||
.col(ColumnDef::new(DialogInfo::IsDeleted).boolean().default(false))
|
||||
.to_owned()
|
||||
).await?;
|
||||
|
||||
let root_insert = Query::insert()
|
||||
.into_table(UserInfo::Table)
|
||||
.columns([UserInfo::Email, UserInfo::Nickname, UserInfo::Password])
|
||||
.values_panic(["kai.hu@infiniflow.org".into(), "root".into(), "123456".into()])
|
||||
.to_owned();
|
||||
|
||||
let doc_insert = Query::insert()
|
||||
.into_table(DocInfo::Table)
|
||||
.columns([
|
||||
DocInfo::Uid,
|
||||
DocInfo::DocName,
|
||||
DocInfo::Size,
|
||||
DocInfo::Type,
|
||||
DocInfo::Location,
|
||||
])
|
||||
.values_panic([(1).into(), "/".into(), (0).into(), "folder".into(), "".into()])
|
||||
.to_owned();
|
||||
|
||||
let tag_insert = Query::insert()
|
||||
.into_table(TagInfo::Table)
|
||||
.columns([TagInfo::Uid, TagInfo::TagName, TagInfo::Regx, TagInfo::Color, TagInfo::Icon])
|
||||
.values_panic([
|
||||
(1).into(),
|
||||
"Video".into(),
|
||||
".*\\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)".into(),
|
||||
(1).into(),
|
||||
(1).into(),
|
||||
])
|
||||
.values_panic([
|
||||
(1).into(),
|
||||
"Picture".into(),
|
||||
".*\\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)".into(),
|
||||
(2).into(),
|
||||
(2).into(),
|
||||
])
|
||||
.values_panic([
|
||||
(1).into(),
|
||||
"Music".into(),
|
||||
".*\\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)".into(),
|
||||
(3).into(),
|
||||
(3).into(),
|
||||
])
|
||||
.values_panic([
|
||||
(1).into(),
|
||||
"Document".into(),
|
||||
".*\\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)".into(),
|
||||
(3).into(),
|
||||
(3).into(),
|
||||
])
|
||||
.to_owned();
|
||||
|
||||
manager.exec_stmt(root_insert).await?;
|
||||
manager.exec_stmt(doc_insert).await?;
|
||||
manager.exec_stmt(tag_insert).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||
manager.drop_table(Table::drop().table(UserInfo::Table).to_owned()).await?;
|
||||
|
||||
manager.drop_table(Table::drop().table(TagInfo::Table).to_owned()).await?;
|
||||
|
||||
manager.drop_table(Table::drop().table(Tag2Doc::Table).to_owned()).await?;
|
||||
|
||||
manager.drop_table(Table::drop().table(Kb2Doc::Table).to_owned()).await?;
|
||||
|
||||
manager.drop_table(Table::drop().table(Dialog2Kb::Table).to_owned()).await?;
|
||||
|
||||
manager.drop_table(Table::drop().table(Doc2Doc::Table).to_owned()).await?;
|
||||
|
||||
manager.drop_table(Table::drop().table(KbInfo::Table).to_owned()).await?;
|
||||
|
||||
manager.drop_table(Table::drop().table(DocInfo::Table).to_owned()).await?;
|
||||
|
||||
manager.drop_table(Table::drop().table(DialogInfo::Table).to_owned()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum UserInfo {
|
||||
Table,
|
||||
Uid,
|
||||
Email,
|
||||
Nickname,
|
||||
AvatarBase64,
|
||||
ColorScheme,
|
||||
ListStyle,
|
||||
Language,
|
||||
Password,
|
||||
LastLoginAt,
|
||||
CreatedAt,
|
||||
UpdatedAt,
|
||||
IsDeleted,
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum TagInfo {
|
||||
Table,
|
||||
Tid,
|
||||
Uid,
|
||||
TagName,
|
||||
Regx,
|
||||
Color,
|
||||
Icon,
|
||||
FolderId,
|
||||
CreatedAt,
|
||||
UpdatedAt,
|
||||
IsDeleted,
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum Tag2Doc {
|
||||
Table,
|
||||
Id,
|
||||
TagId,
|
||||
Did,
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum Kb2Doc {
|
||||
Table,
|
||||
Id,
|
||||
KbId,
|
||||
Did,
|
||||
KbProgress,
|
||||
KbProgressMsg,
|
||||
UpdatedAt,
|
||||
IsDeleted,
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum Dialog2Kb {
|
||||
Table,
|
||||
Id,
|
||||
DialogId,
|
||||
KbId,
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum Doc2Doc {
|
||||
Table,
|
||||
Id,
|
||||
ParentId,
|
||||
Did,
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum KbInfo {
|
||||
Table,
|
||||
KbId,
|
||||
Uid,
|
||||
KbName,
|
||||
Icon,
|
||||
CreatedAt,
|
||||
UpdatedAt,
|
||||
IsDeleted,
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum DocInfo {
|
||||
Table,
|
||||
Did,
|
||||
Uid,
|
||||
DocName,
|
||||
Location,
|
||||
Size,
|
||||
Type,
|
||||
ThumbnailBase64,
|
||||
CreatedAt,
|
||||
UpdatedAt,
|
||||
IsDeleted,
|
||||
}
|
||||
|
||||
#[derive(DeriveIden)]
|
||||
enum DialogInfo {
|
||||
Table,
|
||||
Uid,
|
||||
KbId,
|
||||
DialogId,
|
||||
DialogName,
|
||||
History,
|
||||
CreatedAt,
|
||||
UpdatedAt,
|
||||
IsDeleted,
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
use sea_orm_migration::prelude::*;
|
||||
|
||||
#[async_std::main]
|
||||
async fn main() {
|
||||
cli::run_cli(migration::Migrator).await;
|
||||
}
|
29
python/Dockerfile
Normal file
29
python/Dockerfile
Normal file
@ -0,0 +1,29 @@
|
||||
FROM ubuntu:22.04 as base
|
||||
|
||||
RUN apt-get update
|
||||
|
||||
ENV TZ="Asia/Taipei"
|
||||
RUN apt-get install -yq \
|
||||
build-essential \
|
||||
curl \
|
||||
libncursesw5-dev \
|
||||
libssl-dev \
|
||||
libsqlite3-dev \
|
||||
libgdbm-dev \
|
||||
libc6-dev \
|
||||
libbz2-dev \
|
||||
software-properties-common \
|
||||
python3.11 python3.11-dev python3-pip
|
||||
|
||||
RUN apt-get install -yq git
|
||||
RUN pip3 config set global.index-url https://mirror.baidu.com/pypi/simple
|
||||
RUN pip3 config set global.trusted-host mirror.baidu.com
|
||||
RUN pip3 install --upgrade pip
|
||||
RUN pip3 install torch==2.0.1
|
||||
RUN pip3 install torch-model-archiver==0.8.2
|
||||
RUN pip3 install torchvision==0.15.2
|
||||
COPY requirements.txt .
|
||||
|
||||
WORKDIR /docgpt
|
||||
ENV PYTHONPATH=/docgpt/
|
||||
|
@ -1,22 +0,0 @@
|
||||
|
||||
```shell
|
||||
|
||||
docker pull postgres
|
||||
|
||||
LOCAL_POSTGRES_DATA=./postgres-data
|
||||
|
||||
docker run
|
||||
--name docass-postgres
|
||||
-p 5455:5432
|
||||
-v $LOCAL_POSTGRES_DATA:/var/lib/postgresql/data
|
||||
-e POSTGRES_USER=root
|
||||
-e POSTGRES_PASSWORD=infiniflow_docass
|
||||
-e POSTGRES_DB=docass
|
||||
-d
|
||||
postgres
|
||||
|
||||
docker network create elastic
|
||||
docker pull elasticsearch:8.11.3;
|
||||
docker pull docker.elastic.co/kibana/kibana:8.11.3
|
||||
|
||||
```
|
63
python/]
Normal file
63
python/]
Normal file
@ -0,0 +1,63 @@
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
class Base(ABC):
|
||||
def describe(self, image, max_tokens=300):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
|
||||
class GptV4(Base):
|
||||
def __init__(self):
|
||||
import openapi
|
||||
openapi.api_key = os.environ["OPENAPI_KEY"]
|
||||
self.client = OpenAI()
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
buffered = BytesIO()
|
||||
try:
|
||||
image.save(buffered, format="JPEG")
|
||||
except Exception as e:
|
||||
image.save(buffered, format="PNG")
|
||||
b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model="gpt-4-vision-preview",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{b64}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return res.choices[0].message.content.strip()
|
||||
|
||||
|
||||
class QWen(Base):
|
||||
def chat(self, system, history, gen_conf):
|
||||
from http import HTTPStatus
|
||||
from dashscope import Generation
|
||||
from dashscope.api_entities.dashscope_response import Role
|
||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||
response = Generation.call(
|
||||
Generation.Models.qwen_turbo,
|
||||
messages=messages,
|
||||
result_format='message'
|
||||
)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content']
|
||||
return response.message
|
@ -1,41 +0,0 @@
|
||||
{
|
||||
"version":1,
|
||||
"disable_existing_loggers":false,
|
||||
"formatters":{
|
||||
"simple":{
|
||||
"format":"%(asctime)s - %(name)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s"
|
||||
}
|
||||
},
|
||||
"handlers":{
|
||||
"console":{
|
||||
"class":"logging.StreamHandler",
|
||||
"level":"DEBUG",
|
||||
"formatter":"simple",
|
||||
"stream":"ext://sys.stdout"
|
||||
},
|
||||
"info_file_handler":{
|
||||
"class":"logging.handlers.TimedRotatingFileHandler",
|
||||
"level":"INFO",
|
||||
"formatter":"simple",
|
||||
"filename":"log/info.log",
|
||||
"when": "MIDNIGHT",
|
||||
"interval":1,
|
||||
"backupCount":30,
|
||||
"encoding":"utf8"
|
||||
},
|
||||
"error_file_handler":{
|
||||
"class":"logging.handlers.TimedRotatingFileHandler",
|
||||
"level":"ERROR",
|
||||
"formatter":"simple",
|
||||
"filename":"log/errors.log",
|
||||
"when": "MIDNIGHT",
|
||||
"interval":1,
|
||||
"backupCount":30,
|
||||
"encoding":"utf8"
|
||||
}
|
||||
},
|
||||
"root":{
|
||||
"level":"DEBUG",
|
||||
"handlers":["console","info_file_handler","error_file_handler"]
|
||||
}
|
||||
}
|
@ -1,9 +0,0 @@
|
||||
[infiniflow]
|
||||
es=http://es01:9200
|
||||
postgres_user=root
|
||||
postgres_password=infiniflow_docgpt
|
||||
postgres_host=postgres
|
||||
postgres_port=5432
|
||||
minio_host=minio:9000
|
||||
minio_user=infiniflow
|
||||
minio_password=infiniflow_docgpt
|
@ -1,21 +0,0 @@
|
||||
import os
|
||||
from .embedding_model import *
|
||||
from .chat_model import *
|
||||
from .cv_model import *
|
||||
|
||||
EmbeddingModel = None
|
||||
ChatModel = None
|
||||
CvModel = None
|
||||
|
||||
|
||||
if os.environ.get("OPENAI_API_KEY"):
|
||||
EmbeddingModel = GptEmbed()
|
||||
ChatModel = GptTurbo()
|
||||
CvModel = GptV4()
|
||||
|
||||
elif os.environ.get("DASHSCOPE_API_KEY"):
|
||||
EmbeddingModel = QWenEmbd()
|
||||
ChatModel = QWenChat()
|
||||
CvModel = QWenCV()
|
||||
else:
|
||||
EmbeddingModel = HuEmbedding()
|
@ -1,61 +0,0 @@
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
|
||||
class HuEmbedding(Base):
|
||||
def __init__(self):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
res = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
||||
return np.array(res)
|
||||
|
||||
|
||||
class GptEmbed(Base):
|
||||
def __init__(self):
|
||||
self.client = OpenAI(api_key=os.envirement["OPENAI_API_KEY"])
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
res = self.client.embeddings.create(input=texts,
|
||||
model="text-embedding-ada-002")
|
||||
return [d["embedding"] for d in res["data"]]
|
||||
|
||||
|
||||
class QWenEmbd(Base):
|
||||
def encode(self, texts: list, batch_size=32, text_type="document"):
|
||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
res = []
|
||||
for txt in texts:
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=dashscope.TextEmbedding.Models.text_embedding_v2,
|
||||
input=txt[:2048],
|
||||
text_type=text_type
|
||||
)
|
||||
res.append(resp["output"]["embeddings"][0]["embedding"])
|
||||
return res
|
0
python/output/ToPDF.pdf
Normal file
0
python/output/ToPDF.pdf
Normal file
@ -1,25 +0,0 @@
|
||||
from openpyxl import load_workbook
|
||||
import sys
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class HuExcelParser:
|
||||
def __call__(self, fnm):
|
||||
if isinstance(fnm, str):
|
||||
wb = load_workbook(fnm)
|
||||
else:
|
||||
wb = load_workbook(BytesIO(fnm))
|
||||
res = []
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
lines = []
|
||||
for r in ws.rows:
|
||||
lines.append(
|
||||
"\t".join([str(c.value) if c.value is not None else "" for c in r]))
|
||||
res.append(f"《{sheetname}》\n" + "\n".join(lines))
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
psr = HuExcelParser()
|
||||
psr(sys.argv[1])
|
@ -1,194 +0,0 @@
|
||||
accelerate==0.24.1
|
||||
addict==2.4.0
|
||||
aiobotocore==2.7.0
|
||||
aiofiles==23.2.1
|
||||
aiohttp==3.8.6
|
||||
aioitertools==0.11.0
|
||||
aiosignal==1.3.1
|
||||
aliyun-python-sdk-core==2.14.0
|
||||
aliyun-python-sdk-kms==2.16.2
|
||||
altair==5.1.2
|
||||
anyio==3.7.1
|
||||
astor==0.8.1
|
||||
async-timeout==4.0.3
|
||||
attrdict==2.0.1
|
||||
attrs==23.1.0
|
||||
Babel==2.13.1
|
||||
bce-python-sdk==0.8.92
|
||||
beautifulsoup4==4.12.2
|
||||
bitsandbytes==0.41.1
|
||||
blinker==1.7.0
|
||||
botocore==1.31.64
|
||||
cachetools==5.3.2
|
||||
certifi==2023.7.22
|
||||
cffi==1.16.0
|
||||
charset-normalizer==3.3.2
|
||||
click==8.1.7
|
||||
cloudpickle==3.0.0
|
||||
contourpy==1.2.0
|
||||
crcmod==1.7
|
||||
cryptography==41.0.5
|
||||
cssselect==1.2.0
|
||||
cssutils==2.9.0
|
||||
cycler==0.12.1
|
||||
Cython==3.0.5
|
||||
datasets==2.13.0
|
||||
datrie==0.8.2
|
||||
decorator==5.1.1
|
||||
defusedxml==0.7.1
|
||||
dill==0.3.6
|
||||
einops==0.7.0
|
||||
elastic-transport==8.10.0
|
||||
elasticsearch==8.10.1
|
||||
elasticsearch-dsl==8.9.0
|
||||
et-xmlfile==1.1.0
|
||||
fastapi==0.104.1
|
||||
ffmpy==0.3.1
|
||||
filelock==3.13.1
|
||||
fire==0.5.0
|
||||
FlagEmbedding==1.1.5
|
||||
Flask==3.0.0
|
||||
flask-babel==4.0.0
|
||||
fonttools==4.44.0
|
||||
frozenlist==1.4.0
|
||||
fsspec==2023.10.0
|
||||
future==0.18.3
|
||||
gast==0.5.4
|
||||
-e
|
||||
git+https://github.com/ggerganov/llama.cpp.git@5f6e0c0dff1e7a89331e6b25eca9a9fd71324069#egg=gguf&subdirectory=gguf-py
|
||||
gradio==3.50.2
|
||||
gradio_client==0.6.1
|
||||
greenlet==3.0.1
|
||||
h11==0.14.0
|
||||
hanziconv==0.3.2
|
||||
httpcore==1.0.1
|
||||
httpx==0.25.1
|
||||
huggingface-hub==0.17.3
|
||||
idna==3.4
|
||||
imageio==2.31.6
|
||||
imgaug==0.4.0
|
||||
importlib-metadata==6.8.0
|
||||
importlib-resources==6.1.0
|
||||
install==1.3.5
|
||||
itsdangerous==2.1.2
|
||||
Jinja2==3.1.2
|
||||
jmespath==0.10.0
|
||||
joblib==1.3.2
|
||||
jsonschema==4.19.2
|
||||
jsonschema-specifications==2023.7.1
|
||||
kiwisolver==1.4.5
|
||||
lazy_loader==0.3
|
||||
lmdb==1.4.1
|
||||
lxml==4.9.3
|
||||
MarkupSafe==2.1.3
|
||||
matplotlib==3.8.1
|
||||
modelscope==1.9.4
|
||||
mpmath==1.3.0
|
||||
multidict==6.0.4
|
||||
multiprocess==0.70.14
|
||||
networkx==3.2.1
|
||||
nltk==3.8.1
|
||||
numpy==1.24.4
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
nvidia-nccl-cu12==2.18.1
|
||||
nvidia-nvjitlink-cu12==12.3.52
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
opencv-contrib-python==4.6.0.66
|
||||
opencv-python==4.6.0.66
|
||||
openpyxl==3.1.2
|
||||
opt-einsum==3.3.0
|
||||
orjson==3.9.10
|
||||
oss2==2.18.3
|
||||
packaging==23.2
|
||||
paddleocr==2.7.0.3
|
||||
paddlepaddle-gpu==2.5.2.post120
|
||||
pandas==2.1.2
|
||||
pdf2docx==0.5.5
|
||||
pdfminer.six==20221105
|
||||
pdfplumber==0.10.3
|
||||
Pillow==10.0.1
|
||||
platformdirs==3.11.0
|
||||
premailer==3.10.0
|
||||
protobuf==4.25.0
|
||||
psutil==5.9.6
|
||||
pyarrow==14.0.0
|
||||
pyclipper==1.3.0.post5
|
||||
pycocotools==2.0.7
|
||||
pycparser==2.21
|
||||
pycryptodome==3.19.0
|
||||
pydantic==1.10.13
|
||||
pydub==0.25.1
|
||||
PyMuPDF==1.20.2
|
||||
pyparsing==3.1.1
|
||||
pypdfium2==4.23.1
|
||||
python-dateutil==2.8.2
|
||||
python-docx==1.1.0
|
||||
python-multipart==0.0.6
|
||||
pytz==2023.3.post1
|
||||
PyYAML==6.0.1
|
||||
rapidfuzz==3.5.2
|
||||
rarfile==4.1
|
||||
referencing==0.30.2
|
||||
regex==2023.10.3
|
||||
requests==2.31.0
|
||||
rpds-py==0.12.0
|
||||
s3fs==2023.10.0
|
||||
safetensors==0.4.0
|
||||
scikit-image==0.22.0
|
||||
scikit-learn==1.3.2
|
||||
scipy==1.11.3
|
||||
semantic-version==2.10.0
|
||||
sentence-transformers==2.2.2
|
||||
sentencepiece==0.1.98
|
||||
shapely==2.0.2
|
||||
simplejson==3.19.2
|
||||
six==1.16.0
|
||||
sniffio==1.3.0
|
||||
sortedcontainers==2.4.0
|
||||
soupsieve==2.5
|
||||
SQLAlchemy==2.0.23
|
||||
starlette==0.27.0
|
||||
sympy==1.12
|
||||
tabulate==0.9.0
|
||||
tblib==3.0.0
|
||||
termcolor==2.3.0
|
||||
threadpoolctl==3.2.0
|
||||
tifffile==2023.9.26
|
||||
tiktoken==0.5.1
|
||||
timm==0.9.10
|
||||
tokenizers==0.13.3
|
||||
tomli==2.0.1
|
||||
toolz==0.12.0
|
||||
torch==2.1.0
|
||||
torchaudio==2.1.0
|
||||
torchvision==0.16.0
|
||||
tornado==6.3.3
|
||||
tqdm==4.66.1
|
||||
transformers==4.33.0
|
||||
transformers-stream-generator==0.0.4
|
||||
triton==2.1.0
|
||||
typing_extensions==4.8.0
|
||||
tzdata==2023.3
|
||||
urllib3==2.0.7
|
||||
uvicorn==0.24.0
|
||||
uvloop==0.19.0
|
||||
visualdl==2.5.3
|
||||
websockets==11.0.3
|
||||
Werkzeug==3.0.1
|
||||
wrapt==1.15.0
|
||||
xgboost==2.0.1
|
||||
xinference==0.6.0
|
||||
xorbits==0.7.0
|
||||
xoscar==0.1.3
|
||||
xxhash==3.4.1
|
||||
yapf==0.40.2
|
||||
yarl==1.9.2
|
||||
zipp==3.17.0
|
8
python/res/1-0.tm
Normal file
8
python/res/1-0.tm
Normal file
@ -0,0 +1,8 @@
|
||||
2023-12-20 11:44:08.791336+00:00
|
||||
2023-12-20 11:44:08.853249+00:00
|
||||
2023-12-20 11:44:08.909933+00:00
|
||||
2023-12-21 00:47:09.996757+00:00
|
||||
2023-12-20 11:44:08.965855+00:00
|
||||
2023-12-20 11:44:09.011682+00:00
|
||||
2023-12-21 00:47:10.063326+00:00
|
||||
2023-12-20 11:44:09.069486+00:00
|
3
python/res/thumbnail-1-0.tm
Normal file
3
python/res/thumbnail-1-0.tm
Normal file
@ -0,0 +1,3 @@
|
||||
2023-12-27 08:21:49.309802+00:00
|
||||
2023-12-27 08:37:22.407772+00:00
|
||||
2023-12-27 08:59:18.845627+00:00
|
@ -1,118 +0,0 @@
|
||||
import sys, datetime, random, re, cv2
|
||||
from os.path import dirname, realpath
|
||||
sys.path.append(dirname(realpath(__file__)) + "/../")
|
||||
from util.db_conn import Postgres
|
||||
from util.minio_conn import HuMinio
|
||||
from util import findMaxDt
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
import pdfplumber
|
||||
|
||||
|
||||
PG = Postgres("infiniflow", "docgpt")
|
||||
MINIO = HuMinio("infiniflow")
|
||||
def set_thumbnail(did, base64):
|
||||
sql = f"""
|
||||
update doc_info set thumbnail_base64='{base64}'
|
||||
where
|
||||
did={did}
|
||||
"""
|
||||
PG.update(sql)
|
||||
|
||||
|
||||
def collect(comm, mod, tm):
|
||||
sql = f"""
|
||||
select
|
||||
did, uid, doc_name, location, updated_at
|
||||
from doc_info
|
||||
where
|
||||
updated_at >= '{tm}'
|
||||
and MOD(did, {comm}) = {mod}
|
||||
and is_deleted=false
|
||||
and type <> 'folder'
|
||||
and thumbnail_base64=''
|
||||
order by updated_at asc
|
||||
limit 10
|
||||
"""
|
||||
docs = PG.select(sql)
|
||||
if len(docs) == 0:return pd.DataFrame()
|
||||
|
||||
mtm = str(docs["updated_at"].max())[:19]
|
||||
print("TOTAL:", len(docs), "To: ", mtm)
|
||||
return docs
|
||||
|
||||
|
||||
def build(row):
|
||||
if not re.search(r"\.(pdf|jpg|jpeg|png|gif|svg|apng|icon|ico|webp|mpg|mpeg|avi|rm|rmvb|mov|wmv|mp4)$",
|
||||
row["doc_name"].lower().strip()):
|
||||
set_thumbnail(row["did"], "_")
|
||||
return
|
||||
|
||||
def thumbnail(img, SIZE=128):
|
||||
w,h = img.size
|
||||
p = SIZE/max(w, h)
|
||||
w, h = int(w*p), int(h*p)
|
||||
img.thumbnail((w, h))
|
||||
buffered = BytesIO()
|
||||
try:
|
||||
img.save(buffered, format="JPEG")
|
||||
except Exception as e:
|
||||
try:
|
||||
img.save(buffered, format="PNG")
|
||||
except Exception as ee:
|
||||
pass
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
iobytes = BytesIO(MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
|
||||
if re.search(r"\.pdf$", row["doc_name"].lower().strip()):
|
||||
pdf = pdfplumber.open(iobytes)
|
||||
img = pdf.pages[0].to_image().annotated
|
||||
set_thumbnail(row["did"], thumbnail(img))
|
||||
|
||||
if re.search(r"\.(jpg|jpeg|png|gif|svg|apng|webp|icon|ico)$", row["doc_name"].lower().strip()):
|
||||
img = Image.open(iobytes)
|
||||
set_thumbnail(row["did"], thumbnail(img))
|
||||
|
||||
if re.search(r"\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|mp4)$", row["doc_name"].lower().strip()):
|
||||
url = MINIO.get_presigned_url("%s-upload"%str(row["uid"]),
|
||||
row["location"],
|
||||
expires=datetime.timedelta(seconds=60)
|
||||
)
|
||||
cap = cv2.VideoCapture(url)
|
||||
succ = cap.isOpened()
|
||||
i = random.randint(1, 11)
|
||||
while succ:
|
||||
ret, frame = cap.read()
|
||||
if not ret: break
|
||||
if i > 0:
|
||||
i -= 1
|
||||
continue
|
||||
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
print(img.size)
|
||||
set_thumbnail(row["did"], thumbnail(img))
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def main(comm, mod):
|
||||
global model
|
||||
tm_fnm = f"res/thumbnail-{comm}-{mod}.tm"
|
||||
tm = findMaxDt(tm_fnm)
|
||||
rows = collect(comm, mod, tm)
|
||||
if len(rows) == 0:return
|
||||
|
||||
tmf = open(tm_fnm, "a+")
|
||||
for _, r in rows.iterrows():
|
||||
build(r)
|
||||
tmf.write(str(r["updated_at"]) + "\n")
|
||||
tmf.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
main(comm.Get_size(), comm.Get_rank())
|
||||
|
@ -1,165 +0,0 @@
|
||||
#-*- coding:utf-8 -*-
|
||||
import sys, os, re,inspect,json,traceback,logging,argparse, copy
|
||||
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
|
||||
from tornado.web import RequestHandler,Application
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.options import define,options
|
||||
from util import es_conn, setup_logging
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
from nlp import huqie
|
||||
from nlp import query as Query
|
||||
from nlp import search
|
||||
from llm import HuEmbedding, GptTurbo
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
from util import config
|
||||
from timeit import default_timer as timer
|
||||
from collections import OrderedDict
|
||||
from llm import ChatModel, EmbeddingModel
|
||||
|
||||
SE = None
|
||||
CFIELD="content_ltks"
|
||||
EMBEDDING = EmbeddingModel
|
||||
LLM = ChatModel
|
||||
|
||||
def get_QA_pairs(hists):
|
||||
pa = []
|
||||
for h in hists:
|
||||
for k in ["user", "assistant"]:
|
||||
if h.get(k):
|
||||
pa.append({
|
||||
"content": h[k],
|
||||
"role": k,
|
||||
})
|
||||
|
||||
for p in pa[:-1]: assert len(p) == 2, p
|
||||
return pa
|
||||
|
||||
|
||||
|
||||
def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"):
|
||||
max_len //= len(top_i)
|
||||
# add instruction to prompt
|
||||
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
|
||||
if len(instructions)>2:
|
||||
# Said that LLM is sensitive to the first and the last one, so
|
||||
# rearrange the order of references
|
||||
instructions.append(copy.deepcopy(instructions[1]))
|
||||
instructions.pop(1)
|
||||
|
||||
def token_num(txt):
|
||||
c = 0
|
||||
for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt):
|
||||
if re.match(r"[a-zA-Z-]+$", tk):
|
||||
c += 1
|
||||
continue
|
||||
c += len(tk)
|
||||
return c
|
||||
|
||||
_inst = ""
|
||||
for ins in instructions:
|
||||
if token_num(_inst) > 4096:
|
||||
_inst += "\n知识库:" + instructions[-1][:max_len]
|
||||
break
|
||||
_inst += "\n知识库:" + ins[:max_len]
|
||||
return _inst
|
||||
|
||||
|
||||
def prompt_and_answer(history, inst):
|
||||
hist = get_QA_pairs(history)
|
||||
chks = []
|
||||
for s in re.split(r"[::;;。\n\r]+", inst):
|
||||
if s: chks.append(s)
|
||||
chks = len(set(chks))/(0.1+len(chks))
|
||||
print("Duplication portion:", chks)
|
||||
|
||||
system = """
|
||||
你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。
|
||||
以下是知识库:
|
||||
%s
|
||||
以上是知识库。
|
||||
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
|
||||
|
||||
print("【PROMPT】:", system)
|
||||
start = timer()
|
||||
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
|
||||
print("GENERATE: ", timer()-start)
|
||||
print("===>>", response)
|
||||
return response
|
||||
|
||||
|
||||
class Handler(RequestHandler):
|
||||
def post(self):
|
||||
global SE,MUST_TK_NUM
|
||||
param = json.loads(self.request.body.decode('utf-8'))
|
||||
try:
|
||||
question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
|
||||
res = SE.search({
|
||||
"question": question,
|
||||
"kb_ids": param.get("kb_ids", []),
|
||||
"size": param.get("topn", 15)},
|
||||
search.index_name(param["uid"])
|
||||
)
|
||||
|
||||
sim = SE.rerank(res, question)
|
||||
rk_idx = np.argsort(sim*-1)
|
||||
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
|
||||
inst = get_instruction(res, topidx)
|
||||
|
||||
ans, topidx = prompt_and_answer(param["history"], inst)
|
||||
ans = SE.insert_citations(ans, topidx, res)
|
||||
|
||||
refer = OrderedDict()
|
||||
docnms = {}
|
||||
for i in rk_idx:
|
||||
did = res.field[res.ids[i]]["doc_id"]
|
||||
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"]
|
||||
if did not in refer: refer[did] = []
|
||||
refer[did].append({
|
||||
"chunk_id": res.ids[i],
|
||||
"content": res.field[res.ids[i]]["content_ltks"],
|
||||
"image": ""
|
||||
})
|
||||
|
||||
print("::::::::::::::", ans)
|
||||
self.write(json.dumps({
|
||||
"code":0,
|
||||
"msg":"success",
|
||||
"data":{
|
||||
"uid": param["uid"],
|
||||
"dialog_id": param["dialog_id"],
|
||||
"assistant": ans,
|
||||
"refer": [{
|
||||
"did": did,
|
||||
"doc_name": docnms[did],
|
||||
"chunks": chunks
|
||||
} for did, chunks in refer.items()]
|
||||
}
|
||||
}))
|
||||
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
|
||||
|
||||
except Exception as e:
|
||||
logging.error("Request 500: "+str(e))
|
||||
self.write(json.dumps({
|
||||
"code":500,
|
||||
"msg":str(e),
|
||||
"data":{}
|
||||
}))
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", default=4455, type=int, help="Port used for service")
|
||||
ARGS = parser.parse_args()
|
||||
|
||||
SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING)
|
||||
|
||||
app = Application([(r'/v1/chat/completions', Handler)],debug=False)
|
||||
http_server = HTTPServer(app)
|
||||
http_server.bind(ARGS.port)
|
||||
http_server.start(3)
|
||||
|
||||
IOLoop.current().start()
|
||||
|
@ -1,258 +0,0 @@
|
||||
import json, os, sys, hashlib, copy, time, random, re
|
||||
from os.path import dirname, realpath
|
||||
sys.path.append(dirname(realpath(__file__)) + "/../")
|
||||
from util.es_conn import HuEs
|
||||
from util.db_conn import Postgres
|
||||
from util.minio_conn import HuMinio
|
||||
from util import rmSpace, findMaxDt
|
||||
from FlagEmbedding import FlagModel
|
||||
from nlp import huchunk, huqie, search
|
||||
from io import BytesIO
|
||||
import pandas as pd
|
||||
from elasticsearch_dsl import Q
|
||||
from PIL import Image
|
||||
from parser import (
|
||||
PdfParser,
|
||||
DocxParser,
|
||||
ExcelParser
|
||||
)
|
||||
from nlp.huchunk import (
|
||||
PdfChunker,
|
||||
DocxChunker,
|
||||
ExcelChunker,
|
||||
PptChunker,
|
||||
TextChunker
|
||||
)
|
||||
|
||||
ES = HuEs("infiniflow")
|
||||
BATCH_SIZE = 64
|
||||
PG = Postgres("infiniflow", "docgpt")
|
||||
MINIO = HuMinio("infiniflow")
|
||||
|
||||
PDF = PdfChunker(PdfParser())
|
||||
DOC = DocxChunker(DocxParser())
|
||||
EXC = ExcelChunker(ExcelParser())
|
||||
PPT = PptChunker()
|
||||
|
||||
def chuck_doc(name, binary):
|
||||
suff = os.path.split(name)[-1].lower().split(".")[-1]
|
||||
if suff.find("pdf") >= 0: return PDF(binary)
|
||||
if suff.find("doc") >= 0: return DOC(binary)
|
||||
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
|
||||
if suff.find("ppt") >= 0: return PPT(binary)
|
||||
if os.envirement.get("PARSE_IMAGE") \
|
||||
and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
|
||||
name.lower()):
|
||||
from llm import CvModel
|
||||
txt = CvModel.describe(binary)
|
||||
field = TextChunker.Fields()
|
||||
field.text_chunks = [(txt, binary)]
|
||||
field.table_chunks = []
|
||||
|
||||
|
||||
return TextChunker()(binary)
|
||||
|
||||
|
||||
def collect(comm, mod, tm):
|
||||
sql = f"""
|
||||
select
|
||||
id as kb2doc_id,
|
||||
kb_id,
|
||||
did,
|
||||
updated_at,
|
||||
is_deleted
|
||||
from kb2_doc
|
||||
where
|
||||
updated_at >= '{tm}'
|
||||
and kb_progress = 0
|
||||
and MOD(did, {comm}) = {mod}
|
||||
order by updated_at asc
|
||||
limit 1000
|
||||
"""
|
||||
kb2doc = PG.select(sql)
|
||||
if len(kb2doc) == 0:return pd.DataFrame()
|
||||
|
||||
sql = """
|
||||
select
|
||||
did,
|
||||
uid,
|
||||
doc_name,
|
||||
location,
|
||||
size
|
||||
from doc_info
|
||||
where
|
||||
did in (%s)
|
||||
"""%",".join([str(i) for i in kb2doc["did"].unique()])
|
||||
docs = PG.select(sql)
|
||||
docs = docs.fillna("")
|
||||
docs = docs.join(kb2doc.set_index("did"), on="did", how="left")
|
||||
|
||||
mtm = str(docs["updated_at"].max())[:19]
|
||||
print("TOTAL:", len(docs), "To: ", mtm)
|
||||
return docs
|
||||
|
||||
|
||||
def set_progress(kb2doc_id, prog, msg="Processing..."):
|
||||
sql = f"""
|
||||
update kb2_doc set kb_progress={prog}, kb_progress_msg='{msg}'
|
||||
where
|
||||
id={kb2doc_id}
|
||||
"""
|
||||
PG.update(sql)
|
||||
|
||||
|
||||
def build(row):
|
||||
if row["size"] > 256000000:
|
||||
set_progress(row["kb2doc_id"], -1, "File size exceeds( <= 256Mb )")
|
||||
return []
|
||||
res = ES.search(Q("term", doc_id=row["did"]))
|
||||
if ES.getTotal(res) > 0:
|
||||
ES.updateScriptByQuery(Q("term", doc_id=row["did"]),
|
||||
scripts="""
|
||||
if(!ctx._source.kb_id.contains('%s'))
|
||||
ctx._source.kb_id.add('%s');
|
||||
"""%(str(row["kb_id"]), str(row["kb_id"])),
|
||||
idxnm = search.index_name(row["uid"])
|
||||
)
|
||||
set_progress(row["kb2doc_id"], 1, "Done")
|
||||
return []
|
||||
|
||||
random.seed(time.time())
|
||||
set_progress(row["kb2doc_id"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!")
|
||||
try:
|
||||
obj = chuck_doc(row["doc_name"], MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
|
||||
except Exception as e:
|
||||
if re.search("(No such file|not found)", str(e)):
|
||||
set_progress(row["kb2doc_id"], -1, "Can not find file <%s>"%row["doc_name"])
|
||||
else:
|
||||
set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
|
||||
return []
|
||||
|
||||
if not obj.text_chunks and not obj.table_chunks:
|
||||
set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
|
||||
return []
|
||||
|
||||
set_progress(row["kb2doc_id"], random.randint(20, 60)/100., "Finished slicing files. Start to embedding the content.")
|
||||
|
||||
doc = {
|
||||
"doc_id": row["did"],
|
||||
"kb_id": [str(row["kb_id"])],
|
||||
"docnm_kwd": os.path.split(row["location"])[-1],
|
||||
"title_tks": huqie.qie(os.path.split(row["location"])[-1]),
|
||||
"updated_at": str(row["updated_at"]).replace("T", " ")[:19]
|
||||
}
|
||||
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
||||
output_buffer = BytesIO()
|
||||
docs = []
|
||||
md5 = hashlib.md5()
|
||||
for txt, img in obj.text_chunks:
|
||||
d = copy.deepcopy(doc)
|
||||
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
d["content_ltks"] = huqie.qie(txt)
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
if not img:
|
||||
docs.append(d)
|
||||
continue
|
||||
|
||||
if isinstance(img, Image): img.save(output_buffer, format='JPEG')
|
||||
else: output_buffer = BytesIO(img)
|
||||
|
||||
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
|
||||
output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
|
||||
docs.append(d)
|
||||
|
||||
for arr, img in obj.table_chunks:
|
||||
for i, txt in enumerate(arr):
|
||||
d = copy.deepcopy(doc)
|
||||
d["content_ltks"] = huqie.qie(txt)
|
||||
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
if not img:
|
||||
docs.append(d)
|
||||
continue
|
||||
img.save(output_buffer, format='JPEG')
|
||||
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
|
||||
output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
|
||||
docs.append(d)
|
||||
set_progress(row["kb2doc_id"], random.randint(60, 70)/100., "Continue embedding the content.")
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def init_kb(row):
|
||||
idxnm = search.index_name(row["uid"])
|
||||
if ES.indexExist(idxnm): return
|
||||
return ES.createIdx(idxnm, json.load(open("conf/mapping.json", "r")))
|
||||
|
||||
|
||||
model = None
|
||||
def embedding(docs):
|
||||
global model
|
||||
tts = model.encode([rmSpace(d["title_tks"]) for d in docs])
|
||||
cnts = model.encode([rmSpace(d["content_ltks"]) for d in docs])
|
||||
vects = 0.1 * tts + 0.9 * cnts
|
||||
assert len(vects) == len(docs)
|
||||
for i,d in enumerate(docs):d["q_vec"] = vects[i].tolist()
|
||||
|
||||
|
||||
def rm_doc_from_kb(df):
|
||||
if len(df) == 0:return
|
||||
for _,r in df.iterrows():
|
||||
ES.updateScriptByQuery(Q("term", doc_id=r["did"]),
|
||||
scripts="""
|
||||
if(ctx._source.kb_id.contains('%s'))
|
||||
ctx._source.kb_id.remove(
|
||||
ctx._source.kb_id.indexOf('%s')
|
||||
);
|
||||
"""%(str(r["kb_id"]),str(r["kb_id"])),
|
||||
idxnm = search.index_name(r["uid"])
|
||||
)
|
||||
if len(df) == 0:return
|
||||
sql = """
|
||||
delete from kb2_doc where id in (%s)
|
||||
"""%",".join([str(i) for i in df["kb2doc_id"]])
|
||||
PG.update(sql)
|
||||
|
||||
|
||||
def main(comm, mod):
|
||||
global model
|
||||
from llm import HuEmbedding
|
||||
model = HuEmbedding()
|
||||
tm_fnm = f"res/{comm}-{mod}.tm"
|
||||
tm = findMaxDt(tm_fnm)
|
||||
rows = collect(comm, mod, tm)
|
||||
if len(rows) == 0:return
|
||||
|
||||
rm_doc_from_kb(rows.loc[rows.is_deleted == True])
|
||||
rows = rows.loc[rows.is_deleted == False].reset_index(drop=True)
|
||||
if len(rows) == 0:return
|
||||
tmf = open(tm_fnm, "a+")
|
||||
for _, r in rows.iterrows():
|
||||
cks = build(r)
|
||||
if not cks:
|
||||
tmf.write(str(r["updated_at"]) + "\n")
|
||||
continue
|
||||
## TODO: exception handler
|
||||
## set_progress(r["did"], -1, "ERROR: ")
|
||||
embedding(cks)
|
||||
|
||||
set_progress(r["kb2doc_id"], random.randint(70, 95)/100.,
|
||||
"Finished embedding! Start to build index!")
|
||||
init_kb(r)
|
||||
es_r = ES.bulk(cks, search.index_name(r["uid"]))
|
||||
if es_r:
|
||||
set_progress(r["kb2doc_id"], -1, "Index failure!")
|
||||
print(es_r)
|
||||
else: set_progress(r["kb2doc_id"], 1., "Done!")
|
||||
tmf.write(str(r["updated_at"]) + "\n")
|
||||
tmf.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
main(comm.Get_size(), comm.Get_rank())
|
||||
|
15
python/tmp.log
Normal file
15
python/tmp.log
Normal file
@ -0,0 +1,15 @@
|
||||
Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s]
Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 106184.91it/s]
|
||||
----------- Model Configuration -----------
|
||||
Model Arch: GFL
|
||||
Transform Order:
|
||||
--transform op: Resize
|
||||
--transform op: NormalizeImage
|
||||
--transform op: Permute
|
||||
--transform op: PadStride
|
||||
--------------------------------------------
|
||||
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
|
||||
The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.
|
||||
Some weights of the model checkpoint at microsoft/table-transformer-structure-recognition were not used when initializing TableTransformerForObjectDetection: ['model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
|
||||
- This IS expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
|
||||
- This IS NOT expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
|
||||
WARNING:root:The files are stored in /opt/home/kevinhu/docgpt/, please check it!
|
@ -1,31 +0,0 @@
|
||||
from configparser import ConfigParser
|
||||
import os
|
||||
import inspect
|
||||
|
||||
CF = ConfigParser()
|
||||
__fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
|
||||
if not os.path.exists(__fnm):
|
||||
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
|
||||
assert os.path.exists(
|
||||
__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
|
||||
if not os.path.exists(__fnm):
|
||||
__fnm = "./sys.cnf"
|
||||
|
||||
CF.read(__fnm)
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, env):
|
||||
self.env = env
|
||||
if env == "spark":
|
||||
CF.read("./cv.cnf")
|
||||
|
||||
def get(self, key, default=None):
|
||||
global CF
|
||||
return os.environ.get(key.upper(),
|
||||
CF[self.env].get(key, default)
|
||||
)
|
||||
|
||||
|
||||
def init(env):
|
||||
return Config(env)
|
@ -1,70 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from util import config
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class Postgres(object):
|
||||
def __init__(self, env, dbnm):
|
||||
self.config = config.init(env)
|
||||
self.conn = None
|
||||
self.dbnm = dbnm
|
||||
self.__open__()
|
||||
|
||||
def __open__(self):
|
||||
import psycopg2
|
||||
try:
|
||||
if self.conn:
|
||||
self.__close__()
|
||||
del self.conn
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.conn = psycopg2.connect(f"""dbname={self.dbnm}
|
||||
user={self.config.get('postgres_user')}
|
||||
password={self.config.get('postgres_password')}
|
||||
host={self.config.get('postgres_host')}
|
||||
port={self.config.get('postgres_port')}""")
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
"Fail to connect %s " %
|
||||
self.config.get("pgdb_host") + str(e))
|
||||
|
||||
def __close__(self):
|
||||
try:
|
||||
self.conn.close()
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
"Fail to close %s " %
|
||||
self.config.get("pgdb_host") + str(e))
|
||||
|
||||
def select(self, sql):
|
||||
for _ in range(10):
|
||||
try:
|
||||
return pd.read_sql(sql, self.conn)
|
||||
except Exception as e:
|
||||
logging.error(f"Fail to exec {sql} " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
|
||||
return pd.DataFrame()
|
||||
|
||||
def update(self, sql):
|
||||
for _ in range(10):
|
||||
try:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute(sql)
|
||||
updated_rows = cur.rowcount
|
||||
self.conn.commit()
|
||||
cur.close()
|
||||
return updated_rows
|
||||
except Exception as e:
|
||||
logging.error(f"Fail to exec {sql} " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Postgres("infiniflow", "docgpt")
|
@ -1,36 +0,0 @@
|
||||
import json
|
||||
import logging.config
|
||||
import os
|
||||
|
||||
|
||||
def log_dir():
|
||||
fnm = os.path.join(os.path.dirname(__file__), '../log/')
|
||||
if not os.path.exists(fnm):
|
||||
fnm = os.path.join(os.path.dirname(__file__), '../../log/')
|
||||
assert os.path.exists(fnm), f"Can't locate log dir: {fnm}"
|
||||
return fnm
|
||||
|
||||
|
||||
def setup_logging(default_path="conf/logging.json",
|
||||
default_level=logging.INFO,
|
||||
env_key="LOG_CFG"):
|
||||
path = default_path
|
||||
value = os.getenv(env_key, None)
|
||||
if value:
|
||||
path = value
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
config = json.load(f)
|
||||
fnm = log_dir()
|
||||
|
||||
config["handlers"]["info_file_handler"]["filename"] = fnm + "info.log"
|
||||
config["handlers"]["error_file_handler"]["filename"] = fnm + "error.log"
|
||||
logging.config.dictConfig(config)
|
||||
else:
|
||||
logging.basicConfig(level=default_level)
|
||||
|
||||
|
||||
__fnm = os.path.join(os.path.dirname(__file__), 'conf/logging.json')
|
||||
if not os.path.exists(__fnm):
|
||||
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/logging.json')
|
||||
setup_logging(__fnm)
|
0
rag/__init__.py
Normal file
0
rag/__init__.py
Normal file
32
rag/llm/__init__.py
Normal file
32
rag/llm/__init__.py
Normal file
@ -0,0 +1,32 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from .embedding_model import *
|
||||
from .chat_model import *
|
||||
from .cv_model import *
|
||||
|
||||
|
||||
EmbeddingModel = {
|
||||
"local": HuEmbedding,
|
||||
"OpenAI": OpenAIEmbed,
|
||||
"通义千问": QWenEmbed,
|
||||
}
|
||||
|
||||
|
||||
CvModel = {
|
||||
"OpenAI": GptV4,
|
||||
"通义千问": QWenCV,
|
||||
}
|
||||
|
@ -1,3 +1,18 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
import os
|
@ -1,3 +1,18 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
import os
|
||||
@ -6,6 +21,9 @@ from io import BytesIO
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
pass
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
@ -40,14 +58,15 @@ class Base(ABC):
|
||||
|
||||
|
||||
class GptV4(Base):
|
||||
def __init__(self):
|
||||
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
||||
def __init__(self, key, model_name="gpt-4-vision-preview"):
|
||||
self.client = OpenAI(key)
|
||||
self.model_name = model_name
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
b64 = self.image2base64(image)
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model="gpt-4-vision-preview",
|
||||
model=self.model_name,
|
||||
messages=self.prompt(b64),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
@ -55,11 +74,15 @@ class GptV4(Base):
|
||||
|
||||
|
||||
class QWenCV(Base):
|
||||
def __init__(self, key, model_name="qwen-vl-chat-v1"):
|
||||
import dashscope
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
from http import HTTPStatus
|
||||
from dashscope import MultiModalConversation
|
||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||
response = MultiModalConversation.call(model=MultiModalConversation.Models.qwen_vl_chat_v1,
|
||||
response = MultiModalConversation.call(model=self.model_name,
|
||||
messages=self.prompt(self.image2base64(image)))
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content']
|
94
rag/llm/embedding_model.py
Normal file
94
rag/llm/embedding_model.py
Normal file
@ -0,0 +1,94 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
|
||||
import dashscope
|
||||
from openai import OpenAI
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
pass
|
||||
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
|
||||
class HuEmbedding(Base):
|
||||
def __init__(self):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
token_count = 0
|
||||
for t in texts: token_count += num_tokens_from_string(t)
|
||||
res = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
||||
return np.array(res), token_count
|
||||
|
||||
|
||||
class OpenAIEmbed(Base):
|
||||
def __init__(self, key, model_name="text-embedding-ada-002"):
|
||||
self.client = OpenAI(key)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
token_count = 0
|
||||
for t in texts: token_count += num_tokens_from_string(t)
|
||||
res = self.client.embeddings.create(input=texts,
|
||||
model=self.model_name)
|
||||
return [d["embedding"] for d in res["data"]], token_count
|
||||
|
||||
|
||||
class QWenEmbed(Base):
|
||||
def __init__(self, key, model_name="text_embedding_v2"):
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=32, text_type="document"):
|
||||
import dashscope
|
||||
res = []
|
||||
token_count = 0
|
||||
for txt in texts:
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=self.model_name,
|
||||
input=txt[:2048],
|
||||
text_type=text_type
|
||||
)
|
||||
res.append(resp["output"]["embeddings"][0]["embedding"])
|
||||
token_count += resp["usage"]["total_tokens"]
|
||||
return res, token_count
|
0
rag/nlp/__init__.py
Normal file
0
rag/nlp/__init__.py
Normal file
11
python/nlp/huqie.py → rag/nlp/huqie.py
Executable file → Normal file
11
python/nlp/huqie.py → rag/nlp/huqie.py
Executable file → Normal file
@ -9,6 +9,8 @@ import string
|
||||
import sys
|
||||
from hanziconv import HanziConv
|
||||
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class Huqie:
|
||||
def key_(self, line):
|
||||
@ -41,14 +43,7 @@ class Huqie:
|
||||
self.DEBUG = debug
|
||||
self.DENOMINATOR = 1000000
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
self.DIR_ = ""
|
||||
if os.path.exists("../res/huqie.txt"):
|
||||
self.DIR_ = "../res/huqie"
|
||||
if os.path.exists("./res/huqie.txt"):
|
||||
self.DIR_ = "./res/huqie"
|
||||
if os.path.exists("./huqie.txt"):
|
||||
self.DIR_ = "./huqie"
|
||||
assert self.DIR_, f"【Can't find huqie】"
|
||||
self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
|
||||
|
||||
self.SPLIT_CHAR = r"([ ,\.<>/?;'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
|
||||
try:
|
6
python/nlp/query.py → rag/nlp/query.py
Executable file → Normal file
6
python/nlp/query.py → rag/nlp/query.py
Executable file → Normal file
@ -1,12 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import math
|
||||
from elasticsearch_dsl import Q, Search
|
||||
from nlp import huqie, term_weight, synonym
|
||||
from rag.nlp import huqie, term_weight, synonym
|
||||
|
||||
|
||||
class EsQueryer:
|
@ -1,13 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
from elasticsearch_dsl import Q, Search, A
|
||||
from typing import List, Optional, Tuple, Dict, Union
|
||||
from dataclasses import dataclass
|
||||
from util import setup_logging, rmSpace
|
||||
from nlp import huqie, query
|
||||
from datetime import datetime
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
from rag.utils import rmSpace
|
||||
from rag.nlp import huqie, query
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def index_name(uid): return f"docgpt_{uid}"
|
17
python/nlp/synonym.py → rag/nlp/synonym.py
Executable file → Normal file
17
python/nlp/synonym.py → rag/nlp/synonym.py
Executable file → Normal file
@ -1,8 +1,11 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import re
|
||||
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self, redis=None):
|
||||
@ -10,18 +13,12 @@ class Dealer:
|
||||
self.lookup_num = 100000000
|
||||
self.load_tm = time.time() - 1000000
|
||||
self.dictionary = None
|
||||
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
|
||||
try:
|
||||
self.dictionary = json.load(open("./synonym.json", 'r'))
|
||||
self.dictionary = json.load(open(path, 'r'))
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
self.dictionary = json.load(open("./res/synonym.json", 'r'))
|
||||
except Exception as e:
|
||||
try:
|
||||
self.dictionary = json.load(open("../res/synonym.json", 'r'))
|
||||
except Exception as e:
|
||||
logging.warn("Miss synonym.json")
|
||||
self.dictionary = {}
|
||||
logging.warn("Miss synonym.json")
|
||||
self.dictionary = {}
|
||||
|
||||
if not redis:
|
||||
logging.warning(
|
12
python/nlp/term_weight.py → rag/nlp/term_weight.py
Executable file → Normal file
12
python/nlp/term_weight.py → rag/nlp/term_weight.py
Executable file → Normal file
@ -1,9 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import math
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
from nlp import huqie
|
||||
from rag.nlp import huqie
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class Dealer:
|
||||
@ -60,16 +62,14 @@ class Dealer:
|
||||
return set(res.keys())
|
||||
return res
|
||||
|
||||
fnm = os.path.join(os.path.dirname(__file__), '../res/')
|
||||
if not os.path.exists(fnm):
|
||||
fnm = os.path.join(os.path.dirname(__file__), '../../res/')
|
||||
fnm = os.path.join(get_project_base_directory(), "res")
|
||||
self.ne, self.df = {}, {}
|
||||
try:
|
||||
self.ne = json.load(open(fnm + "ner.json", "r"))
|
||||
self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
|
||||
except Exception as e:
|
||||
print("[WARNING] Load ner.json FAIL!")
|
||||
try:
|
||||
self.df = load_dict(fnm + "term.freq")
|
||||
self.df = load_dict(os.path.join(fnm, "term.freq"))
|
||||
except Exception as e:
|
||||
print("[WARNING] Load term.freq FAIL!")
|
||||
|
@ -1,8 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from docx import Document
|
||||
import re
|
||||
import pandas as pd
|
||||
from collections import Counter
|
||||
from nlp import huqie
|
||||
from rag.nlp import huqie
|
||||
from io import BytesIO
|
||||
|
||||
|
33
rag/parser/excel_parser.py
Normal file
33
rag/parser/excel_parser.py
Normal file
@ -0,0 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from openpyxl import load_workbook
|
||||
import sys
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class HuExcelParser:
|
||||
def __call__(self, fnm):
|
||||
if isinstance(fnm, str):
|
||||
wb = load_workbook(fnm)
|
||||
else:
|
||||
wb = load_workbook(BytesIO(fnm))
|
||||
res = []
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
rows = list(ws.rows)
|
||||
ti = list(rows[0])
|
||||
for r in list(rows[1:]):
|
||||
l = []
|
||||
for i,c in enumerate(r):
|
||||
if not c.value:continue
|
||||
t = str(ti[i].value) if i < len(ti) else ""
|
||||
t += (":" if t else "") + str(c.value)
|
||||
l.append(t)
|
||||
l = "; ".join(l)
|
||||
if sheetname.lower().find("sheet") <0: l += " ——"+sheetname
|
||||
res.append(l)
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
psr = HuExcelParser()
|
||||
psr(sys.argv[1])
|
@ -1,3 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import xgboost as xgb
|
||||
from io import BytesIO
|
||||
import torch
|
||||
@ -6,11 +7,11 @@ import pdfplumber
|
||||
import logging
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from nlp import huqie
|
||||
from rag.nlp import huqie
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
from cv.table_recognize import TableTransformer
|
||||
from cv.ppdetection import PPDet
|
||||
from rag.cv.table_recognize import TableTransformer
|
||||
from rag.cv.ppdetection import PPDet
|
||||
from huggingface_hub import hf_hub_download
|
||||
logging.getLogger("pdfminer").setLevel(logging.WARNING)
|
||||
|
0
python/res/ner.json → rag/res/ner.json
Executable file → Normal file
0
python/res/ner.json → rag/res/ner.json
Executable file → Normal file
37
rag/settings.py
Normal file
37
rag/settings.py
Normal file
@ -0,0 +1,37 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
from web_server.utils import get_base_config,decrypt_database_config
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
from web_server.utils.log_utils import LoggerFactory, getLogger
|
||||
|
||||
|
||||
# Server
|
||||
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||
SUBPROCESS_STD_LOG_NAME = "std.log"
|
||||
|
||||
ES = get_base_config("es", {})
|
||||
MINIO = decrypt_database_config(name="minio")
|
||||
DOC_MAXIMUM_SIZE = 64 * 1024 * 1024
|
||||
|
||||
# Logger
|
||||
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag"))
|
||||
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
|
||||
LoggerFactory.LEVEL = 10
|
||||
|
||||
es_logger = getLogger("es")
|
||||
minio_logger = getLogger("minio")
|
||||
cron_logger = getLogger("cron_logger")
|
279
rag/svr/parse_user_docs.py
Normal file
279
rag/svr/parse_user_docs.py
Normal file
@ -0,0 +1,279 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
import copy
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from rag.llm import EmbeddingModel, CvModel
|
||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||
from rag.utils import ELASTICSEARCH, num_tokens_from_string
|
||||
from rag.utils import MINIO
|
||||
from rag.utils import rmSpace, findMaxDt
|
||||
from rag.nlp import huchunk, huqie, search
|
||||
from io import BytesIO
|
||||
import pandas as pd
|
||||
from elasticsearch_dsl import Q
|
||||
from PIL import Image
|
||||
from rag.parser import (
|
||||
PdfParser,
|
||||
DocxParser,
|
||||
ExcelParser
|
||||
)
|
||||
from rag.nlp.huchunk import (
|
||||
PdfChunker,
|
||||
DocxChunker,
|
||||
ExcelChunker,
|
||||
PptChunker,
|
||||
TextChunker
|
||||
)
|
||||
from web_server.db import LLMType
|
||||
from web_server.db.services.document_service import DocumentService
|
||||
from web_server.db.services.llm_service import TenantLLMService
|
||||
from web_server.utils import get_format_time
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
|
||||
BATCH_SIZE = 64
|
||||
|
||||
PDF = PdfChunker(PdfParser())
|
||||
DOC = DocxChunker(DocxParser())
|
||||
EXC = ExcelChunker(ExcelParser())
|
||||
PPT = PptChunker()
|
||||
|
||||
|
||||
def chuck_doc(name, binary, cvmdl=None):
|
||||
suff = os.path.split(name)[-1].lower().split(".")[-1]
|
||||
if suff.find("pdf") >= 0:
|
||||
return PDF(binary)
|
||||
if suff.find("doc") >= 0:
|
||||
return DOC(binary)
|
||||
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff):
|
||||
return EXC(binary)
|
||||
if suff.find("ppt") >= 0:
|
||||
return PPT(binary)
|
||||
if cvmdl and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
|
||||
name.lower()):
|
||||
txt = cvmdl.describe(binary)
|
||||
field = TextChunker.Fields()
|
||||
field.text_chunks = [(txt, binary)]
|
||||
field.table_chunks = []
|
||||
|
||||
return TextChunker()(binary)
|
||||
|
||||
|
||||
def collect(comm, mod, tm):
|
||||
docs = DocumentService.get_newly_uploaded(tm, mod, comm)
|
||||
if len(docs) == 0:
|
||||
return pd.DataFrame()
|
||||
docs = pd.DataFrame(docs)
|
||||
mtm = str(docs["update_time"].max())[:19]
|
||||
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
|
||||
return docs
|
||||
|
||||
|
||||
def set_progress(docid, prog, msg="Processing...", begin=False):
|
||||
d = {"progress": prog, "progress_msg": msg}
|
||||
if begin:
|
||||
d["process_begin_at"] = get_format_time()
|
||||
try:
|
||||
DocumentService.update_by_id(
|
||||
docid, {"progress": prog, "progress_msg": msg})
|
||||
except Exception as e:
|
||||
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
|
||||
|
||||
|
||||
def build(row):
|
||||
if row["size"] > DOC_MAXIMUM_SIZE:
|
||||
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
|
||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||
return []
|
||||
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
|
||||
if ELASTICSEARCH.getTotal(res) > 0:
|
||||
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
|
||||
scripts="""
|
||||
if(!ctx._source.kb_id.contains('%s'))
|
||||
ctx._source.kb_id.add('%s');
|
||||
""" % (str(row["kb_id"]), str(row["kb_id"])),
|
||||
idxnm=search.index_name(row["tenant_id"])
|
||||
)
|
||||
set_progress(row["id"], 1, "Done")
|
||||
return []
|
||||
|
||||
random.seed(time.time())
|
||||
set_progress(row["id"], random.randint(0, 20) /
|
||||
100., "Finished preparing! Start to slice file!", True)
|
||||
try:
|
||||
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]))
|
||||
except Exception as e:
|
||||
if re.search("(No such file|not found)", str(e)):
|
||||
set_progress(
|
||||
row["id"], -1, "Can not find file <%s>" %
|
||||
row["doc_name"])
|
||||
else:
|
||||
set_progress(
|
||||
row["id"], -1, f"Internal server error: %s" %
|
||||
str(e).replace(
|
||||
"'", ""))
|
||||
return []
|
||||
|
||||
if not obj.text_chunks and not obj.table_chunks:
|
||||
set_progress(
|
||||
row["id"],
|
||||
1,
|
||||
"Nothing added! Mostly, file type unsupported yet.")
|
||||
return []
|
||||
|
||||
set_progress(row["id"], random.randint(20, 60) / 100.,
|
||||
"Finished slicing files. Start to embedding the content.")
|
||||
|
||||
doc = {
|
||||
"doc_id": row["did"],
|
||||
"kb_id": [str(row["kb_id"])],
|
||||
"docnm_kwd": os.path.split(row["location"])[-1],
|
||||
"title_tks": huqie.qie(row["name"]),
|
||||
"updated_at": str(row["update_time"]).replace("T", " ")[:19]
|
||||
}
|
||||
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
||||
output_buffer = BytesIO()
|
||||
docs = []
|
||||
md5 = hashlib.md5()
|
||||
for txt, img in obj.text_chunks:
|
||||
d = copy.deepcopy(doc)
|
||||
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
d["content_ltks"] = huqie.qie(txt)
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
if not img:
|
||||
docs.append(d)
|
||||
continue
|
||||
|
||||
if isinstance(img, Image):
|
||||
img.save(output_buffer, format='JPEG')
|
||||
else:
|
||||
output_buffer = BytesIO(img)
|
||||
|
||||
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
||||
docs.append(d)
|
||||
|
||||
for arr, img in obj.table_chunks:
|
||||
for i, txt in enumerate(arr):
|
||||
d = copy.deepcopy(doc)
|
||||
d["content_ltks"] = huqie.qie(txt)
|
||||
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
if not img:
|
||||
docs.append(d)
|
||||
continue
|
||||
img.save(output_buffer, format='JPEG')
|
||||
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
||||
docs.append(d)
|
||||
set_progress(row["id"], random.randint(60, 70) /
|
||||
100., "Continue embedding the content.")
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def init_kb(row):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
if ELASTICSEARCH.indexExist(idxnm):
|
||||
return
|
||||
return ELASTICSEARCH.createIdx(idxnm, json.load(
|
||||
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
||||
|
||||
|
||||
def embedding(docs, mdl):
|
||||
tts, cnts = [rmSpace(d["title_tks"]) for d in docs], [rmSpace(d["content_ltks"]) for d in docs]
|
||||
tk_count = 0
|
||||
tts, c = mdl.encode(tts)
|
||||
tk_count += c
|
||||
cnts, c = mdl.encode(cnts)
|
||||
tk_count += c
|
||||
vects = 0.1 * tts + 0.9 * cnts
|
||||
assert len(vects) == len(docs)
|
||||
for i, d in enumerate(docs):
|
||||
d["q_vec"] = vects[i].tolist()
|
||||
return tk_count
|
||||
|
||||
|
||||
def model_instance(tenant_id, llm_type):
|
||||
model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING)
|
||||
if not model_config:return
|
||||
model_config = model_config[0]
|
||||
if llm_type == LLMType.EMBEDDING:
|
||||
if model_config.llm_factory not in EmbeddingModel: return
|
||||
return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
|
||||
if llm_type == LLMType.IMAGE2TEXT:
|
||||
if model_config.llm_factory not in CvModel: return
|
||||
return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
|
||||
|
||||
|
||||
def main(comm, mod):
|
||||
global model
|
||||
from rag.llm import HuEmbedding
|
||||
model = HuEmbedding()
|
||||
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
|
||||
tm = findMaxDt(tm_fnm)
|
||||
rows = collect(comm, mod, tm)
|
||||
if len(rows) == 0:
|
||||
return
|
||||
|
||||
tmf = open(tm_fnm, "a+")
|
||||
for _, r in rows.iterrows():
|
||||
embd_mdl = model_instance(r["tenant_id"], LLMType.EMBEDDING)
|
||||
if not embd_mdl:
|
||||
set_progress(r["id"], -1, "Can't find embedding model!")
|
||||
cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"]))
|
||||
continue
|
||||
cv_mdl = model_instance(r["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
st_tm = timer()
|
||||
cks = build(r, cv_mdl)
|
||||
if not cks:
|
||||
tmf.write(str(r["updated_at"]) + "\n")
|
||||
continue
|
||||
# TODO: exception handler
|
||||
## set_progress(r["did"], -1, "ERROR: ")
|
||||
try:
|
||||
tk_count = embedding(cks, embd_mdl)
|
||||
except Exception as e:
|
||||
set_progress(r["id"], -1, "Embedding error:{}".format(str(e)))
|
||||
cron_logger.error(str(e))
|
||||
continue
|
||||
|
||||
|
||||
set_progress(r["id"], random.randint(70, 95) / 100.,
|
||||
"Finished embedding! Start to build index!")
|
||||
init_kb(r)
|
||||
es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
|
||||
if es_r:
|
||||
set_progress(r["id"], -1, "Index failure!")
|
||||
cron_logger.error(str(es_r))
|
||||
else:
|
||||
set_progress(r["id"], 1., "Done!")
|
||||
DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm})
|
||||
tmf.write(str(r["update_time"]) + "\n")
|
||||
tmf.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
main(comm.Get_size(), comm.Get_rank())
|
@ -1,6 +1,23 @@
|
||||
import os
|
||||
import re
|
||||
import tiktoken
|
||||
|
||||
|
||||
def singleton(cls, *args, **kw):
|
||||
instances = {}
|
||||
|
||||
def _singleton():
|
||||
key = str(cls) + str(os.getpid())
|
||||
if key not in instances:
|
||||
instances[key] = cls(*args, **kw)
|
||||
return instances[key]
|
||||
|
||||
return _singleton
|
||||
|
||||
|
||||
from .minio_conn import MINIO
|
||||
from .es_conn import ELASTICSEARCH
|
||||
|
||||
def rmSpace(txt):
|
||||
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
|
||||
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
|
||||
@ -22,3 +39,9 @@ def findMaxDt(fnm):
|
||||
except Exception as e:
|
||||
print("WARNING: can't find " + fnm)
|
||||
return m
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
104
python/util/es_conn.py → rag/utils/es_conn.py
Executable file → Normal file
104
python/util/es_conn.py → rag/utils/es_conn.py
Executable file → Normal file
@ -1,51 +1,39 @@
|
||||
import re
|
||||
import logging
|
||||
import json
|
||||
import time
|
||||
import copy
|
||||
import elasticsearch
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch_dsl import UpdateByQuery, Search, Index, Q
|
||||
from util import config
|
||||
from elasticsearch_dsl import UpdateByQuery, Search, Index
|
||||
from rag.settings import es_logger
|
||||
from rag import settings
|
||||
from rag.utils import singleton
|
||||
|
||||
logging.info("Elasticsearch version: ", elasticsearch.__version__)
|
||||
|
||||
|
||||
def instance(env):
|
||||
CF = config.init(env)
|
||||
ES_DRESS = CF.get("es").split(",")
|
||||
|
||||
ES = Elasticsearch(
|
||||
ES_DRESS,
|
||||
timeout=600
|
||||
)
|
||||
|
||||
logging.info("ES: ", ES_DRESS, ES.info())
|
||||
|
||||
return ES
|
||||
es_logger.info("Elasticsearch version: "+ str(elasticsearch.__version__))
|
||||
|
||||
|
||||
@singleton
|
||||
class HuEs:
|
||||
def __init__(self, env):
|
||||
self.env = env
|
||||
def __init__(self):
|
||||
self.info = {}
|
||||
self.config = config.init(env)
|
||||
self.conn()
|
||||
self.idxnm = self.config.get("idx_nm", "")
|
||||
self.idxnm = settings.ES.get("index_name", "")
|
||||
if not self.es.ping():
|
||||
raise Exception("Can't connect to ES cluster")
|
||||
|
||||
def conn(self):
|
||||
for _ in range(10):
|
||||
try:
|
||||
c = instance(self.env)
|
||||
if c:
|
||||
self.es = c
|
||||
self.info = c.info()
|
||||
logging.info("Connect to es.")
|
||||
self.es = Elasticsearch(
|
||||
settings.ES["hosts"].split(","),
|
||||
timeout=600
|
||||
)
|
||||
if self.es:
|
||||
self.info = self.es.info()
|
||||
es_logger.info("Connect to es.")
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error("Fail to connect to es: " + str(e))
|
||||
es_logger.error("Fail to connect to es: " + str(e))
|
||||
time.sleep(1)
|
||||
|
||||
def version(self):
|
||||
@ -80,12 +68,12 @@ class HuEs:
|
||||
refresh=False,
|
||||
doc_type="_doc",
|
||||
retry_on_conflict=100)
|
||||
logging.info("Successfully upsert: %s" % id)
|
||||
es_logger.info("Successfully upsert: %s" % id)
|
||||
T = True
|
||||
break
|
||||
except Exception as e:
|
||||
logging.warning("Fail to index: " +
|
||||
json.dumps(d, ensure_ascii=False) + str(e))
|
||||
es_logger.warning("Fail to index: " +
|
||||
json.dumps(d, ensure_ascii=False) + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
@ -94,7 +82,7 @@ class HuEs:
|
||||
|
||||
if not T:
|
||||
res.append(d)
|
||||
logging.error(
|
||||
es_logger.error(
|
||||
"Fail to index: " +
|
||||
re.sub(
|
||||
"[\r\n]",
|
||||
@ -147,7 +135,7 @@ class HuEs:
|
||||
|
||||
return res
|
||||
except Exception as e:
|
||||
logging.warn("Fail to bulk: " + str(e))
|
||||
es_logger.warn("Fail to bulk: " + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
@ -162,7 +150,7 @@ class HuEs:
|
||||
ids[id] = copy.deepcopy(d["raw"])
|
||||
acts.append({"update": {"_id": id, "_index": self.idxnm}})
|
||||
acts.append(d["script"])
|
||||
logging.info("bulk upsert: %s" % id)
|
||||
es_logger.info("bulk upsert: %s" % id)
|
||||
|
||||
res = []
|
||||
for _ in range(10):
|
||||
@ -189,7 +177,7 @@ class HuEs:
|
||||
|
||||
return res
|
||||
except Exception as e:
|
||||
logging.warning("Fail to bulk: " + str(e))
|
||||
es_logger.warning("Fail to bulk: " + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
@ -212,10 +200,10 @@ class HuEs:
|
||||
id=d["id"],
|
||||
refresh=True,
|
||||
doc_type="_doc")
|
||||
logging.info("Remove %s" % d["id"])
|
||||
es_logger.info("Remove %s" % d["id"])
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warn("Fail to delete: " + str(d) + str(e))
|
||||
es_logger.warn("Fail to delete: " + str(d) + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
@ -223,7 +211,7 @@ class HuEs:
|
||||
return True
|
||||
self.conn()
|
||||
|
||||
logging.error("Fail to delete: " + str(d))
|
||||
es_logger.error("Fail to delete: " + str(d))
|
||||
|
||||
return False
|
||||
|
||||
@ -242,7 +230,7 @@ class HuEs:
|
||||
raise Exception("Es Timeout.")
|
||||
return res
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
es_logger.error(
|
||||
"ES search exception: " +
|
||||
str(e) +
|
||||
"【Q】:" +
|
||||
@ -250,7 +238,7 @@ class HuEs:
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
raise e
|
||||
logging.error("ES search timeout for 3 times!")
|
||||
es_logger.error("ES search timeout for 3 times!")
|
||||
raise Exception("ES search timeout.")
|
||||
|
||||
def updateByQuery(self, q, d):
|
||||
@ -267,8 +255,8 @@ class HuEs:
|
||||
r = ubq.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error("ES updateByQuery exception: " +
|
||||
str(e) + "【Q】:" + str(q.to_dict()))
|
||||
es_logger.error("ES updateByQuery exception: " +
|
||||
str(e) + "【Q】:" + str(q.to_dict()))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
self.conn()
|
||||
@ -288,8 +276,8 @@ class HuEs:
|
||||
r = ubq.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error("ES updateByQuery exception: " +
|
||||
str(e) + "【Q】:" + str(q.to_dict()))
|
||||
es_logger.error("ES updateByQuery exception: " +
|
||||
str(e) + "【Q】:" + str(q.to_dict()))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
self.conn()
|
||||
@ -304,8 +292,8 @@ class HuEs:
|
||||
body=Search().query(query).to_dict())
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error("ES updateByQuery deleteByQuery: " +
|
||||
str(e) + "【Q】:" + str(query.to_dict()))
|
||||
es_logger.error("ES updateByQuery deleteByQuery: " +
|
||||
str(e) + "【Q】:" + str(query.to_dict()))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
|
||||
@ -329,8 +317,9 @@ class HuEs:
|
||||
routing=routing, refresh=False) # , doc_type="_doc")
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error("ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
|
||||
json.dumps(script, ensure_ascii=False))
|
||||
es_logger.error(
|
||||
"ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
|
||||
json.dumps(script, ensure_ascii=False))
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
|
||||
@ -342,7 +331,7 @@ class HuEs:
|
||||
try:
|
||||
return s.exists()
|
||||
except Exception as e:
|
||||
logging.error("ES updateByQuery indexExist: " + str(e))
|
||||
es_logger.error("ES updateByQuery indexExist: " + str(e))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
|
||||
@ -354,7 +343,7 @@ class HuEs:
|
||||
return self.es.exists(index=(idxnm if idxnm else self.idxnm),
|
||||
id=docid)
|
||||
except Exception as e:
|
||||
logging.error("ES Doc Exist: " + str(e))
|
||||
es_logger.error("ES Doc Exist: " + str(e))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
return False
|
||||
@ -368,13 +357,13 @@ class HuEs:
|
||||
settings=mapping["settings"],
|
||||
mappings=mapping["mappings"])
|
||||
except Exception as e:
|
||||
logging.error("ES create index error %s ----%s" % (idxnm, str(e)))
|
||||
es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
|
||||
|
||||
def deleteIdx(self, idxnm):
|
||||
try:
|
||||
return self.es.indices.delete(idxnm, allow_no_indices=True)
|
||||
except Exception as e:
|
||||
logging.error("ES delete index error %s ----%s" % (idxnm, str(e)))
|
||||
es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
|
||||
|
||||
def getTotal(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
@ -393,7 +382,7 @@ class HuEs:
|
||||
return rr
|
||||
|
||||
def scrollIter(self, pagesize=100, scroll_time='2m', q={
|
||||
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
|
||||
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
|
||||
for _ in range(100):
|
||||
try:
|
||||
page = self.es.search(
|
||||
@ -405,12 +394,12 @@ class HuEs:
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error("ES scrolling fail. " + str(e))
|
||||
es_logger.error("ES scrolling fail. " + str(e))
|
||||
time.sleep(3)
|
||||
|
||||
sid = page['_scroll_id']
|
||||
scroll_size = page['hits']['total']["value"]
|
||||
logging.info("[TOTAL]%d" % scroll_size)
|
||||
es_logger.info("[TOTAL]%d" % scroll_size)
|
||||
# Start scrolling
|
||||
while scroll_size > 0:
|
||||
yield page["hits"]["hits"]
|
||||
@ -419,10 +408,13 @@ class HuEs:
|
||||
page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error("ES scrolling fail. " + str(e))
|
||||
es_logger.error("ES scrolling fail. " + str(e))
|
||||
time.sleep(3)
|
||||
|
||||
# Update the scroll ID
|
||||
sid = page['_scroll_id']
|
||||
# Get the number of results that we returned in the last scroll
|
||||
scroll_size = len(page['hits']['hits'])
|
||||
|
||||
|
||||
ELASTICSEARCH = HuEs()
|
@ -1,13 +1,15 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from util import config
|
||||
from minio import Minio
|
||||
from io import BytesIO
|
||||
from rag import settings
|
||||
from rag.settings import minio_logger
|
||||
from rag.utils import singleton
|
||||
|
||||
|
||||
@singleton
|
||||
class HuMinio(object):
|
||||
def __init__(self, env):
|
||||
self.config = config.init(env)
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.__open__()
|
||||
|
||||
@ -19,15 +21,14 @@ class HuMinio(object):
|
||||
pass
|
||||
|
||||
try:
|
||||
self.conn = Minio(self.config.get("minio_host"),
|
||||
access_key=self.config.get("minio_user"),
|
||||
secret_key=self.config.get("minio_password"),
|
||||
self.conn = Minio(settings.MINIO["host"],
|
||||
access_key=settings.MINIO["user"],
|
||||
secret_key=settings.MINIO["passwd"],
|
||||
secure=False
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
"Fail to connect %s " %
|
||||
self.config.get("minio_host") + str(e))
|
||||
minio_logger.error(
|
||||
"Fail to connect %s " % settings.MINIO["host"] + str(e))
|
||||
|
||||
def __close__(self):
|
||||
del self.conn
|
||||
@ -45,34 +46,51 @@ class HuMinio(object):
|
||||
)
|
||||
return r
|
||||
except Exception as e:
|
||||
logging.error(f"Fail put {bucket}/{fnm}: " + str(e))
|
||||
minio_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
|
||||
def rm(self, bucket, fnm):
|
||||
try:
|
||||
self.conn.remove_object(bucket, fnm)
|
||||
except Exception as e:
|
||||
minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
|
||||
|
||||
|
||||
def get(self, bucket, fnm):
|
||||
for _ in range(10):
|
||||
try:
|
||||
r = self.conn.get_object(bucket, fnm)
|
||||
return r.read()
|
||||
except Exception as e:
|
||||
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
|
||||
minio_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
def obj_exist(self, bucket, fnm):
|
||||
try:
|
||||
if self.conn.stat_object(bucket, fnm):return True
|
||||
return False
|
||||
except Exception as e:
|
||||
minio_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
|
||||
return False
|
||||
|
||||
|
||||
def get_presigned_url(self, bucket, fnm, expires):
|
||||
for _ in range(10):
|
||||
try:
|
||||
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
||||
except Exception as e:
|
||||
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
|
||||
minio_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
MINIO = HuMinio()
|
||||
|
||||
if __name__ == "__main__":
|
||||
conn = HuMinio("infiniflow")
|
||||
conn = HuMinio()
|
||||
fnm = "/opt/home/kevinhu/docgpt/upload/13/11-408.jpg"
|
||||
from PIL import Image
|
||||
img = Image.open(fnm)
|
@ -1,178 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use actix_web::{ HttpResponse, post, web };
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use crate::api::JsonResponse;
|
||||
use crate::AppState;
|
||||
use crate::errors::AppError;
|
||||
use crate::service::dialog_info::Query;
|
||||
use crate::service::dialog_info::Mutation;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListParams {
|
||||
pub uid: i64,
|
||||
pub dialog_id: Option<i64>,
|
||||
}
|
||||
#[post("/v1.0/dialogs")]
|
||||
async fn list(
|
||||
params: web::Json<ListParams>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let mut result = HashMap::new();
|
||||
if let Some(dia_id) = params.dialog_id {
|
||||
let dia = Query::find_dialog_info_by_id(&data.conn, dia_id).await?.unwrap();
|
||||
let kb = crate::service::kb_info::Query
|
||||
::find_kb_info_by_id(&data.conn, dia.kb_id).await?
|
||||
.unwrap();
|
||||
print!("{:?}", dia.history);
|
||||
let hist: Value = serde_json::from_str(&dia.history)?;
|
||||
let detail =
|
||||
json!({
|
||||
"dialog_id": dia_id,
|
||||
"dialog_name": dia.dialog_name.to_owned(),
|
||||
"created_at": dia.created_at.to_string().to_owned(),
|
||||
"updated_at": dia.updated_at.to_string().to_owned(),
|
||||
"history": hist,
|
||||
"kb_info": kb
|
||||
});
|
||||
|
||||
result.insert("dialogs", vec![detail]);
|
||||
} else {
|
||||
let mut dias = Vec::<Value>::new();
|
||||
for dia in Query::find_dialog_infos_by_uid(&data.conn, params.uid).await? {
|
||||
let kb = crate::service::kb_info::Query
|
||||
::find_kb_info_by_id(&data.conn, dia.kb_id).await?
|
||||
.unwrap();
|
||||
let hist: Value = serde_json::from_str(&dia.history)?;
|
||||
dias.push(
|
||||
json!({
|
||||
"dialog_id": dia.dialog_id,
|
||||
"dialog_name": dia.dialog_name.to_owned(),
|
||||
"created_at": dia.created_at.to_string().to_owned(),
|
||||
"updated_at": dia.updated_at.to_string().to_owned(),
|
||||
"history": hist,
|
||||
"kb_info": kb
|
||||
})
|
||||
);
|
||||
}
|
||||
result.insert("dialogs", dias);
|
||||
}
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: result,
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RmParams {
|
||||
pub uid: i64,
|
||||
pub dialog_id: i64,
|
||||
}
|
||||
#[post("/v1.0/delete_dialog")]
|
||||
async fn delete(
|
||||
params: web::Json<RmParams>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let _ = Mutation::delete_dialog_info(&data.conn, params.dialog_id).await?;
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: (),
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateParams {
|
||||
pub uid: i64,
|
||||
pub dialog_id: Option<i64>,
|
||||
pub kb_id: i64,
|
||||
pub name: String,
|
||||
}
|
||||
#[post("/v1.0/create_dialog")]
|
||||
async fn create(
|
||||
param: web::Json<CreateParams>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let mut result = HashMap::new();
|
||||
if let Some(dia_id) = param.dialog_id {
|
||||
result.insert("dialog_id", dia_id);
|
||||
let dia = Query::find_dialog_info_by_id(&data.conn, dia_id).await?;
|
||||
let _ = Mutation::update_dialog_info_by_id(
|
||||
&data.conn,
|
||||
dia_id,
|
||||
¶m.name,
|
||||
&dia.unwrap().history
|
||||
).await?;
|
||||
} else {
|
||||
let dia = Mutation::create_dialog_info(
|
||||
&data.conn,
|
||||
param.uid,
|
||||
param.kb_id,
|
||||
¶m.name
|
||||
).await?;
|
||||
result.insert("dialog_id", dia.dialog_id.unwrap());
|
||||
}
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: result,
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateHistoryParams {
|
||||
pub uid: i64,
|
||||
pub dialog_id: i64,
|
||||
pub history: Value,
|
||||
}
|
||||
#[post("/v1.0/update_history")]
|
||||
async fn update_history(
|
||||
param: web::Json<UpdateHistoryParams>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let mut json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: (),
|
||||
};
|
||||
|
||||
if let Some(dia) = Query::find_dialog_info_by_id(&data.conn, param.dialog_id).await? {
|
||||
let _ = Mutation::update_dialog_info_by_id(
|
||||
&data.conn,
|
||||
param.dialog_id,
|
||||
&dia.dialog_name,
|
||||
¶m.history.to_string()
|
||||
).await?;
|
||||
} else {
|
||||
json_response.code = 500;
|
||||
json_response.err = "Can't find dialog data!".to_owned();
|
||||
}
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
@ -1,314 +0,0 @@
|
||||
use std::collections::{HashMap};
|
||||
use std::io::BufReader;
|
||||
use actix_multipart_extract::{ File, Multipart, MultipartForm };
|
||||
use actix_web::web::Bytes;
|
||||
use actix_web::{ HttpResponse, post, web };
|
||||
use chrono::{ Utc, FixedOffset };
|
||||
use minio::s3::args::{ BucketExistsArgs, MakeBucketArgs, PutObjectArgs };
|
||||
use sea_orm::DbConn;
|
||||
use crate::api::JsonResponse;
|
||||
use crate::AppState;
|
||||
use crate::entity::doc_info::Model;
|
||||
use crate::errors::AppError;
|
||||
use crate::service::doc_info::{ Mutation, Query };
|
||||
use serde::Deserialize;
|
||||
use regex::Regex;
|
||||
|
||||
fn now() -> chrono::DateTime<FixedOffset> {
|
||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListParams {
|
||||
pub uid: i64,
|
||||
pub filter: FilterParams,
|
||||
pub sortby: String,
|
||||
pub page: Option<u32>,
|
||||
pub per_page: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FilterParams {
|
||||
pub keywords: Option<String>,
|
||||
pub folder_id: Option<i64>,
|
||||
pub tag_id: Option<i64>,
|
||||
pub kb_id: Option<i64>,
|
||||
}
|
||||
|
||||
#[post("/v1.0/docs")]
|
||||
async fn list(
|
||||
params: web::Json<ListParams>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let docs = Query::find_doc_infos_by_params(&data.conn, params.into_inner()).await?;
|
||||
|
||||
let mut result = HashMap::new();
|
||||
result.insert("docs", docs);
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: result,
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, MultipartForm, Debug)]
|
||||
pub struct UploadForm {
|
||||
#[multipart(max_size = 512MB)]
|
||||
file_field: File,
|
||||
uid: i64,
|
||||
did: i64,
|
||||
}
|
||||
|
||||
fn file_type(filename: &String) -> String {
|
||||
let fnm = filename.to_lowercase();
|
||||
if
|
||||
let Some(_) = Regex::new(r"\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$")
|
||||
.unwrap()
|
||||
.captures(&fnm)
|
||||
{
|
||||
return "Video".to_owned();
|
||||
}
|
||||
if
|
||||
let Some(_) = Regex::new(
|
||||
r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$"
|
||||
)
|
||||
.unwrap()
|
||||
.captures(&fnm)
|
||||
{
|
||||
return "Picture".to_owned();
|
||||
}
|
||||
if
|
||||
let Some(_) = Regex::new(r"\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$")
|
||||
.unwrap()
|
||||
.captures(&fnm)
|
||||
{
|
||||
return "Music".to_owned();
|
||||
}
|
||||
if
|
||||
let Some(_) = Regex::new(r"\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$")
|
||||
.unwrap()
|
||||
.captures(&fnm)
|
||||
{
|
||||
return "Document".to_owned();
|
||||
}
|
||||
"Other".to_owned()
|
||||
}
|
||||
|
||||
|
||||
#[post("/v1.0/upload")]
|
||||
async fn upload(
|
||||
payload: Multipart<UploadForm>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
|
||||
|
||||
Ok(HttpResponse::Ok().body("File uploaded successfully"))
|
||||
}
|
||||
|
||||
pub(crate) async fn _upload_file(uid: i64, did: i64, file_name: &str, bytes: &[u8], data: &web::Data<AppState>) -> Result<(), AppError> {
|
||||
async fn add_number_to_filename(
|
||||
file_name: &str,
|
||||
conn: &DbConn,
|
||||
uid: i64,
|
||||
parent_id: i64
|
||||
) -> String {
|
||||
let mut i = 0;
|
||||
let mut new_file_name = file_name.to_string();
|
||||
let arr: Vec<&str> = file_name.split(".").collect();
|
||||
let suffix = String::from(arr[arr.len() - 1]);
|
||||
let preffix = arr[..arr.len() - 1].join(".");
|
||||
let mut docs = Query::find_doc_infos_by_name(
|
||||
conn,
|
||||
uid,
|
||||
&new_file_name,
|
||||
Some(parent_id)
|
||||
).await.unwrap();
|
||||
while docs.len() > 0 {
|
||||
i += 1;
|
||||
new_file_name = format!("{}_{}.{}", preffix, i, suffix);
|
||||
docs = Query::find_doc_infos_by_name(
|
||||
conn,
|
||||
uid,
|
||||
&new_file_name,
|
||||
Some(parent_id)
|
||||
).await.unwrap();
|
||||
}
|
||||
new_file_name
|
||||
}
|
||||
let fnm = add_number_to_filename(file_name, &data.conn, uid, did).await;
|
||||
|
||||
let bucket_name = format!("{}-upload", uid);
|
||||
let s3_client: &minio::s3::client::Client = &data.s3_client;
|
||||
let buckets_exists = s3_client
|
||||
.bucket_exists(&BucketExistsArgs::new(&bucket_name).unwrap()).await
|
||||
.unwrap();
|
||||
if !buckets_exists {
|
||||
print!("Create bucket: {}", bucket_name.clone());
|
||||
s3_client.make_bucket(&MakeBucketArgs::new(&bucket_name).unwrap()).await.unwrap();
|
||||
} else {
|
||||
print!("Existing bucket: {}", bucket_name.clone());
|
||||
}
|
||||
|
||||
let location = format!("/{}/{}", did, fnm)
|
||||
.as_bytes()
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|b| format!("{:02x}", b).to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join("");
|
||||
print!("===>{}", location.clone());
|
||||
s3_client.put_object(
|
||||
&mut PutObjectArgs::new(
|
||||
&bucket_name,
|
||||
&location,
|
||||
&mut BufReader::new(bytes),
|
||||
Some(bytes.len()),
|
||||
None
|
||||
)?
|
||||
).await?;
|
||||
|
||||
let doc = Mutation::create_doc_info(&data.conn, Model {
|
||||
did: Default::default(),
|
||||
uid: uid,
|
||||
doc_name: fnm.clone(),
|
||||
size: bytes.len() as i64,
|
||||
location,
|
||||
r#type: file_type(&fnm),
|
||||
thumbnail_base64: Default::default(),
|
||||
created_at: now(),
|
||||
updated_at: now(),
|
||||
is_deleted: Default::default(),
|
||||
}).await?;
|
||||
|
||||
let _ = Mutation::place_doc(&data.conn, did, doc.did.unwrap()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct RmDocsParam {
|
||||
uid: i64,
|
||||
dids: Vec<i64>,
|
||||
}
|
||||
#[post("/v1.0/delete_docs")]
|
||||
async fn delete(
|
||||
params: web::Json<RmDocsParam>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let _ = Mutation::delete_doc_info(&data.conn, ¶ms.dids).await?;
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: (),
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct MvParams {
|
||||
pub uid: i64,
|
||||
pub dids: Vec<i64>,
|
||||
pub dest_did: i64,
|
||||
}
|
||||
|
||||
#[post("/v1.0/mv_docs")]
|
||||
async fn mv(
|
||||
params: web::Json<MvParams>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
Mutation::mv_doc_info(&data.conn, params.dest_did, ¶ms.dids).await?;
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: (),
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct NewFoldParams {
|
||||
pub uid: i64,
|
||||
pub parent_id: i64,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[post("/v1.0/new_folder")]
|
||||
async fn new_folder(
|
||||
params: web::Json<NewFoldParams>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let doc = Mutation::create_doc_info(&data.conn, Model {
|
||||
did: Default::default(),
|
||||
uid: params.uid,
|
||||
doc_name: params.name.to_string(),
|
||||
size: 0,
|
||||
r#type: "folder".to_string(),
|
||||
location: "".to_owned(),
|
||||
thumbnail_base64: Default::default(),
|
||||
created_at: now(),
|
||||
updated_at: now(),
|
||||
is_deleted: Default::default(),
|
||||
}).await?;
|
||||
let _ = Mutation::place_doc(&data.conn, params.parent_id, doc.did.unwrap()).await?;
|
||||
|
||||
Ok(HttpResponse::Ok().body("Folder created successfully"))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RenameParams {
|
||||
pub uid: i64,
|
||||
pub did: i64,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[post("/v1.0/rename")]
|
||||
async fn rename(
|
||||
params: web::Json<RenameParams>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let docs = Query::find_doc_infos_by_name(&data.conn, params.uid, ¶ms.name, None).await?;
|
||||
if docs.len() > 0 {
|
||||
let json_response = JsonResponse {
|
||||
code: 500,
|
||||
err: "Name duplicated!".to_owned(),
|
||||
data: (),
|
||||
};
|
||||
return Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
);
|
||||
}
|
||||
let doc = Mutation::rename(&data.conn, params.did, ¶ms.name).await?;
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: doc,
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
@ -1,166 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use actix_web::{ get, HttpResponse, post, web };
|
||||
use serde::Serialize;
|
||||
use crate::api::JsonResponse;
|
||||
use crate::AppState;
|
||||
use crate::entity::kb_info;
|
||||
use crate::errors::AppError;
|
||||
use crate::service::kb_info::Mutation;
|
||||
use crate::service::kb_info::Query;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AddDocs2KbParams {
|
||||
pub uid: i64,
|
||||
pub dids: Vec<i64>,
|
||||
pub kb_id: i64,
|
||||
}
|
||||
#[post("/v1.0/create_kb")]
|
||||
async fn create(
|
||||
model: web::Json<kb_info::Model>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let mut docs = Query::find_kb_infos_by_name(
|
||||
&data.conn,
|
||||
model.kb_name.to_owned()
|
||||
).await.unwrap();
|
||||
if docs.len() > 0 {
|
||||
let json_response = JsonResponse {
|
||||
code: 201,
|
||||
err: "Duplicated name.".to_owned(),
|
||||
data: (),
|
||||
};
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
} else {
|
||||
let model = Mutation::create_kb_info(&data.conn, model.into_inner()).await?;
|
||||
|
||||
let mut result = HashMap::new();
|
||||
result.insert("kb_id", model.kb_id.unwrap());
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: result,
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/v1.0/add_docs_to_kb")]
|
||||
async fn add_docs_to_kb(
|
||||
param: web::Json<AddDocs2KbParams>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let _ = Mutation::add_docs(&data.conn, param.kb_id, param.dids.to_owned()).await?;
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: (),
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
|
||||
#[post("/v1.0/anti_kb_docs")]
|
||||
async fn anti_kb_docs(
|
||||
param: web::Json<AddDocs2KbParams>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let _ = Mutation::remove_docs(&data.conn, param.dids.to_owned(), Some(param.kb_id)).await?;
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: (),
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
#[get("/v1.0/kbs")]
|
||||
async fn list(
|
||||
model: web::Json<kb_info::Model>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let kbs = Query::find_kb_infos_by_uid(&data.conn, model.uid).await?;
|
||||
|
||||
let mut result = HashMap::new();
|
||||
result.insert("kbs", kbs);
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: result,
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
|
||||
#[post("/v1.0/delete_kb")]
|
||||
async fn delete(
|
||||
model: web::Json<kb_info::Model>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let _ = Mutation::delete_kb_info(&data.conn, model.kb_id).await?;
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: (),
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct DocIdsParams {
|
||||
pub uid: i64,
|
||||
pub dids: Vec<i64>,
|
||||
}
|
||||
|
||||
#[post("/v1.0/all_relevents")]
|
||||
async fn all_relevents(
|
||||
params: web::Json<DocIdsParams>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let dids = crate::service::doc_info::Query::all_descendent_ids(&data.conn, ¶ms.dids).await?;
|
||||
let mut result = HashMap::new();
|
||||
let kbs = Query::find_kb_by_docs(&data.conn, dids).await?;
|
||||
result.insert("kbs", kbs);
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: result,
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
@ -1,14 +0,0 @@
|
||||
use serde::{ Deserialize, Serialize };
|
||||
|
||||
pub(crate) mod tag_info;
|
||||
pub(crate) mod kb_info;
|
||||
pub(crate) mod dialog_info;
|
||||
pub(crate) mod doc_info;
|
||||
pub(crate) mod user_info;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct JsonResponse<T> {
|
||||
code: u32,
|
||||
err: String,
|
||||
data: T,
|
||||
}
|
@ -1,81 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use actix_web::{ get, HttpResponse, post, web };
|
||||
use serde::Deserialize;
|
||||
use crate::api::JsonResponse;
|
||||
use crate::AppState;
|
||||
use crate::entity::tag_info;
|
||||
use crate::errors::AppError;
|
||||
use crate::service::tag_info::{ Mutation, Query };
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct TagListParams {
|
||||
pub uid: i64,
|
||||
}
|
||||
|
||||
#[post("/v1.0/create_tag")]
|
||||
async fn create(
|
||||
model: web::Json<tag_info::Model>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let model = Mutation::create_tag(&data.conn, model.into_inner()).await?;
|
||||
|
||||
let mut result = HashMap::new();
|
||||
result.insert("tid", model.tid.unwrap());
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: result,
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
|
||||
#[post("/v1.0/delete_tag")]
|
||||
async fn delete(
|
||||
model: web::Json<tag_info::Model>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let _ = Mutation::delete_tag(&data.conn, model.tid).await?;
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: (),
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
|
||||
//#[get("/v1.0/tags", wrap = "HttpAuthentication::bearer(validator)")]
|
||||
|
||||
#[post("/v1.0/tags")]
|
||||
async fn list(
|
||||
param: web::Json<TagListParams>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let tags = Query::find_tags_by_uid(param.uid, &data.conn).await?;
|
||||
|
||||
let mut result = HashMap::new();
|
||||
result.insert("tags", tags);
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: result,
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
@ -1,149 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io::SeekFrom;
|
||||
use std::ptr::null;
|
||||
use actix_identity::Identity;
|
||||
use actix_web::{ HttpResponse, post, web };
|
||||
use chrono::{ FixedOffset, Utc };
|
||||
use sea_orm::ActiveValue::NotSet;
|
||||
use serde::{ Deserialize, Serialize };
|
||||
use crate::api::JsonResponse;
|
||||
use crate::AppState;
|
||||
use crate::entity::{ doc_info, tag_info };
|
||||
use crate::entity::user_info::Model;
|
||||
use crate::errors::{ AppError, UserError };
|
||||
use crate::service::user_info::Mutation;
|
||||
use crate::service::user_info::Query;
|
||||
|
||||
fn now() -> chrono::DateTime<FixedOffset> {
|
||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
||||
}
|
||||
|
||||
pub(crate) fn create_auth_token(user: &Model) -> u64 {
|
||||
use std::{ collections::hash_map::DefaultHasher, hash::{ Hash, Hasher } };
|
||||
|
||||
let mut hasher = DefaultHasher::new();
|
||||
user.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub(crate) struct LoginParams {
|
||||
pub(crate) email: String,
|
||||
pub(crate) password: String,
|
||||
}
|
||||
|
||||
#[post("/v1.0/login")]
|
||||
async fn login(
|
||||
data: web::Data<AppState>,
|
||||
identity: Identity,
|
||||
input: web::Json<LoginParams>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
match Query::login(&data.conn, &input.email, &input.password).await? {
|
||||
Some(user) => {
|
||||
let _ = Mutation::update_login_status(user.uid, &data.conn).await?;
|
||||
let token = create_auth_token(&user).to_string();
|
||||
|
||||
identity.remember(token.clone());
|
||||
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: token.clone(),
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.append_header(("X-Auth-Token", token))
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
None => Err(UserError::LoginFailed.into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/v1.0/register")]
|
||||
async fn register(
|
||||
model: web::Json<Model>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let mut result = HashMap::new();
|
||||
let u = Query::find_user_infos(&data.conn, &model.email).await?;
|
||||
if let Some(_) = u {
|
||||
let json_response = JsonResponse {
|
||||
code: 500,
|
||||
err: "Email registered!".to_owned(),
|
||||
data: (),
|
||||
};
|
||||
return Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
);
|
||||
}
|
||||
|
||||
let usr = Mutation::create_user(&data.conn, &model).await?;
|
||||
result.insert("uid", usr.uid.clone().unwrap());
|
||||
crate::service::doc_info::Mutation::create_doc_info(&data.conn, doc_info::Model {
|
||||
did: Default::default(),
|
||||
uid: usr.uid.clone().unwrap(),
|
||||
doc_name: "/".into(),
|
||||
size: 0,
|
||||
location: "".into(),
|
||||
thumbnail_base64: "".into(),
|
||||
r#type: "folder".to_string(),
|
||||
created_at: now(),
|
||||
updated_at: now(),
|
||||
is_deleted: Default::default(),
|
||||
}).await?;
|
||||
let tnm = vec!["Video", "Picture", "Music", "Document"];
|
||||
let tregx = vec![
|
||||
".*\\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa)",
|
||||
".*\\.(png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng)",
|
||||
".*\\.(WAV|FLAC|APE|ALAC|WavPack|WV|MP3|AAC|Ogg|Vorbis|Opus)",
|
||||
".*\\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp)"
|
||||
];
|
||||
for i in 0..4 {
|
||||
crate::service::tag_info::Mutation::create_tag(&data.conn, tag_info::Model {
|
||||
tid: Default::default(),
|
||||
uid: usr.uid.clone().unwrap(),
|
||||
tag_name: tnm[i].to_owned(),
|
||||
regx: tregx[i].to_owned(),
|
||||
color: (i + 1).to_owned() as i16,
|
||||
icon: (i + 1).to_owned() as i16,
|
||||
folder_id: 0,
|
||||
created_at: Default::default(),
|
||||
updated_at: Default::default(),
|
||||
}).await?;
|
||||
}
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: result,
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
||||
|
||||
#[post("/v1.0/setting")]
|
||||
async fn setting(
|
||||
model: web::Json<Model>,
|
||||
data: web::Data<AppState>
|
||||
) -> Result<HttpResponse, AppError> {
|
||||
let _ = Mutation::update_user_by_id(&data.conn, &model).await?;
|
||||
let json_response = JsonResponse {
|
||||
code: 200,
|
||||
err: "".to_owned(),
|
||||
data: (),
|
||||
};
|
||||
|
||||
Ok(
|
||||
HttpResponse::Ok()
|
||||
.content_type("application/json")
|
||||
.body(serde_json::to_string(&json_response)?)
|
||||
)
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::{ Deserialize, Serialize };
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
|
||||
#[sea_orm(table_name = "dialog2_kb")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = true)]
|
||||
pub id: i64,
|
||||
#[sea_orm(index)]
|
||||
pub dialog_id: i64,
|
||||
#[sea_orm(index)]
|
||||
pub kb_id: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, EnumIter)]
|
||||
pub enum Relation {
|
||||
DialogInfo,
|
||||
KbInfo,
|
||||
}
|
||||
|
||||
impl RelationTrait for Relation {
|
||||
fn def(&self) -> RelationDef {
|
||||
match self {
|
||||
Self::DialogInfo =>
|
||||
Entity::belongs_to(super::dialog_info::Entity)
|
||||
.from(Column::DialogId)
|
||||
.to(super::dialog_info::Column::DialogId)
|
||||
.into(),
|
||||
Self::KbInfo =>
|
||||
Entity::belongs_to(super::kb_info::Entity)
|
||||
.from(Column::KbId)
|
||||
.to(super::kb_info::Column::KbId)
|
||||
.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
@ -1,38 +0,0 @@
|
||||
use chrono::{ DateTime, FixedOffset };
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::{ Deserialize, Serialize };
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
|
||||
#[sea_orm(table_name = "dialog_info")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub dialog_id: i64,
|
||||
#[sea_orm(index)]
|
||||
pub uid: i64,
|
||||
#[serde(skip_deserializing)]
|
||||
pub kb_id: i64,
|
||||
pub dialog_name: String,
|
||||
pub history: String,
|
||||
|
||||
#[serde(skip_deserializing)]
|
||||
pub created_at: DateTime<FixedOffset>,
|
||||
#[serde(skip_deserializing)]
|
||||
pub updated_at: DateTime<FixedOffset>,
|
||||
#[serde(skip_deserializing)]
|
||||
pub is_deleted: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl Related<super::kb_info::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
super::dialog2_kb::Relation::KbInfo.def()
|
||||
}
|
||||
|
||||
fn via() -> Option<RelationDef> {
|
||||
Some(super::dialog2_kb::Relation::DialogInfo.def().rev())
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
@ -1,38 +0,0 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::{ Deserialize, Serialize };
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
|
||||
#[sea_orm(table_name = "doc2_doc")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = true)]
|
||||
pub id: i64,
|
||||
#[sea_orm(index)]
|
||||
pub parent_id: i64,
|
||||
#[sea_orm(index)]
|
||||
pub did: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, EnumIter)]
|
||||
pub enum Relation {
|
||||
Parent,
|
||||
Child,
|
||||
}
|
||||
|
||||
impl RelationTrait for Relation {
|
||||
fn def(&self) -> RelationDef {
|
||||
match self {
|
||||
Self::Parent =>
|
||||
Entity::belongs_to(super::doc_info::Entity)
|
||||
.from(Column::ParentId)
|
||||
.to(super::doc_info::Column::Did)
|
||||
.into(),
|
||||
Self::Child =>
|
||||
Entity::belongs_to(super::doc_info::Entity)
|
||||
.from(Column::Did)
|
||||
.to(super::doc_info::Column::Did)
|
||||
.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
@ -1,62 +0,0 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::{ Deserialize, Serialize };
|
||||
use crate::entity::kb_info;
|
||||
use chrono::{ DateTime, FixedOffset };
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Deserialize, Serialize)]
|
||||
#[sea_orm(table_name = "doc_info")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub did: i64,
|
||||
#[sea_orm(index)]
|
||||
pub uid: i64,
|
||||
pub doc_name: String,
|
||||
pub size: i64,
|
||||
#[sea_orm(column_name = "type")]
|
||||
pub r#type: String,
|
||||
#[serde(skip_deserializing)]
|
||||
pub location: String,
|
||||
#[serde(skip_deserializing)]
|
||||
pub thumbnail_base64: String,
|
||||
#[serde(skip_deserializing)]
|
||||
pub created_at: DateTime<FixedOffset>,
|
||||
#[serde(skip_deserializing)]
|
||||
pub updated_at: DateTime<FixedOffset>,
|
||||
#[serde(skip_deserializing)]
|
||||
pub is_deleted: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl Related<super::tag_info::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
super::tag2_doc::Relation::Tag.def()
|
||||
}
|
||||
|
||||
fn via() -> Option<RelationDef> {
|
||||
Some(super::tag2_doc::Relation::DocInfo.def().rev())
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::kb_info::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
super::kb2_doc::Relation::KbInfo.def()
|
||||
}
|
||||
|
||||
fn via() -> Option<RelationDef> {
|
||||
Some(super::kb2_doc::Relation::DocInfo.def().rev())
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::doc2_doc::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
super::doc2_doc::Relation::Parent.def()
|
||||
}
|
||||
|
||||
fn via() -> Option<RelationDef> {
|
||||
Some(super::doc2_doc::Relation::Child.def().rev())
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
@ -1,47 +0,0 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::{ Deserialize, Serialize };
|
||||
use chrono::{ DateTime, FixedOffset };
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Deserialize, Serialize)]
|
||||
#[sea_orm(table_name = "kb2_doc")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = true)]
|
||||
pub id: i64,
|
||||
#[sea_orm(index)]
|
||||
pub kb_id: i64,
|
||||
#[sea_orm(index)]
|
||||
pub did: i64,
|
||||
#[serde(skip_deserializing)]
|
||||
pub kb_progress: f32,
|
||||
#[serde(skip_deserializing)]
|
||||
pub kb_progress_msg: String,
|
||||
#[serde(skip_deserializing)]
|
||||
pub updated_at: DateTime<FixedOffset>,
|
||||
#[serde(skip_deserializing)]
|
||||
pub is_deleted: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, EnumIter)]
|
||||
pub enum Relation {
|
||||
DocInfo,
|
||||
KbInfo,
|
||||
}
|
||||
|
||||
impl RelationTrait for Relation {
|
||||
fn def(&self) -> RelationDef {
|
||||
match self {
|
||||
Self::DocInfo =>
|
||||
Entity::belongs_to(super::doc_info::Entity)
|
||||
.from(Column::Did)
|
||||
.to(super::doc_info::Column::Did)
|
||||
.into(),
|
||||
Self::KbInfo =>
|
||||
Entity::belongs_to(super::kb_info::Entity)
|
||||
.from(Column::KbId)
|
||||
.to(super::kb_info::Column::KbId)
|
||||
.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
@ -1,47 +0,0 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::{ Deserialize, Serialize };
|
||||
use chrono::{ DateTime, FixedOffset };
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
|
||||
#[sea_orm(table_name = "kb_info")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
#[serde(skip_deserializing)]
|
||||
pub kb_id: i64,
|
||||
#[sea_orm(index)]
|
||||
pub uid: i64,
|
||||
pub kb_name: String,
|
||||
pub icon: i16,
|
||||
|
||||
#[serde(skip_deserializing)]
|
||||
pub created_at: DateTime<FixedOffset>,
|
||||
#[serde(skip_deserializing)]
|
||||
pub updated_at: DateTime<FixedOffset>,
|
||||
#[serde(skip_deserializing)]
|
||||
pub is_deleted: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl Related<super::doc_info::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
super::kb2_doc::Relation::DocInfo.def()
|
||||
}
|
||||
|
||||
fn via() -> Option<RelationDef> {
|
||||
Some(super::kb2_doc::Relation::KbInfo.def().rev())
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::dialog_info::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
super::dialog2_kb::Relation::DialogInfo.def()
|
||||
}
|
||||
|
||||
fn via() -> Option<RelationDef> {
|
||||
Some(super::dialog2_kb::Relation::KbInfo.def().rev())
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
@ -1,9 +0,0 @@
|
||||
pub(crate) mod user_info;
|
||||
pub(crate) mod tag_info;
|
||||
pub(crate) mod tag2_doc;
|
||||
pub(crate) mod kb2_doc;
|
||||
pub(crate) mod dialog2_kb;
|
||||
pub(crate) mod doc2_doc;
|
||||
pub(crate) mod kb_info;
|
||||
pub(crate) mod doc_info;
|
||||
pub(crate) mod dialog_info;
|
@ -1,38 +0,0 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::{ Deserialize, Serialize };
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
|
||||
#[sea_orm(table_name = "tag2_doc")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = true)]
|
||||
pub id: i64,
|
||||
#[sea_orm(index)]
|
||||
pub tag_id: i64,
|
||||
#[sea_orm(index)]
|
||||
pub did: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, EnumIter)]
|
||||
pub enum Relation {
|
||||
Tag,
|
||||
DocInfo,
|
||||
}
|
||||
|
||||
impl RelationTrait for Relation {
|
||||
fn def(&self) -> sea_orm::RelationDef {
|
||||
match self {
|
||||
Self::Tag =>
|
||||
Entity::belongs_to(super::tag_info::Entity)
|
||||
.from(Column::TagId)
|
||||
.to(super::tag_info::Column::Tid)
|
||||
.into(),
|
||||
Self::DocInfo =>
|
||||
Entity::belongs_to(super::doc_info::Entity)
|
||||
.from(Column::Did)
|
||||
.to(super::doc_info::Column::Did)
|
||||
.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
@ -1,40 +0,0 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::{ Deserialize, Serialize };
|
||||
use chrono::{ DateTime, FixedOffset };
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Deserialize, Serialize)]
|
||||
#[sea_orm(table_name = "tag_info")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
#[serde(skip_deserializing)]
|
||||
pub tid: i64,
|
||||
#[sea_orm(index)]
|
||||
pub uid: i64,
|
||||
pub tag_name: String,
|
||||
#[serde(skip_deserializing)]
|
||||
pub regx: String,
|
||||
pub color: i16,
|
||||
pub icon: i16,
|
||||
#[serde(skip_deserializing)]
|
||||
pub folder_id: i64,
|
||||
|
||||
#[serde(skip_deserializing)]
|
||||
pub created_at: DateTime<FixedOffset>,
|
||||
#[serde(skip_deserializing)]
|
||||
pub updated_at: DateTime<FixedOffset>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl Related<super::doc_info::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
super::tag2_doc::Relation::DocInfo.def()
|
||||
}
|
||||
|
||||
fn via() -> Option<RelationDef> {
|
||||
Some(super::tag2_doc::Relation::Tag.def().rev())
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
@ -1,30 +0,0 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::{ Deserialize, Serialize };
|
||||
use chrono::{ DateTime, FixedOffset };
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, DeriveEntityModel, Deserialize, Serialize)]
|
||||
#[sea_orm(table_name = "user_info")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
#[serde(skip_deserializing)]
|
||||
pub uid: i64,
|
||||
pub email: String,
|
||||
pub nickname: String,
|
||||
pub avatar_base64: String,
|
||||
pub color_scheme: String,
|
||||
pub list_style: String,
|
||||
pub language: String,
|
||||
pub password: String,
|
||||
|
||||
#[serde(skip_deserializing)]
|
||||
pub last_login_at: DateTime<FixedOffset>,
|
||||
#[serde(skip_deserializing)]
|
||||
pub created_at: DateTime<FixedOffset>,
|
||||
#[serde(skip_deserializing)]
|
||||
pub updated_at: DateTime<FixedOffset>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
@ -1,83 +0,0 @@
|
||||
use actix_web::{HttpResponse, ResponseError};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum AppError {
|
||||
#[error("`{0}`")]
|
||||
User(#[from] UserError),
|
||||
|
||||
#[error("`{0}`")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[error("`{0}`")]
|
||||
Actix(#[from] actix_web::Error),
|
||||
|
||||
#[error("`{0}`")]
|
||||
Db(#[from] sea_orm::DbErr),
|
||||
|
||||
#[error("`{0}`")]
|
||||
MinioS3(#[from] minio::s3::error::Error),
|
||||
|
||||
#[error("`{0}`")]
|
||||
Std(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum UserError {
|
||||
#[error("`username` field of `User` cannot be empty!")]
|
||||
EmptyUsername,
|
||||
|
||||
#[error("`username` field of `User` cannot contain whitespaces!")]
|
||||
UsernameInvalidCharacter,
|
||||
|
||||
#[error("`password` field of `User` cannot be empty!")]
|
||||
EmptyPassword,
|
||||
|
||||
#[error("`password` field of `User` cannot contain whitespaces!")]
|
||||
PasswordInvalidCharacter,
|
||||
|
||||
#[error("Could not find any `User` for id: `{0}`!")]
|
||||
NotFound(i64),
|
||||
|
||||
#[error("Failed to login user!")]
|
||||
LoginFailed,
|
||||
|
||||
#[error("User is not logged in!")]
|
||||
NotLoggedIn,
|
||||
|
||||
#[error("Invalid authorization token!")]
|
||||
InvalidToken,
|
||||
|
||||
#[error("Could not find any `User`!")]
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl ResponseError for AppError {
|
||||
fn status_code(&self) -> actix_web::http::StatusCode {
|
||||
match self {
|
||||
AppError::User(user_error) => match user_error {
|
||||
UserError::EmptyUsername => actix_web::http::StatusCode::UNPROCESSABLE_ENTITY,
|
||||
UserError::UsernameInvalidCharacter => {
|
||||
actix_web::http::StatusCode::UNPROCESSABLE_ENTITY
|
||||
}
|
||||
UserError::EmptyPassword => actix_web::http::StatusCode::UNPROCESSABLE_ENTITY,
|
||||
UserError::PasswordInvalidCharacter => {
|
||||
actix_web::http::StatusCode::UNPROCESSABLE_ENTITY
|
||||
}
|
||||
UserError::NotFound(_) => actix_web::http::StatusCode::NOT_FOUND,
|
||||
UserError::NotLoggedIn => actix_web::http::StatusCode::UNAUTHORIZED,
|
||||
UserError::Empty => actix_web::http::StatusCode::NOT_FOUND,
|
||||
UserError::LoginFailed => actix_web::http::StatusCode::NOT_FOUND,
|
||||
UserError::InvalidToken => actix_web::http::StatusCode::UNAUTHORIZED,
|
||||
},
|
||||
AppError::Actix(fail) => fail.as_response_error().status_code(),
|
||||
_ => actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(&self) -> HttpResponse {
|
||||
let status_code = self.status_code();
|
||||
let response = HttpResponse::build(status_code).body(self.to_string());
|
||||
response
|
||||
}
|
||||
}
|
145
src/main.rs
145
src/main.rs
@ -1,145 +0,0 @@
|
||||
mod api;
|
||||
mod entity;
|
||||
mod service;
|
||||
mod errors;
|
||||
mod web_socket;
|
||||
|
||||
use std::env;
|
||||
use actix_files::Files;
|
||||
use actix_identity::{ CookieIdentityPolicy, IdentityService, RequestIdentity };
|
||||
use actix_session::CookieSession;
|
||||
use actix_web::{ web, App, HttpServer, middleware, Error };
|
||||
use actix_web::cookie::time::Duration;
|
||||
use actix_web::dev::ServiceRequest;
|
||||
use actix_web::error::ErrorUnauthorized;
|
||||
use actix_web_httpauth::extractors::bearer::BearerAuth;
|
||||
use listenfd::ListenFd;
|
||||
use minio::s3::client::Client;
|
||||
use minio::s3::creds::StaticProvider;
|
||||
use minio::s3::http::BaseUrl;
|
||||
use sea_orm::{ Database, DatabaseConnection };
|
||||
use migration::{ Migrator, MigratorTrait };
|
||||
use crate::errors::{ AppError, UserError };
|
||||
use crate::web_socket::doc_info::upload_file_ws;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct AppState {
|
||||
conn: DatabaseConnection,
|
||||
s3_client: Client,
|
||||
}
|
||||
|
||||
pub(crate) async fn validator(
|
||||
req: ServiceRequest,
|
||||
credentials: BearerAuth
|
||||
) -> Result<ServiceRequest, Error> {
|
||||
if let Some(token) = req.get_identity() {
|
||||
println!("{}, {}", credentials.token(), token);
|
||||
(credentials.token() == token)
|
||||
.then(|| req)
|
||||
.ok_or(ErrorUnauthorized(UserError::InvalidToken))
|
||||
} else {
|
||||
Err(ErrorUnauthorized(UserError::NotLoggedIn))
|
||||
}
|
||||
}
|
||||
|
||||
#[actix_web::main]
|
||||
async fn main() -> Result<(), AppError> {
|
||||
std::env::set_var("RUST_LOG", "debug");
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
// get env vars
|
||||
dotenvy::dotenv().ok();
|
||||
let db_url = env::var("DATABASE_URL").expect("DATABASE_URL is not set in .env file");
|
||||
let host = env::var("HOST").expect("HOST is not set in .env file");
|
||||
let port = env::var("PORT").expect("PORT is not set in .env file");
|
||||
let server_url = format!("{host}:{port}");
|
||||
|
||||
let mut s3_base_url = env::var("MINIO_HOST").expect("MINIO_HOST is not set in .env file");
|
||||
let s3_access_key = env::var("MINIO_USR").expect("MINIO_USR is not set in .env file");
|
||||
let s3_secret_key = env::var("MINIO_PWD").expect("MINIO_PWD is not set in .env file");
|
||||
if s3_base_url.find("http") != Some(0) {
|
||||
s3_base_url = format!("http://{}", s3_base_url);
|
||||
}
|
||||
|
||||
// establish connection to database and apply migrations
|
||||
// -> create post table if not exists
|
||||
let conn = Database::connect(&db_url).await.unwrap();
|
||||
Migrator::up(&conn, None).await.unwrap();
|
||||
|
||||
let static_provider = StaticProvider::new(s3_access_key.as_str(), s3_secret_key.as_str(), None);
|
||||
|
||||
let s3_client = Client::new(
|
||||
s3_base_url.parse::<BaseUrl>()?,
|
||||
Some(Box::new(static_provider)),
|
||||
None,
|
||||
Some(true)
|
||||
)?;
|
||||
|
||||
let state = AppState { conn, s3_client };
|
||||
|
||||
// create server and try to serve over socket if possible
|
||||
let mut listenfd = ListenFd::from_env();
|
||||
let mut server = HttpServer::new(move || {
|
||||
App::new()
|
||||
.service(Files::new("/static", "./static"))
|
||||
.app_data(web::Data::new(state.clone()))
|
||||
.wrap(
|
||||
IdentityService::new(
|
||||
CookieIdentityPolicy::new(&[0; 32])
|
||||
.name("auth-cookie")
|
||||
.login_deadline(Duration::seconds(120))
|
||||
.secure(false)
|
||||
)
|
||||
)
|
||||
.wrap(
|
||||
CookieSession::signed(&[0; 32])
|
||||
.name("session-cookie")
|
||||
.secure(false)
|
||||
// WARNING(alex): This uses the `time` crate, not `std::time`!
|
||||
.expires_in_time(Duration::seconds(60))
|
||||
)
|
||||
.wrap(middleware::Logger::default())
|
||||
.configure(init)
|
||||
});
|
||||
|
||||
server = match listenfd.take_tcp_listener(0)? {
|
||||
Some(listener) => server.listen(listener)?,
|
||||
None => server.bind(&server_url)?,
|
||||
};
|
||||
|
||||
println!("Starting server at {server_url}");
|
||||
server.run().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn init(cfg: &mut web::ServiceConfig) {
|
||||
cfg.service(api::tag_info::create);
|
||||
cfg.service(api::tag_info::delete);
|
||||
cfg.service(api::tag_info::list);
|
||||
|
||||
cfg.service(api::kb_info::create);
|
||||
cfg.service(api::kb_info::delete);
|
||||
cfg.service(api::kb_info::list);
|
||||
cfg.service(api::kb_info::add_docs_to_kb);
|
||||
cfg.service(api::kb_info::anti_kb_docs);
|
||||
cfg.service(api::kb_info::all_relevents);
|
||||
|
||||
cfg.service(api::doc_info::list);
|
||||
cfg.service(api::doc_info::delete);
|
||||
cfg.service(api::doc_info::mv);
|
||||
cfg.service(api::doc_info::upload);
|
||||
cfg.service(api::doc_info::new_folder);
|
||||
cfg.service(api::doc_info::rename);
|
||||
|
||||
cfg.service(api::dialog_info::list);
|
||||
cfg.service(api::dialog_info::delete);
|
||||
cfg.service(api::dialog_info::create);
|
||||
cfg.service(api::dialog_info::update_history);
|
||||
|
||||
cfg.service(api::user_info::login);
|
||||
cfg.service(api::user_info::register);
|
||||
cfg.service(api::user_info::setting);
|
||||
|
||||
cfg.service(web::resource("/ws-upload-doc").route(web::get().to(upload_file_ws)));
|
||||
}
|
@ -1,107 +0,0 @@
|
||||
use chrono::{ Local, FixedOffset, Utc };
|
||||
use migration::Expr;
|
||||
use sea_orm::{
|
||||
ActiveModelTrait,
|
||||
DbConn,
|
||||
DbErr,
|
||||
DeleteResult,
|
||||
EntityTrait,
|
||||
PaginatorTrait,
|
||||
QueryOrder,
|
||||
UpdateResult,
|
||||
};
|
||||
use sea_orm::ActiveValue::Set;
|
||||
use sea_orm::QueryFilter;
|
||||
use sea_orm::ColumnTrait;
|
||||
use crate::entity::dialog_info;
|
||||
use crate::entity::dialog_info::Entity;
|
||||
|
||||
fn now() -> chrono::DateTime<FixedOffset> {
|
||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
||||
}
|
||||
pub struct Query;
|
||||
|
||||
impl Query {
|
||||
pub async fn find_dialog_info_by_id(
|
||||
db: &DbConn,
|
||||
id: i64
|
||||
) -> Result<Option<dialog_info::Model>, DbErr> {
|
||||
Entity::find_by_id(id).one(db).await
|
||||
}
|
||||
|
||||
pub async fn find_dialog_infos(db: &DbConn) -> Result<Vec<dialog_info::Model>, DbErr> {
|
||||
Entity::find().all(db).await
|
||||
}
|
||||
|
||||
pub async fn find_dialog_infos_by_uid(
|
||||
db: &DbConn,
|
||||
uid: i64
|
||||
) -> Result<Vec<dialog_info::Model>, DbErr> {
|
||||
Entity::find()
|
||||
.filter(dialog_info::Column::Uid.eq(uid))
|
||||
.filter(dialog_info::Column::IsDeleted.eq(false))
|
||||
.all(db).await
|
||||
}
|
||||
|
||||
pub async fn find_dialog_infos_in_page(
|
||||
db: &DbConn,
|
||||
page: u64,
|
||||
posts_per_page: u64
|
||||
) -> Result<(Vec<dialog_info::Model>, u64), DbErr> {
|
||||
// Setup paginator
|
||||
let paginator = Entity::find()
|
||||
.order_by_asc(dialog_info::Column::DialogId)
|
||||
.paginate(db, posts_per_page);
|
||||
let num_pages = paginator.num_pages().await?;
|
||||
|
||||
// Fetch paginated posts
|
||||
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Mutation;
|
||||
|
||||
impl Mutation {
|
||||
pub async fn create_dialog_info(
|
||||
db: &DbConn,
|
||||
uid: i64,
|
||||
kb_id: i64,
|
||||
name: &String
|
||||
) -> Result<dialog_info::ActiveModel, DbErr> {
|
||||
(dialog_info::ActiveModel {
|
||||
dialog_id: Default::default(),
|
||||
uid: Set(uid),
|
||||
kb_id: Set(kb_id),
|
||||
dialog_name: Set(name.to_owned()),
|
||||
history: Set("".to_owned()),
|
||||
created_at: Set(now()),
|
||||
updated_at: Set(now()),
|
||||
is_deleted: Default::default(),
|
||||
}).save(db).await
|
||||
}
|
||||
|
||||
pub async fn update_dialog_info_by_id(
|
||||
db: &DbConn,
|
||||
dialog_id: i64,
|
||||
dialog_name: &String,
|
||||
history: &String
|
||||
) -> Result<UpdateResult, DbErr> {
|
||||
Entity::update_many()
|
||||
.col_expr(dialog_info::Column::DialogName, Expr::value(dialog_name))
|
||||
.col_expr(dialog_info::Column::History, Expr::value(history))
|
||||
.col_expr(dialog_info::Column::UpdatedAt, Expr::value(now()))
|
||||
.filter(dialog_info::Column::DialogId.eq(dialog_id))
|
||||
.exec(db).await
|
||||
}
|
||||
|
||||
pub async fn delete_dialog_info(db: &DbConn, dialog_id: i64) -> Result<UpdateResult, DbErr> {
|
||||
Entity::update_many()
|
||||
.col_expr(dialog_info::Column::IsDeleted, Expr::value(true))
|
||||
.filter(dialog_info::Column::DialogId.eq(dialog_id))
|
||||
.exec(db).await
|
||||
}
|
||||
|
||||
pub async fn delete_all_dialog_infos(db: &DbConn) -> Result<DeleteResult, DbErr> {
|
||||
Entity::delete_many().exec(db).await
|
||||
}
|
||||
}
|
@ -1,335 +0,0 @@
|
||||
use chrono::{ Utc, FixedOffset };
|
||||
use sea_orm::{
|
||||
ActiveModelTrait,
|
||||
ColumnTrait,
|
||||
DbConn,
|
||||
DbErr,
|
||||
DeleteResult,
|
||||
EntityTrait,
|
||||
PaginatorTrait,
|
||||
QueryOrder,
|
||||
Unset,
|
||||
Unchanged,
|
||||
ConditionalStatement,
|
||||
QuerySelect,
|
||||
JoinType,
|
||||
RelationTrait,
|
||||
DbBackend,
|
||||
Statement,
|
||||
UpdateResult,
|
||||
};
|
||||
use sea_orm::ActiveValue::Set;
|
||||
use sea_orm::QueryFilter;
|
||||
use crate::api::doc_info::ListParams;
|
||||
use crate::entity::{ doc2_doc, doc_info };
|
||||
use crate::entity::doc_info::Entity;
|
||||
use crate::service;
|
||||
|
||||
fn now() -> chrono::DateTime<FixedOffset> {
|
||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
||||
}
|
||||
|
||||
pub struct Query;
|
||||
|
||||
impl Query {
|
||||
pub async fn find_doc_info_by_id(
|
||||
db: &DbConn,
|
||||
id: i64
|
||||
) -> Result<Option<doc_info::Model>, DbErr> {
|
||||
Entity::find_by_id(id).one(db).await
|
||||
}
|
||||
|
||||
pub async fn find_doc_infos(db: &DbConn) -> Result<Vec<doc_info::Model>, DbErr> {
|
||||
Entity::find().all(db).await
|
||||
}
|
||||
|
||||
pub async fn find_doc_infos_by_uid(
|
||||
db: &DbConn,
|
||||
uid: i64
|
||||
) -> Result<Vec<doc_info::Model>, DbErr> {
|
||||
Entity::find().filter(doc_info::Column::Uid.eq(uid)).all(db).await
|
||||
}
|
||||
|
||||
pub async fn find_doc_infos_by_name(
|
||||
db: &DbConn,
|
||||
uid: i64,
|
||||
name: &String,
|
||||
parent_id: Option<i64>
|
||||
) -> Result<Vec<doc_info::Model>, DbErr> {
|
||||
let mut dids = Vec::<i64>::new();
|
||||
if let Some(pid) = parent_id {
|
||||
for d2d in doc2_doc::Entity
|
||||
::find()
|
||||
.filter(doc2_doc::Column::ParentId.eq(pid))
|
||||
.all(db).await? {
|
||||
dids.push(d2d.did);
|
||||
}
|
||||
} else {
|
||||
let doc = Entity::find()
|
||||
.filter(doc_info::Column::DocName.eq(name.clone()))
|
||||
.filter(doc_info::Column::Uid.eq(uid))
|
||||
.all(db).await?;
|
||||
if doc.len() == 0 {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
assert!(doc.len() > 0);
|
||||
let d2d = doc2_doc::Entity
|
||||
::find()
|
||||
.filter(doc2_doc::Column::Did.eq(doc[0].did))
|
||||
.all(db).await?;
|
||||
assert!(d2d.len() <= 1, "Did: {}->{}", doc[0].did, d2d.len());
|
||||
if d2d.len() > 0 {
|
||||
for d2d_ in doc2_doc::Entity
|
||||
::find()
|
||||
.filter(doc2_doc::Column::ParentId.eq(d2d[0].parent_id))
|
||||
.all(db).await? {
|
||||
dids.push(d2d_.did);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Entity::find()
|
||||
.filter(doc_info::Column::DocName.eq(name.clone()))
|
||||
.filter(doc_info::Column::Uid.eq(uid))
|
||||
.filter(doc_info::Column::Did.is_in(dids))
|
||||
.filter(doc_info::Column::IsDeleted.eq(false))
|
||||
.all(db).await
|
||||
}
|
||||
|
||||
pub async fn all_descendent_ids(db: &DbConn, doc_ids: &Vec<i64>) -> Result<Vec<i64>, DbErr> {
|
||||
let mut dids = doc_ids.clone();
|
||||
let mut i: usize = 0;
|
||||
loop {
|
||||
if dids.len() == i {
|
||||
break;
|
||||
}
|
||||
|
||||
for d in doc2_doc::Entity
|
||||
::find()
|
||||
.filter(doc2_doc::Column::ParentId.eq(dids[i]))
|
||||
.all(db).await? {
|
||||
dids.push(d.did);
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
Ok(dids)
|
||||
}
|
||||
|
||||
pub async fn find_doc_infos_by_params(
|
||||
db: &DbConn,
|
||||
params: ListParams
|
||||
) -> Result<Vec<doc_info::Model>, DbErr> {
|
||||
// Setup paginator
|
||||
let mut sql: String =
|
||||
"
|
||||
select
|
||||
a.did,
|
||||
a.uid,
|
||||
a.doc_name,
|
||||
a.location,
|
||||
a.size,
|
||||
a.type,
|
||||
a.created_at,
|
||||
a.updated_at,
|
||||
a.is_deleted
|
||||
from
|
||||
doc_info as a
|
||||
".to_owned();
|
||||
|
||||
let mut cond: String = format!(" a.uid={} and a.is_deleted=False ", params.uid);
|
||||
|
||||
if let Some(kb_id) = params.filter.kb_id {
|
||||
sql.push_str(
|
||||
&format!(" inner join kb2_doc on kb2_doc.did = a.did and kb2_doc.kb_id={}", kb_id)
|
||||
);
|
||||
}
|
||||
if let Some(folder_id) = params.filter.folder_id {
|
||||
sql.push_str(
|
||||
&format!(" inner join doc2_doc on a.did = doc2_doc.did and doc2_doc.parent_id={}", folder_id)
|
||||
);
|
||||
}
|
||||
// Fetch paginated posts
|
||||
if let Some(tag_id) = params.filter.tag_id {
|
||||
let tag = service::tag_info::Query
|
||||
::find_tag_info_by_id(tag_id, &db).await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
if tag.folder_id > 0 {
|
||||
sql.push_str(
|
||||
&format!(
|
||||
" inner join doc2_doc on a.did = doc2_doc.did and doc2_doc.parent_id={}",
|
||||
tag.folder_id
|
||||
)
|
||||
);
|
||||
}
|
||||
if tag.regx.len() > 0 {
|
||||
cond.push_str(&format!(" and (type='{}' or doc_name ~ '{}') ", tag.tag_name, tag.regx));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(keywords) = params.filter.keywords {
|
||||
cond.push_str(&format!(" and doc_name like '%{}%'", keywords));
|
||||
}
|
||||
if cond.len() > 0 {
|
||||
sql.push_str(&" where ");
|
||||
sql.push_str(&cond);
|
||||
}
|
||||
let mut orderby = params.sortby.clone();
|
||||
if orderby.len() == 0 {
|
||||
orderby = "updated_at desc".to_owned();
|
||||
}
|
||||
sql.push_str(&format!(" order by {}", orderby));
|
||||
let mut page_size: u32 = 30;
|
||||
if let Some(pg_sz) = params.per_page {
|
||||
page_size = pg_sz;
|
||||
}
|
||||
let mut page: u32 = 0;
|
||||
if let Some(pg) = params.page {
|
||||
page = pg;
|
||||
}
|
||||
sql.push_str(&format!(" limit {} offset {} ;", page_size, page * page_size));
|
||||
|
||||
print!("{}", sql);
|
||||
Entity::find()
|
||||
.from_raw_sql(Statement::from_sql_and_values(DbBackend::Postgres, sql, vec![]))
|
||||
.all(db).await
|
||||
}
|
||||
|
||||
pub async fn find_doc_infos_in_page(
|
||||
db: &DbConn,
|
||||
page: u64,
|
||||
posts_per_page: u64
|
||||
) -> Result<(Vec<doc_info::Model>, u64), DbErr> {
|
||||
// Setup paginator
|
||||
let paginator = Entity::find()
|
||||
.order_by_asc(doc_info::Column::Did)
|
||||
.paginate(db, posts_per_page);
|
||||
let num_pages = paginator.num_pages().await?;
|
||||
|
||||
// Fetch paginated posts
|
||||
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Mutation;
|
||||
|
||||
impl Mutation {
|
||||
pub async fn mv_doc_info(db: &DbConn, dest_did: i64, dids: &[i64]) -> Result<(), DbErr> {
|
||||
for did in dids {
|
||||
let d = doc2_doc::Entity
|
||||
::find()
|
||||
.filter(doc2_doc::Column::Did.eq(did.to_owned()))
|
||||
.all(db).await?;
|
||||
|
||||
let _ = (doc2_doc::ActiveModel {
|
||||
id: Set(d[0].id),
|
||||
did: Set(did.to_owned()),
|
||||
parent_id: Set(dest_did),
|
||||
}).update(db).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn place_doc(
|
||||
db: &DbConn,
|
||||
dest_did: i64,
|
||||
did: i64
|
||||
) -> Result<doc2_doc::ActiveModel, DbErr> {
|
||||
(doc2_doc::ActiveModel {
|
||||
id: Default::default(),
|
||||
parent_id: Set(dest_did),
|
||||
did: Set(did),
|
||||
}).save(db).await
|
||||
}
|
||||
|
||||
pub async fn create_doc_info(
|
||||
db: &DbConn,
|
||||
form_data: doc_info::Model
|
||||
) -> Result<doc_info::ActiveModel, DbErr> {
|
||||
(doc_info::ActiveModel {
|
||||
did: Default::default(),
|
||||
uid: Set(form_data.uid.to_owned()),
|
||||
doc_name: Set(form_data.doc_name.to_owned()),
|
||||
size: Set(form_data.size.to_owned()),
|
||||
r#type: Set(form_data.r#type.to_owned()),
|
||||
location: Set(form_data.location.to_owned()),
|
||||
thumbnail_base64: Default::default(),
|
||||
created_at: Set(form_data.created_at.to_owned()),
|
||||
updated_at: Set(form_data.updated_at.to_owned()),
|
||||
is_deleted: Default::default(),
|
||||
}).save(db).await
|
||||
}
|
||||
|
||||
pub async fn update_doc_info_by_id(
|
||||
db: &DbConn,
|
||||
id: i64,
|
||||
form_data: doc_info::Model
|
||||
) -> Result<doc_info::Model, DbErr> {
|
||||
let doc_info: doc_info::ActiveModel = Entity::find_by_id(id)
|
||||
.one(db).await?
|
||||
.ok_or(DbErr::Custom("Cannot find.".to_owned()))
|
||||
.map(Into::into)?;
|
||||
|
||||
(doc_info::ActiveModel {
|
||||
did: doc_info.did,
|
||||
uid: Set(form_data.uid.to_owned()),
|
||||
doc_name: Set(form_data.doc_name.to_owned()),
|
||||
size: Set(form_data.size.to_owned()),
|
||||
r#type: Set(form_data.r#type.to_owned()),
|
||||
location: Set(form_data.location.to_owned()),
|
||||
thumbnail_base64: doc_info.thumbnail_base64,
|
||||
created_at: doc_info.created_at,
|
||||
updated_at: Set(now()),
|
||||
is_deleted: Default::default(),
|
||||
}).update(db).await
|
||||
}
|
||||
|
||||
pub async fn delete_doc_info(db: &DbConn, doc_ids: &Vec<i64>) -> Result<UpdateResult, DbErr> {
|
||||
let mut dids = doc_ids.clone();
|
||||
let mut i: usize = 0;
|
||||
loop {
|
||||
if dids.len() == i {
|
||||
break;
|
||||
}
|
||||
let mut doc: doc_info::ActiveModel = Entity::find_by_id(dids[i])
|
||||
.one(db).await?
|
||||
.ok_or(DbErr::Custom(format!("Can't find doc:{}", dids[i])))
|
||||
.map(Into::into)?;
|
||||
doc.updated_at = Set(now());
|
||||
doc.is_deleted = Set(true);
|
||||
let _ = doc.update(db).await?;
|
||||
|
||||
for d in doc2_doc::Entity
|
||||
::find()
|
||||
.filter(doc2_doc::Column::ParentId.eq(dids[i]))
|
||||
.all(db).await? {
|
||||
dids.push(d.did);
|
||||
}
|
||||
let _ = doc2_doc::Entity
|
||||
::delete_many()
|
||||
.filter(doc2_doc::Column::ParentId.eq(dids[i]))
|
||||
.exec(db).await?;
|
||||
let _ = doc2_doc::Entity
|
||||
::delete_many()
|
||||
.filter(doc2_doc::Column::Did.eq(dids[i]))
|
||||
.exec(db).await?;
|
||||
i += 1;
|
||||
}
|
||||
crate::service::kb_info::Mutation::remove_docs(&db, dids, None).await
|
||||
}
|
||||
|
||||
pub async fn rename(db: &DbConn, doc_id: i64, name: &String) -> Result<doc_info::Model, DbErr> {
|
||||
let mut doc: doc_info::ActiveModel = Entity::find_by_id(doc_id)
|
||||
.one(db).await?
|
||||
.ok_or(DbErr::Custom(format!("Can't find doc:{}", doc_id)))
|
||||
.map(Into::into)?;
|
||||
doc.updated_at = Set(now());
|
||||
doc.doc_name = Set(name.clone());
|
||||
doc.update(db).await
|
||||
}
|
||||
|
||||
pub async fn delete_all_doc_infos(db: &DbConn) -> Result<DeleteResult, DbErr> {
|
||||
Entity::delete_many().exec(db).await
|
||||
}
|
||||
}
|
@ -1,168 +0,0 @@
|
||||
use chrono::{ Local, FixedOffset, Utc };
|
||||
use migration::Expr;
|
||||
use sea_orm::{
|
||||
ActiveModelTrait,
|
||||
ColumnTrait,
|
||||
DbConn,
|
||||
DbErr,
|
||||
DeleteResult,
|
||||
EntityTrait,
|
||||
PaginatorTrait,
|
||||
QueryFilter,
|
||||
QueryOrder,
|
||||
UpdateResult,
|
||||
};
|
||||
use sea_orm::ActiveValue::Set;
|
||||
use crate::entity::kb_info;
|
||||
use crate::entity::kb2_doc;
|
||||
use crate::entity::kb_info::Entity;
|
||||
|
||||
fn now() -> chrono::DateTime<FixedOffset> {
|
||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
||||
}
|
||||
pub struct Query;
|
||||
|
||||
impl Query {
|
||||
pub async fn find_kb_info_by_id(db: &DbConn, id: i64) -> Result<Option<kb_info::Model>, DbErr> {
|
||||
Entity::find_by_id(id).one(db).await
|
||||
}
|
||||
|
||||
pub async fn find_kb_infos(db: &DbConn) -> Result<Vec<kb_info::Model>, DbErr> {
|
||||
Entity::find().all(db).await
|
||||
}
|
||||
|
||||
pub async fn find_kb_infos_by_uid(db: &DbConn, uid: i64) -> Result<Vec<kb_info::Model>, DbErr> {
|
||||
Entity::find().filter(kb_info::Column::Uid.eq(uid)).all(db).await
|
||||
}
|
||||
|
||||
pub async fn find_kb_infos_by_name(
|
||||
db: &DbConn,
|
||||
name: String
|
||||
) -> Result<Vec<kb_info::Model>, DbErr> {
|
||||
Entity::find().filter(kb_info::Column::KbName.eq(name)).all(db).await
|
||||
}
|
||||
|
||||
pub async fn find_kb_by_docs(
|
||||
db: &DbConn,
|
||||
doc_ids: Vec<i64>
|
||||
) -> Result<Vec<kb_info::Model>, DbErr> {
|
||||
let mut kbids = Vec::<i64>::new();
|
||||
for k in kb2_doc::Entity
|
||||
::find()
|
||||
.filter(kb2_doc::Column::Did.is_in(doc_ids))
|
||||
.all(db).await? {
|
||||
kbids.push(k.kb_id);
|
||||
}
|
||||
Entity::find().filter(kb_info::Column::KbId.is_in(kbids)).all(db).await
|
||||
}
|
||||
|
||||
pub async fn find_kb_infos_in_page(
|
||||
db: &DbConn,
|
||||
page: u64,
|
||||
posts_per_page: u64
|
||||
) -> Result<(Vec<kb_info::Model>, u64), DbErr> {
|
||||
// Setup paginator
|
||||
let paginator = Entity::find()
|
||||
.order_by_asc(kb_info::Column::KbId)
|
||||
.paginate(db, posts_per_page);
|
||||
let num_pages = paginator.num_pages().await?;
|
||||
|
||||
// Fetch paginated posts
|
||||
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Mutation;
|
||||
|
||||
impl Mutation {
|
||||
pub async fn create_kb_info(
|
||||
db: &DbConn,
|
||||
form_data: kb_info::Model
|
||||
) -> Result<kb_info::ActiveModel, DbErr> {
|
||||
(kb_info::ActiveModel {
|
||||
kb_id: Default::default(),
|
||||
uid: Set(form_data.uid.to_owned()),
|
||||
kb_name: Set(form_data.kb_name.to_owned()),
|
||||
icon: Set(form_data.icon.to_owned()),
|
||||
created_at: Set(now()),
|
||||
updated_at: Set(now()),
|
||||
is_deleted: Default::default(),
|
||||
}).save(db).await
|
||||
}
|
||||
|
||||
pub async fn add_docs(db: &DbConn, kb_id: i64, doc_ids: Vec<i64>) -> Result<(), DbErr> {
|
||||
for did in doc_ids {
|
||||
let res = kb2_doc::Entity
|
||||
::find()
|
||||
.filter(kb2_doc::Column::KbId.eq(kb_id))
|
||||
.filter(kb2_doc::Column::Did.eq(did))
|
||||
.all(db).await?;
|
||||
if res.len() > 0 {
|
||||
continue;
|
||||
}
|
||||
let _ = (kb2_doc::ActiveModel {
|
||||
id: Default::default(),
|
||||
kb_id: Set(kb_id),
|
||||
did: Set(did),
|
||||
kb_progress: Set(0.0),
|
||||
kb_progress_msg: Set("".to_owned()),
|
||||
updated_at: Set(now()),
|
||||
is_deleted: Default::default(),
|
||||
}).save(db).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn remove_docs(
|
||||
db: &DbConn,
|
||||
doc_ids: Vec<i64>,
|
||||
kb_id: Option<i64>
|
||||
) -> Result<UpdateResult, DbErr> {
|
||||
let update = kb2_doc::Entity
|
||||
::update_many()
|
||||
.col_expr(kb2_doc::Column::IsDeleted, Expr::value(true))
|
||||
.col_expr(kb2_doc::Column::KbProgress, Expr::value(0))
|
||||
.col_expr(kb2_doc::Column::KbProgressMsg, Expr::value(""))
|
||||
.filter(kb2_doc::Column::Did.is_in(doc_ids));
|
||||
if let Some(kbid) = kb_id {
|
||||
update.filter(kb2_doc::Column::KbId.eq(kbid)).exec(db).await
|
||||
} else {
|
||||
update.exec(db).await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update_kb_info_by_id(
|
||||
db: &DbConn,
|
||||
id: i64,
|
||||
form_data: kb_info::Model
|
||||
) -> Result<kb_info::Model, DbErr> {
|
||||
let kb_info: kb_info::ActiveModel = Entity::find_by_id(id)
|
||||
.one(db).await?
|
||||
.ok_or(DbErr::Custom("Cannot find.".to_owned()))
|
||||
.map(Into::into)?;
|
||||
|
||||
(kb_info::ActiveModel {
|
||||
kb_id: kb_info.kb_id,
|
||||
uid: kb_info.uid,
|
||||
kb_name: Set(form_data.kb_name.to_owned()),
|
||||
icon: Set(form_data.icon.to_owned()),
|
||||
created_at: kb_info.created_at,
|
||||
updated_at: Set(now()),
|
||||
is_deleted: Default::default(),
|
||||
}).update(db).await
|
||||
}
|
||||
|
||||
pub async fn delete_kb_info(db: &DbConn, kb_id: i64) -> Result<DeleteResult, DbErr> {
|
||||
let kb: kb_info::ActiveModel = Entity::find_by_id(kb_id)
|
||||
.one(db).await?
|
||||
.ok_or(DbErr::Custom("Cannot find.".to_owned()))
|
||||
.map(Into::into)?;
|
||||
|
||||
kb.delete(db).await
|
||||
}
|
||||
|
||||
pub async fn delete_all_kb_infos(db: &DbConn) -> Result<DeleteResult, DbErr> {
|
||||
Entity::delete_many().exec(db).await
|
||||
}
|
||||
}
|
@ -1,5 +0,0 @@
|
||||
pub(crate) mod dialog_info;
|
||||
pub(crate) mod tag_info;
|
||||
pub(crate) mod kb_info;
|
||||
pub(crate) mod doc_info;
|
||||
pub(crate) mod user_info;
|
@ -1,108 +0,0 @@
|
||||
use chrono::{ FixedOffset, Utc };
|
||||
use sea_orm::{
|
||||
ActiveModelTrait,
|
||||
DbConn,
|
||||
DbErr,
|
||||
DeleteResult,
|
||||
EntityTrait,
|
||||
PaginatorTrait,
|
||||
QueryOrder,
|
||||
ColumnTrait,
|
||||
QueryFilter,
|
||||
};
|
||||
use sea_orm::ActiveValue::{ Set, NotSet };
|
||||
use crate::entity::tag_info;
|
||||
use crate::entity::tag_info::Entity;
|
||||
|
||||
fn now() -> chrono::DateTime<FixedOffset> {
|
||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
||||
}
|
||||
pub struct Query;
|
||||
|
||||
impl Query {
|
||||
pub async fn find_tag_info_by_id(
|
||||
id: i64,
|
||||
db: &DbConn
|
||||
) -> Result<Option<tag_info::Model>, DbErr> {
|
||||
Entity::find_by_id(id).one(db).await
|
||||
}
|
||||
|
||||
pub async fn find_tags_by_uid(uid: i64, db: &DbConn) -> Result<Vec<tag_info::Model>, DbErr> {
|
||||
Entity::find().filter(tag_info::Column::Uid.eq(uid)).all(db).await
|
||||
}
|
||||
|
||||
pub async fn find_tag_infos_in_page(
|
||||
db: &DbConn,
|
||||
page: u64,
|
||||
posts_per_page: u64
|
||||
) -> Result<(Vec<tag_info::Model>, u64), DbErr> {
|
||||
// Setup paginator
|
||||
let paginator = Entity::find()
|
||||
.order_by_asc(tag_info::Column::Tid)
|
||||
.paginate(db, posts_per_page);
|
||||
let num_pages = paginator.num_pages().await?;
|
||||
|
||||
// Fetch paginated posts
|
||||
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Mutation;
|
||||
|
||||
impl Mutation {
|
||||
pub async fn create_tag(
|
||||
db: &DbConn,
|
||||
form_data: tag_info::Model
|
||||
) -> Result<tag_info::ActiveModel, DbErr> {
|
||||
(tag_info::ActiveModel {
|
||||
tid: Default::default(),
|
||||
uid: Set(form_data.uid.to_owned()),
|
||||
tag_name: Set(form_data.tag_name.to_owned()),
|
||||
regx: Set(form_data.regx.to_owned()),
|
||||
color: Set(form_data.color.to_owned()),
|
||||
icon: Set(form_data.icon.to_owned()),
|
||||
folder_id: match form_data.folder_id {
|
||||
0 => NotSet,
|
||||
_ => Set(form_data.folder_id.to_owned()),
|
||||
},
|
||||
created_at: Set(now()),
|
||||
updated_at: Set(now()),
|
||||
}).save(db).await
|
||||
}
|
||||
|
||||
pub async fn update_tag_by_id(
|
||||
db: &DbConn,
|
||||
id: i64,
|
||||
form_data: tag_info::Model
|
||||
) -> Result<tag_info::Model, DbErr> {
|
||||
let tag: tag_info::ActiveModel = Entity::find_by_id(id)
|
||||
.one(db).await?
|
||||
.ok_or(DbErr::Custom("Cannot find tag.".to_owned()))
|
||||
.map(Into::into)?;
|
||||
|
||||
(tag_info::ActiveModel {
|
||||
tid: tag.tid,
|
||||
uid: tag.uid,
|
||||
tag_name: Set(form_data.tag_name.to_owned()),
|
||||
regx: Set(form_data.regx.to_owned()),
|
||||
color: Set(form_data.color.to_owned()),
|
||||
icon: Set(form_data.icon.to_owned()),
|
||||
folder_id: Set(form_data.folder_id.to_owned()),
|
||||
created_at: Default::default(),
|
||||
updated_at: Set(now()),
|
||||
}).update(db).await
|
||||
}
|
||||
|
||||
pub async fn delete_tag(db: &DbConn, tid: i64) -> Result<DeleteResult, DbErr> {
|
||||
let tag: tag_info::ActiveModel = Entity::find_by_id(tid)
|
||||
.one(db).await?
|
||||
.ok_or(DbErr::Custom("Cannot find tag.".to_owned()))
|
||||
.map(Into::into)?;
|
||||
|
||||
tag.delete(db).await
|
||||
}
|
||||
|
||||
pub async fn delete_all_tags(db: &DbConn) -> Result<DeleteResult, DbErr> {
|
||||
Entity::delete_many().exec(db).await
|
||||
}
|
||||
}
|
@ -1,131 +0,0 @@
|
||||
use chrono::{ FixedOffset, Utc };
|
||||
use migration::Expr;
|
||||
use sea_orm::{
|
||||
ActiveModelTrait,
|
||||
ColumnTrait,
|
||||
DbConn,
|
||||
DbErr,
|
||||
DeleteResult,
|
||||
EntityTrait,
|
||||
PaginatorTrait,
|
||||
QueryFilter,
|
||||
QueryOrder,
|
||||
UpdateResult,
|
||||
};
|
||||
use sea_orm::ActiveValue::Set;
|
||||
use crate::entity::user_info;
|
||||
use crate::entity::user_info::Entity;
|
||||
|
||||
fn now() -> chrono::DateTime<FixedOffset> {
|
||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
||||
}
|
||||
pub struct Query;
|
||||
|
||||
impl Query {
|
||||
pub async fn find_user_info_by_id(
|
||||
db: &DbConn,
|
||||
id: i64
|
||||
) -> Result<Option<user_info::Model>, DbErr> {
|
||||
Entity::find_by_id(id).one(db).await
|
||||
}
|
||||
|
||||
pub async fn login(
|
||||
db: &DbConn,
|
||||
email: &str,
|
||||
password: &str
|
||||
) -> Result<Option<user_info::Model>, DbErr> {
|
||||
Entity::find()
|
||||
.filter(user_info::Column::Email.eq(email))
|
||||
.filter(user_info::Column::Password.eq(password))
|
||||
.one(db).await
|
||||
}
|
||||
|
||||
pub async fn find_user_infos(
|
||||
db: &DbConn,
|
||||
email: &String
|
||||
) -> Result<Option<user_info::Model>, DbErr> {
|
||||
Entity::find().filter(user_info::Column::Email.eq(email)).one(db).await
|
||||
}
|
||||
|
||||
pub async fn find_user_infos_in_page(
|
||||
db: &DbConn,
|
||||
page: u64,
|
||||
posts_per_page: u64
|
||||
) -> Result<(Vec<user_info::Model>, u64), DbErr> {
|
||||
// Setup paginator
|
||||
let paginator = Entity::find()
|
||||
.order_by_asc(user_info::Column::Uid)
|
||||
.paginate(db, posts_per_page);
|
||||
let num_pages = paginator.num_pages().await?;
|
||||
|
||||
// Fetch paginated posts
|
||||
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Mutation;
|
||||
|
||||
impl Mutation {
|
||||
pub async fn create_user(
|
||||
db: &DbConn,
|
||||
form_data: &user_info::Model
|
||||
) -> Result<user_info::ActiveModel, DbErr> {
|
||||
(user_info::ActiveModel {
|
||||
uid: Default::default(),
|
||||
email: Set(form_data.email.to_owned()),
|
||||
nickname: Set(form_data.nickname.to_owned()),
|
||||
avatar_base64: Set(form_data.avatar_base64.to_owned()),
|
||||
color_scheme: Set(form_data.color_scheme.to_owned()),
|
||||
list_style: Set(form_data.list_style.to_owned()),
|
||||
language: Set(form_data.language.to_owned()),
|
||||
password: Set(form_data.password.to_owned()),
|
||||
last_login_at: Set(now()),
|
||||
created_at: Set(now()),
|
||||
updated_at: Set(now()),
|
||||
}).save(db).await
|
||||
}
|
||||
|
||||
pub async fn update_user_by_id(
|
||||
db: &DbConn,
|
||||
form_data: &user_info::Model
|
||||
) -> Result<user_info::Model, DbErr> {
|
||||
let usr: user_info::ActiveModel = Entity::find_by_id(form_data.uid)
|
||||
.one(db).await?
|
||||
.ok_or(DbErr::Custom("Cannot find user.".to_owned()))
|
||||
.map(Into::into)?;
|
||||
|
||||
(user_info::ActiveModel {
|
||||
uid: Set(form_data.uid),
|
||||
email: Set(form_data.email.to_owned()),
|
||||
nickname: Set(form_data.nickname.to_owned()),
|
||||
avatar_base64: Set(form_data.avatar_base64.to_owned()),
|
||||
color_scheme: Set(form_data.color_scheme.to_owned()),
|
||||
list_style: Set(form_data.list_style.to_owned()),
|
||||
language: Set(form_data.language.to_owned()),
|
||||
password: Set(form_data.password.to_owned()),
|
||||
updated_at: Set(now()),
|
||||
last_login_at: usr.last_login_at,
|
||||
created_at: usr.created_at,
|
||||
}).update(db).await
|
||||
}
|
||||
|
||||
pub async fn update_login_status(uid: i64, db: &DbConn) -> Result<UpdateResult, DbErr> {
|
||||
Entity::update_many()
|
||||
.col_expr(user_info::Column::LastLoginAt, Expr::value(now()))
|
||||
.filter(user_info::Column::Uid.eq(uid))
|
||||
.exec(db).await
|
||||
}
|
||||
|
||||
pub async fn delete_user(db: &DbConn, tid: i64) -> Result<DeleteResult, DbErr> {
|
||||
let tag: user_info::ActiveModel = Entity::find_by_id(tid)
|
||||
.one(db).await?
|
||||
.ok_or(DbErr::Custom("Cannot find tag.".to_owned()))
|
||||
.map(Into::into)?;
|
||||
|
||||
tag.delete(db).await
|
||||
}
|
||||
|
||||
pub async fn delete_all(db: &DbConn) -> Result<DeleteResult, DbErr> {
|
||||
Entity::delete_many().exec(db).await
|
||||
}
|
||||
}
|
@ -1,97 +0,0 @@
|
||||
use std::io::{Cursor, Write};
|
||||
use std::time::{Duration, Instant};
|
||||
use actix_rt::time::interval;
|
||||
use actix_web::{HttpRequest, HttpResponse, rt, web};
|
||||
use actix_web::web::Buf;
|
||||
use actix_ws::Message;
|
||||
use futures_util::{future, StreamExt};
|
||||
use futures_util::future::Either;
|
||||
use uuid::Uuid;
|
||||
use crate::api::doc_info::_upload_file;
|
||||
use crate::AppState;
|
||||
use crate::errors::AppError;
|
||||
|
||||
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
/// How long before lack of client response causes a timeout.
|
||||
const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
pub async fn upload_file_ws(req: HttpRequest, stream: web::Payload, data: web::Data<AppState>) -> Result<HttpResponse, AppError> {
|
||||
let (res, session, msg_stream) = actix_ws::handle(&req, stream)?;
|
||||
|
||||
// spawn websocket handler (and don't await it) so that the response is returned immediately
|
||||
rt::spawn(upload_file_handler(data, session, msg_stream));
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn upload_file_handler(
|
||||
data: web::Data<AppState>,
|
||||
mut session: actix_ws::Session,
|
||||
mut msg_stream: actix_ws::MessageStream,
|
||||
) {
|
||||
let mut bytes = Cursor::new(vec![]);
|
||||
let mut last_heartbeat = Instant::now();
|
||||
let mut interval = interval(HEARTBEAT_INTERVAL);
|
||||
|
||||
let reason = loop {
|
||||
let tick = interval.tick();
|
||||
tokio::pin!(tick);
|
||||
|
||||
match future::select(msg_stream.next(), tick).await {
|
||||
// received message from WebSocket client
|
||||
Either::Left((Some(Ok(msg)), _)) => {
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
session.text(text).await.unwrap();
|
||||
}
|
||||
|
||||
Message::Binary(bin) => {
|
||||
let mut pos = 0; // notice the name of the file that will be written
|
||||
while pos < bin.len() {
|
||||
let bytes_written = bytes.write(&bin[pos..]).unwrap();
|
||||
pos += bytes_written
|
||||
};
|
||||
session.binary(bin).await.unwrap();
|
||||
}
|
||||
|
||||
Message::Close(reason) => {
|
||||
break reason;
|
||||
}
|
||||
|
||||
Message::Ping(bytes) => {
|
||||
last_heartbeat = Instant::now();
|
||||
let _ = session.pong(&bytes).await;
|
||||
}
|
||||
|
||||
Message::Pong(_) => {
|
||||
last_heartbeat = Instant::now();
|
||||
}
|
||||
|
||||
Message::Continuation(_) | Message::Nop => {}
|
||||
};
|
||||
}
|
||||
Either::Left((Some(Err(_)), _)) => {
|
||||
break None;
|
||||
}
|
||||
Either::Left((None, _)) => break None,
|
||||
Either::Right((_inst, _)) => {
|
||||
if Instant::now().duration_since(last_heartbeat) > CLIENT_TIMEOUT {
|
||||
break None;
|
||||
}
|
||||
|
||||
let _ = session.ping(b"").await;
|
||||
}
|
||||
}
|
||||
};
|
||||
let _ = session.close(reason).await;
|
||||
|
||||
if !bytes.has_remaining() {
|
||||
return;
|
||||
}
|
||||
|
||||
let uid = bytes.get_i64();
|
||||
let did = bytes.get_i64();
|
||||
|
||||
_upload_file(uid, did, &Uuid::new_v4().to_string(), &bytes.into_inner(), &data).await.unwrap();
|
||||
}
|
@ -1 +0,0 @@
|
||||
pub mod doc_info;
|
0
web_server/__init__.py
Normal file
0
web_server/__init__.py
Normal file
147
web_server/apps/__init__.py
Normal file
147
web_server/apps/__init__.py
Normal file
@ -0,0 +1,147 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from flask import Blueprint, Flask, request
|
||||
from werkzeug.wrappers.request import Request
|
||||
from flask_cors import CORS
|
||||
|
||||
from web_server.db import StatusEnum
|
||||
from web_server.db.services import UserService
|
||||
from web_server.utils import CustomJSONEncoder
|
||||
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from web_server.settings import RetCode, SECRET_KEY, stat_logger
|
||||
from web_server.hook import HookManager
|
||||
from web_server.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters
|
||||
from web_server.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger
|
||||
from web_server.utils.api_utils import get_json_result, server_error_response
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
__all__ = ['app']
|
||||
|
||||
|
||||
logger = logging.getLogger('flask.app')
|
||||
for h in access_logger.handlers:
|
||||
logger.addHandler(h)
|
||||
|
||||
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app, supports_credentials=True,max_age = 2592000)
|
||||
app.url_map.strict_slashes = False
|
||||
app.json_encoder = CustomJSONEncoder
|
||||
app.errorhandler(Exception)(server_error_response)
|
||||
|
||||
|
||||
## convince for dev and debug
|
||||
#app.config["LOGIN_DISABLED"] = True
|
||||
app.config["SESSION_PERMANENT"] = False
|
||||
app.config["SESSION_TYPE"] = "filesystem"
|
||||
app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024
|
||||
|
||||
Session(app)
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
|
||||
|
||||
def search_pages_path(pages_dir):
|
||||
return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
|
||||
|
||||
|
||||
def register_page(page_path):
|
||||
page_name = page_path.stem.rstrip('_app')
|
||||
module_name = '.'.join(page_path.parts[page_path.parts.index('web_server'):-1] + (page_name, ))
|
||||
|
||||
spec = spec_from_file_location(module_name, page_path)
|
||||
page = module_from_spec(spec)
|
||||
page.app = app
|
||||
page.manager = Blueprint(page_name, module_name)
|
||||
sys.modules[module_name] = page
|
||||
spec.loader.exec_module(page)
|
||||
|
||||
page_name = getattr(page, 'page_name', page_name)
|
||||
url_prefix = f'/{API_VERSION}/{page_name}'
|
||||
|
||||
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
||||
return url_prefix
|
||||
|
||||
|
||||
pages_dir = [
|
||||
Path(__file__).parent,
|
||||
Path(__file__).parent.parent / 'web_server' / 'apps',
|
||||
]
|
||||
|
||||
client_urls_prefix = [
|
||||
register_page(path)
|
||||
for dir in pages_dir
|
||||
for path in search_pages_path(dir)
|
||||
]
|
||||
|
||||
|
||||
def client_authentication_before_request():
|
||||
result = HookManager.client_authentication(ClientAuthenticationParameters(
|
||||
request.full_path, request.headers,
|
||||
request.form, request.data, request.json,
|
||||
))
|
||||
|
||||
if result.code != RetCode.SUCCESS:
|
||||
return get_json_result(result.code, result.message)
|
||||
|
||||
|
||||
def site_authentication_before_request():
|
||||
for url_prefix in client_urls_prefix:
|
||||
if request.path.startswith(url_prefix):
|
||||
return
|
||||
|
||||
result = HookManager.site_authentication(AuthenticationParameters(
|
||||
request.headers.get('site_signature'),
|
||||
request.json,
|
||||
))
|
||||
|
||||
if result.code != RetCode.SUCCESS:
|
||||
return get_json_result(result.code, result.message)
|
||||
|
||||
|
||||
@app.before_request
|
||||
def authentication_before_request():
|
||||
if CLIENT_AUTHENTICATION:
|
||||
return client_authentication_before_request()
|
||||
|
||||
if SITE_AUTHENTICATION:
|
||||
return site_authentication_before_request()
|
||||
|
||||
@login_manager.request_loader
|
||||
def load_user(web_request):
|
||||
jwt = Serializer(secret_key=SECRET_KEY)
|
||||
authorization = web_request.headers.get("Authorization")
|
||||
if authorization:
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
|
||||
if user:
|
||||
return user[0]
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
return None
|
||||
else:
|
||||
return None
|
235
web_server/apps/document_app.py
Normal file
235
web_server/apps/document_app.py
Normal file
@ -0,0 +1,235 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import pathlib
|
||||
|
||||
from elasticsearch_dsl import Q
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from rag.nlp import search
|
||||
from rag.utils import ELASTICSEARCH
|
||||
from web_server.db.services import duplicate_name
|
||||
from web_server.db.services.kb_service import KnowledgebaseService
|
||||
from web_server.db.services.user_service import TenantService
|
||||
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db import StatusEnum, FileType
|
||||
from web_server.db.services.document_service import DocumentService
|
||||
from web_server.settings import RetCode
|
||||
from web_server.utils.api_utils import get_json_result
|
||||
from rag.utils.minio_conn import MINIO
|
||||
from web_server.utils.file_utils import filename_type
|
||||
|
||||
|
||||
@manager.route('/upload', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def upload():
|
||||
kb_id = request.form.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(
|
||||
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return get_json_result(
|
||||
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
|
||||
filename = duplicate_name(
|
||||
DocumentService.query,
|
||||
name=file.filename,
|
||||
kb_id=kb.id)
|
||||
location = filename
|
||||
while MINIO.obj_exist(kb_id, location):
|
||||
location += "_"
|
||||
blob = request.files['file'].read()
|
||||
MINIO.put(kb_id, filename, blob)
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"created_by": current_user.id,
|
||||
"type": filename_type(filename),
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob)
|
||||
})
|
||||
return get_json_result(data=doc.to_json())
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/create', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("name", "kb_id")
|
||||
def create():
|
||||
req = request.json
|
||||
kb_id = req["kb_id"]
|
||||
if not kb_id:
|
||||
return get_json_result(
|
||||
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
|
||||
if DocumentService.query(name=req["name"], kb_id=kb_id):
|
||||
return get_data_error_result(
|
||||
retmsg="Duplicated document name in the same knowledgebase.")
|
||||
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"created_by": current_user.id,
|
||||
"type": FileType.VIRTUAL,
|
||||
"name": req["name"],
|
||||
"location": "",
|
||||
"size": 0
|
||||
})
|
||||
return get_json_result(data=doc.to_json())
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET'])
|
||||
@login_required
|
||||
def list():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(
|
||||
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
|
||||
keywords = request.args.get("keywords", "")
|
||||
|
||||
page_number = request.args.get("page", 1)
|
||||
items_per_page = request.args.get("page_size", 15)
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
try:
|
||||
docs = DocumentService.get_by_kb_id(
|
||||
kb_id, page_number, items_per_page, orderby, desc, keywords)
|
||||
return get_json_result(data=docs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/change_status', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id", "status")
|
||||
def change_status():
|
||||
req = request.json
|
||||
if str(req["status"]) not in ["0", "1"]:
|
||||
get_json_result(
|
||||
data=False,
|
||||
retmsg='"Status" must be either 0 or 1!',
|
||||
retcode=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
retmsg="Can't find this knowledgebase!")
|
||||
|
||||
if not DocumentService.update_by_id(
|
||||
req["doc_id"], {"status": str(req["status"])}):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Document update)!")
|
||||
|
||||
if str(req["status"]) == "0":
|
||||
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
|
||||
scripts="""
|
||||
if(ctx._source.kb_id.contains('%s'))
|
||||
ctx._source.kb_id.remove(
|
||||
ctx._source.kb_id.indexOf('%s')
|
||||
);
|
||||
""" % (doc.kb_id, doc.kb_id),
|
||||
idxnm=search.index_name(
|
||||
kb.tenant_id)
|
||||
)
|
||||
else:
|
||||
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
|
||||
scripts="""
|
||||
if(!ctx._source.kb_id.contains('%s'))
|
||||
ctx._source.kb_id.add('%s');
|
||||
""" % (doc.kb_id, doc.kb_id),
|
||||
idxnm=search.index_name(
|
||||
kb.tenant_id)
|
||||
)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if not DocumentService.delete_by_id(req["doc_id"]):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Document removal)!")
|
||||
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||
MINIO.rm(kb.id, doc.location)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rename', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id", "name", "old_name")
|
||||
def rename():
|
||||
req = request.json
|
||||
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
req["old_name"].lower()).suffix:
|
||||
get_json_result(
|
||||
data=False,
|
||||
retmsg="The extension of file can't be changed",
|
||||
retcode=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
if DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||
return get_data_error_result(
|
||||
retmsg="Duplicated document name in the same knowledgebase.")
|
||||
|
||||
if not DocumentService.update_by_id(
|
||||
req["doc_id"], {"name": req["name"]}):
|
||||
return get_data_error_result(
|
||||
retmsg="Database error (Document rename)!")
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
102
web_server/apps/kb_app.py
Normal file
102
web_server/apps/kb_app.py
Normal file
@ -0,0 +1,102 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from web_server.db.services import duplicate_name
|
||||
from web_server.db.services.user_service import TenantService, UserTenantService
|
||||
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from web_server.utils import get_uuid, get_format_time
|
||||
from web_server.db import StatusEnum, UserTenantRole
|
||||
from web_server.db.services.kb_service import KnowledgebaseService
|
||||
from web_server.db.db_models import Knowledgebase
|
||||
from web_server.settings import stat_logger, RetCode
|
||||
from web_server.utils.api_utils import get_json_result
|
||||
|
||||
|
||||
@manager.route('/create', methods=['post'])
|
||||
@login_required
|
||||
@validate_request("name", "description", "permission", "embd_id", "parser_id")
|
||||
def create():
|
||||
req = request.json
|
||||
req["name"] = req["name"].strip()
|
||||
req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)
|
||||
try:
|
||||
req["id"] = get_uuid()
|
||||
req["tenant_id"] = current_user.id
|
||||
req["created_by"] = current_user.id
|
||||
if not KnowledgebaseService.save(**req): return get_data_error_result()
|
||||
return get_json_result(data={"kb_id": req["id"]})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/update', methods=['post'])
|
||||
@login_required
|
||||
@validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id")
|
||||
def update():
|
||||
req = request.json
|
||||
req["name"] = req["name"].strip()
|
||||
try:
|
||||
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||
if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||
|
||||
if req["name"].lower() != kb.name.lower() \
|
||||
and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1:
|
||||
return get_data_error_result(retmsg="Duplicated knowledgebase name.")
|
||||
|
||||
del req["kb_id"]
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result()
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!")
|
||||
|
||||
return get_json_result(data=kb.to_json())
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET'])
|
||||
@login_required
|
||||
def list():
|
||||
page_number = request.args.get("page", 1)
|
||||
items_per_page = request.args.get("page_size", 15)
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
try:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
|
||||
return get_json_result(data=kbs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['post'])
|
||||
@login_required
|
||||
@validate_request("kb_id")
|
||||
def rm():
|
||||
req = request.json
|
||||
try:
|
||||
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.IN_VALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!")
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
226
web_server/apps/user_app.py
Normal file
226
web_server/apps/user_app.py
Normal file
@ -0,0 +1,226 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import request, session, redirect, url_for
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
from flask_login import login_required, current_user, login_user, logout_user
|
||||
from web_server.utils.api_utils import server_error_response, validate_request
|
||||
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
|
||||
from web_server.db import UserTenantRole
|
||||
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
|
||||
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
|
||||
from web_server.settings import stat_logger
|
||||
from web_server.utils.api_utils import get_json_result, cors_reponse
|
||||
|
||||
|
||||
@manager.route('/login', methods=['POST', 'GET'])
|
||||
def login():
|
||||
userinfo = None
|
||||
login_channel = "password"
|
||||
if session.get("access_token"):
|
||||
login_channel = session["access_token_from"]
|
||||
if session["access_token_from"] == "github":
|
||||
userinfo = user_info_from_github(session["access_token"])
|
||||
elif not request.json:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||
retmsg='Unautherized!')
|
||||
|
||||
email = request.json.get('email') if not userinfo else userinfo["email"]
|
||||
users = UserService.query(email=email)
|
||||
if not users:
|
||||
if request.json is not None:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
||||
avatar = ""
|
||||
try:
|
||||
avatar = download_img(userinfo["avatar_url"])
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
try:
|
||||
users = user_register({
|
||||
"access_token": session["access_token"],
|
||||
"email": userinfo["email"],
|
||||
"avatar": avatar,
|
||||
"nickname": userinfo["login"],
|
||||
"login_channel": login_channel,
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
})
|
||||
if not users: raise Exception('Register user failure.')
|
||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
return server_error_response(e)
|
||||
elif not request.json:
|
||||
login_user(users[0])
|
||||
return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
|
||||
|
||||
password = request.json.get('password')
|
||||
try:
|
||||
password = decrypt(password)
|
||||
except:
|
||||
return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
|
||||
|
||||
user = UserService.query_user(email, password)
|
||||
if user:
|
||||
response_data = user.to_json()
|
||||
user.access_token = get_uuid()
|
||||
login_user(user)
|
||||
user.save()
|
||||
msg = "Welcome back!"
|
||||
return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
|
||||
else:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!')
|
||||
|
||||
|
||||
@manager.route('/github_callback', methods=['GET'])
|
||||
def github_callback():
|
||||
try:
|
||||
import requests
|
||||
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
||||
"client_id": GITHUB_OAUTH.get("client_id"),
|
||||
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
||||
"code": request.args.get('code')
|
||||
},headers={"Accept": "application/json"})
|
||||
res = res.json()
|
||||
if "error" in res:
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||
retmsg=res["error_description"])
|
||||
|
||||
if "user:email" not in res["scope"].split(","):
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
|
||||
|
||||
session["access_token"] = res["access_token"]
|
||||
session["access_token_from"] = "github"
|
||||
return redirect(url_for("user.login"), code=307)
|
||||
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
def user_info_from_github(access_token):
|
||||
import requests
|
||||
headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"}
|
||||
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||
user_info = res.json()
|
||||
email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json()
|
||||
user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"]
|
||||
return user_info
|
||||
|
||||
|
||||
@manager.route("/logout", methods=['GET'])
|
||||
@login_required
|
||||
def log_out():
|
||||
current_user.access_token = ""
|
||||
current_user.save()
|
||||
logout_user()
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/setting", methods=["POST"])
|
||||
@login_required
|
||||
def setting_user():
|
||||
update_dict = {}
|
||||
request_data = request.json
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
|
||||
|
||||
if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password))
|
||||
|
||||
for k in request_data.keys():
|
||||
if k in ["password", "new_password"]:continue
|
||||
update_dict[k] = request_data[k]
|
||||
|
||||
try:
|
||||
UserService.update_by_id(current_user.id, update_dict)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
|
||||
@manager.route("/info", methods=["GET"])
|
||||
@login_required
|
||||
def user_info():
|
||||
return get_json_result(data=current_user.to_dict())
|
||||
|
||||
|
||||
def user_register(user):
|
||||
user_id = get_uuid()
|
||||
user["id"] = user_id
|
||||
tenant = {
|
||||
"id": user_id,
|
||||
"name": user["nickname"] + "‘s Kingdom",
|
||||
"llm_id": CHAT_MDL,
|
||||
"embd_id": EMBEDDING_MDL,
|
||||
"asr_id": ASR_MDL,
|
||||
"parser_ids": PARSERS,
|
||||
"img2txt_id": IMAGE2TEXT_MDL
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_id,
|
||||
"user_id": user_id,
|
||||
"invited_by": user_id,
|
||||
"role": UserTenantRole.OWNER
|
||||
}
|
||||
|
||||
if not UserService.save(**user):return
|
||||
TenantService.save(**tenant)
|
||||
UserTenantService.save(**usr_tenant)
|
||||
return UserService.query(email=user["email"])
|
||||
|
||||
|
||||
@manager.route("/register", methods=["POST"])
|
||||
@validate_request("nickname", "email", "password")
|
||||
def user_add():
|
||||
req = request.json
|
||||
if UserService.query(email=req["email"]):
|
||||
return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
|
||||
|
||||
user_dict = {
|
||||
"access_token": get_uuid(),
|
||||
"email": req["email"],
|
||||
"nickname": req["nickname"],
|
||||
"password": decrypt(req["password"]),
|
||||
"login_channel": "password",
|
||||
"last_login_time": get_format_time(),
|
||||
"is_superuser": False,
|
||||
}
|
||||
try:
|
||||
users = user_register(user_dict)
|
||||
if not users: raise Exception('Register user failure.')
|
||||
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||
user = users[0]
|
||||
login_user(user)
|
||||
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
|
||||
except Exception as e:
|
||||
stat_logger.exception(e)
|
||||
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||
|
||||
|
||||
|
||||
@manager.route("/tenant_info", methods=["GET"])
|
||||
@login_required
|
||||
def tenant_info():
|
||||
try:
|
||||
tenants = TenantService.get_by_user_id(current_user.id)
|
||||
return get_json_result(data=tenants)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
54
web_server/db/__init__.py
Normal file
54
web_server/db/__init__.py
Normal file
@ -0,0 +1,54 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from enum import Enum
|
||||
from enum import IntEnum
|
||||
from strenum import StrEnum
|
||||
|
||||
|
||||
class StatusEnum(Enum):
|
||||
VALID = "1"
|
||||
IN_VALID = "0"
|
||||
|
||||
|
||||
class UserTenantRole(StrEnum):
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
NORMAL = 'normal'
|
||||
|
||||
|
||||
class TenantPermission(StrEnum):
|
||||
ME = 'me'
|
||||
TEAM = 'team'
|
||||
|
||||
|
||||
class SerializedType(IntEnum):
|
||||
PICKLE = 1
|
||||
JSON = 2
|
||||
|
||||
|
||||
class FileType(StrEnum):
|
||||
PDF = 'pdf'
|
||||
DOC = 'doc'
|
||||
VISUAL = 'visual'
|
||||
AURAL = 'aural'
|
||||
VIRTUAL = 'virtual'
|
||||
|
||||
|
||||
class LLMType(StrEnum):
|
||||
CHAT = 'chat'
|
||||
EMBEDDING = 'embedding'
|
||||
SPEECH2TEXT = 'speech2text'
|
||||
IMAGE2TEXT = 'image2text'
|
616
web_server/db/db_models.py
Normal file
616
web_server/db/db_models.py
Normal file
@ -0,0 +1,616 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import operator
|
||||
from functools import wraps
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from flask_login import UserMixin
|
||||
|
||||
from peewee import (
|
||||
BigAutoField, BigIntegerField, BooleanField, CharField,
|
||||
CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField,
|
||||
Field, Model, Metadata
|
||||
)
|
||||
from playhouse.pool import PooledMySQLDatabase
|
||||
|
||||
from web_server.db import SerializedType
|
||||
from web_server.settings import DATABASE, stat_logger, SECRET_KEY
|
||||
from web_server.utils.log_utils import getLogger
|
||||
from web_server import utils
|
||||
|
||||
LOGGER = getLogger()
|
||||
|
||||
|
||||
def singleton(cls, *args, **kw):
|
||||
instances = {}
|
||||
|
||||
def _singleton():
|
||||
key = str(cls) + str(os.getpid())
|
||||
if key not in instances:
|
||||
instances[key] = cls(*args, **kw)
|
||||
return instances[key]
|
||||
|
||||
return _singleton
|
||||
|
||||
|
||||
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
|
||||
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
|
||||
|
||||
|
||||
class LongTextField(TextField):
|
||||
field_type = 'LONGTEXT'
|
||||
|
||||
|
||||
class JSONField(LongTextField):
|
||||
default_value = {}
|
||||
|
||||
def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
|
||||
self._object_hook = object_hook
|
||||
self._object_pairs_hook = object_pairs_hook
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def db_value(self, value):
|
||||
if value is None:
|
||||
value = self.default_value
|
||||
return utils.json_dumps(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if not value:
|
||||
return self.default_value
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
|
||||
|
||||
class ListField(JSONField):
|
||||
default_value = []
|
||||
|
||||
|
||||
class SerializedField(LongTextField):
|
||||
def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
|
||||
self._serialized_type = serialized_type
|
||||
self._object_hook = object_hook
|
||||
self._object_pairs_hook = object_pairs_hook
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def db_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
return utils.serialize_b64(value, to_str=True)
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return None
|
||||
return utils.json_dumps(value, with_type=True)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
def python_value(self, value):
|
||||
if self._serialized_type == SerializedType.PICKLE:
|
||||
return utils.deserialize_b64(value)
|
||||
elif self._serialized_type == SerializedType.JSON:
|
||||
if value is None:
|
||||
return {}
|
||||
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||
else:
|
||||
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||
|
||||
|
||||
def is_continuous_field(cls: typing.Type) -> bool:
|
||||
if cls in CONTINUOUS_FIELD_TYPE:
|
||||
return True
|
||||
for p in cls.__bases__:
|
||||
if p in CONTINUOUS_FIELD_TYPE:
|
||||
return True
|
||||
elif p != Field and p != object:
|
||||
if is_continuous_field(p):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def auto_date_timestamp_field():
|
||||
return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
|
||||
|
||||
|
||||
def auto_date_timestamp_db_field():
|
||||
return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
|
||||
|
||||
|
||||
def remove_field_name_prefix(field_name):
|
||||
return field_name[2:] if field_name.startswith('f_') else field_name
|
||||
|
||||
|
||||
class BaseModel(Model):
|
||||
create_time = BigIntegerField(null=True)
|
||||
create_date = DateTimeField(null=True)
|
||||
update_time = BigIntegerField(null=True)
|
||||
update_date = DateTimeField(null=True)
|
||||
|
||||
def to_json(self):
|
||||
# This function is obsolete
|
||||
return self.to_dict()
|
||||
|
||||
def to_dict(self):
|
||||
return self.__dict__['__data__']
|
||||
|
||||
def to_human_model_dict(self, only_primary_with: list = None):
|
||||
model_dict = self.__dict__['__data__']
|
||||
|
||||
if not only_primary_with:
|
||||
return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
|
||||
|
||||
human_model_dict = {}
|
||||
for k in self._meta.primary_key.field_names:
|
||||
human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
|
||||
for k in only_primary_with:
|
||||
human_model_dict[k] = model_dict[f'f_{k}']
|
||||
return human_model_dict
|
||||
|
||||
@property
|
||||
def meta(self) -> Metadata:
|
||||
return self._meta
|
||||
|
||||
@classmethod
|
||||
def get_primary_keys_name(cls):
|
||||
return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
|
||||
cls._meta.primary_key.name]
|
||||
|
||||
@classmethod
|
||||
def getter_by(cls, attr):
|
||||
return operator.attrgetter(attr)(cls)
|
||||
|
||||
@classmethod
|
||||
def query(cls, reverse=None, order_by=None, **kwargs):
|
||||
filters = []
|
||||
for f_n, f_v in kwargs.items():
|
||||
attr_name = '%s' % f_n
|
||||
if not hasattr(cls, attr_name) or f_v is None:
|
||||
continue
|
||||
if type(f_v) in {list, set}:
|
||||
f_v = list(f_v)
|
||||
if is_continuous_field(type(getattr(cls, attr_name))):
|
||||
if len(f_v) == 2:
|
||||
for i, v in enumerate(f_v):
|
||||
if isinstance(v, str) and f_n in auto_date_timestamp_field():
|
||||
# time type: %Y-%m-%d %H:%M:%S
|
||||
f_v[i] = utils.date_string_to_timestamp(v)
|
||||
lt_value = f_v[0]
|
||||
gt_value = f_v[1]
|
||||
if lt_value is not None and gt_value is not None:
|
||||
filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
|
||||
elif lt_value is not None:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
|
||||
elif gt_value is not None:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
|
||||
else:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) << f_v)
|
||||
else:
|
||||
filters.append(operator.attrgetter(attr_name)(cls) == f_v)
|
||||
if filters:
|
||||
query_records = cls.select().where(*filters)
|
||||
if reverse is not None:
|
||||
if not order_by or not hasattr(cls, f"{order_by}"):
|
||||
order_by = "create_time"
|
||||
if reverse is True:
|
||||
query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc())
|
||||
elif reverse is False:
|
||||
query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc())
|
||||
return [query_record for query_record in query_records]
|
||||
else:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def insert(cls, __data=None, **insert):
|
||||
if isinstance(__data, dict) and __data:
|
||||
__data[cls._meta.combined["create_time"]] = utils.current_timestamp()
|
||||
if insert:
|
||||
insert["create_time"] = utils.current_timestamp()
|
||||
|
||||
return super().insert(__data, **insert)
|
||||
|
||||
# update and insert will call this method
|
||||
@classmethod
|
||||
def _normalize_data(cls, data, kwargs):
|
||||
normalized = super()._normalize_data(data, kwargs)
|
||||
if not normalized:
|
||||
return {}
|
||||
|
||||
normalized[cls._meta.combined["update_time"]] = utils.current_timestamp()
|
||||
|
||||
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
|
||||
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
|
||||
cls._meta.combined[f"{f_n}_time"] in normalized and \
|
||||
normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
|
||||
normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(
|
||||
normalized[cls._meta.combined[f"{f_n}_time"]])
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
class JsonSerializedField(SerializedField):
|
||||
def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs):
|
||||
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
|
||||
object_pairs_hook=object_pairs_hook, **kwargs)
|
||||
|
||||
|
||||
@singleton
|
||||
class BaseDataBase:
|
||||
def __init__(self):
|
||||
database_config = DATABASE.copy()
|
||||
db_name = database_config.pop("name")
|
||||
self.database_connection = PooledMySQLDatabase(db_name, **database_config)
|
||||
stat_logger.info('init mysql database on cluster mode successfully')
|
||||
|
||||
|
||||
class DatabaseLock:
|
||||
def __init__(self, lock_name, timeout=10, db=None):
|
||||
self.lock_name = lock_name
|
||||
self.timeout = int(timeout)
|
||||
self.db = db if db else DB
|
||||
|
||||
def lock(self):
|
||||
# SQL parameters only support %s format placeholders
|
||||
cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
|
||||
ret = cursor.fetchone()
|
||||
if ret[0] == 0:
|
||||
raise Exception(f'acquire mysql lock {self.lock_name} timeout')
|
||||
elif ret[0] == 1:
|
||||
return True
|
||||
else:
|
||||
raise Exception(f'failed to acquire lock {self.lock_name}')
|
||||
|
||||
def unlock(self):
|
||||
cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,))
|
||||
ret = cursor.fetchone()
|
||||
if ret[0] == 0:
|
||||
raise Exception(f'mysql lock {self.lock_name} was not established by this thread')
|
||||
elif ret[0] == 1:
|
||||
return True
|
||||
else:
|
||||
raise Exception(f'mysql lock {self.lock_name} does not exist')
|
||||
|
||||
def __enter__(self):
|
||||
if isinstance(self.db, PooledMySQLDatabase):
|
||||
self.lock()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if isinstance(self.db, PooledMySQLDatabase):
|
||||
self.unlock()
|
||||
|
||||
def __call__(self, func):
|
||||
@wraps(func)
|
||||
def magic(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return magic
|
||||
|
||||
|
||||
DB = BaseDataBase().database_connection
|
||||
DB.lock = DatabaseLock
|
||||
|
||||
|
||||
def close_connection():
|
||||
try:
|
||||
if DB:
|
||||
DB.close()
|
||||
except Exception as e:
|
||||
LOGGER.exception(e)
|
||||
|
||||
|
||||
class DataBaseModel(BaseModel):
|
||||
class Meta:
|
||||
database = DB
|
||||
|
||||
|
||||
@DB.connection_context()
|
||||
def init_database_tables():
|
||||
members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
|
||||
table_objs = []
|
||||
create_failed_list = []
|
||||
for name, obj in members:
|
||||
if obj != DataBaseModel and issubclass(obj, DataBaseModel):
|
||||
table_objs.append(obj)
|
||||
LOGGER.info(f"start create table {obj.__name__}")
|
||||
try:
|
||||
obj.create_table()
|
||||
LOGGER.info(f"create table success: {obj.__name__}")
|
||||
except Exception as e:
|
||||
LOGGER.exception(e)
|
||||
create_failed_list.append(obj.__name__)
|
||||
if create_failed_list:
|
||||
LOGGER.info(f"create tables failed: {create_failed_list}")
|
||||
raise Exception(f"create tables failed: {create_failed_list}")
|
||||
|
||||
|
||||
def fill_db_model_object(model_object, human_model_dict):
|
||||
for k, v in human_model_dict.items():
|
||||
attr_name = '%s' % k
|
||||
if hasattr(model_object.__class__, attr_name):
|
||||
setattr(model_object, attr_name, v)
|
||||
return model_object
|
||||
|
||||
|
||||
class User(DataBaseModel, UserMixin):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
access_token = CharField(max_length=255, null=True)
|
||||
nickname = CharField(max_length=100, null=False, help_text="nicky name")
|
||||
password = CharField(max_length=255, null=True, help_text="password")
|
||||
email = CharField(max_length=255, null=False, help_text="email", index=True)
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese")
|
||||
color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Dark")
|
||||
last_login_time = DateTimeField(null=True)
|
||||
is_authenticated = CharField(max_length=1, null=False, default="1")
|
||||
is_active = CharField(max_length=1, null=False, default="1")
|
||||
is_anonymous = CharField(max_length=1, null=False, default="0")
|
||||
login_channel = CharField(null=True, help_text="from which user login")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
is_superuser = BooleanField(null=True, help_text="is root", default=False)
|
||||
|
||||
def __str__(self):
|
||||
return self.email
|
||||
|
||||
def get_id(self):
|
||||
jwt = Serializer(secret_key=SECRET_KEY)
|
||||
return jwt.dumps(str(self.access_token))
|
||||
|
||||
class Meta:
|
||||
db_table = "user"
|
||||
|
||||
|
||||
class Tenant(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
name = CharField(max_length=100, null=True, help_text="Tenant name")
|
||||
public_key = CharField(max_length=255, null=True)
|
||||
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
|
||||
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
|
||||
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
|
||||
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
|
||||
parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "tenant"
|
||||
|
||||
|
||||
class UserTenant(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
user_id = CharField(max_length=32, null=False)
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
role = CharField(max_length=32, null=False, help_text="UserTenantRole")
|
||||
invited_by = CharField(max_length=32, null=False)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "user_tenant"
|
||||
|
||||
|
||||
class InvitationCode(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
code = CharField(max_length=32, null=False)
|
||||
visit_time = DateTimeField(null=True)
|
||||
user_id = CharField(max_length=32, null=True)
|
||||
tenant_id = CharField(max_length=32, null=True)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "invitation_code"
|
||||
|
||||
|
||||
class LLMFactories(DataBaseModel):
|
||||
name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
|
||||
logo = TextField(null=True, help_text="llm logo base64")
|
||||
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
class Meta:
|
||||
db_table = "llm_factories"
|
||||
|
||||
|
||||
class LLM(DataBaseModel):
|
||||
# defautlt LLMs for every users
|
||||
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
|
||||
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
||||
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
def __str__(self):
|
||||
return self.llm_name
|
||||
|
||||
class Meta:
|
||||
db_table = "llm"
|
||||
|
||||
|
||||
class TenantLLM(DataBaseModel):
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
|
||||
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||
llm_name = CharField(max_length=128, null=False, help_text="LLM name")
|
||||
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||
|
||||
def __str__(self):
|
||||
return self.llm_name
|
||||
|
||||
class Meta:
|
||||
db_table = "tenant_llm"
|
||||
primary_key = CompositeKey('tenant_id', 'llm_factory')
|
||||
|
||||
|
||||
class Knowledgebase(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
name = CharField(max_length=128, null=False, help_text="KB name", index=True)
|
||||
description = TextField(null=True, help_text="KB description")
|
||||
permission = CharField(max_length=16, null=False, help_text="me|team")
|
||||
created_by = CharField(max_length=32, null=False)
|
||||
doc_num = IntegerField(default=0)
|
||||
embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID")
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
class Meta:
|
||||
db_table = "knowledgebase"
|
||||
|
||||
|
||||
class Document(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||
kb_id = CharField(max_length=256, null=False, index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
|
||||
type = CharField(max_length=32, null=False, help_text="file extension")
|
||||
created_by = CharField(max_length=32, null=False, help_text="who created it")
|
||||
name = CharField(max_length=255, null=True, help_text="file name", index=True)
|
||||
location = CharField(max_length=255, null=True, help_text="where dose it store")
|
||||
size = IntegerField(default=0)
|
||||
token_num = IntegerField(default=0)
|
||||
chunk_num = IntegerField(default=0)
|
||||
progress = FloatField(default=0)
|
||||
progress_msg = CharField(max_length=255, null=True, help_text="process message", default="")
|
||||
process_begin_at = DateTimeField(null=True)
|
||||
process_duation = FloatField(default=0)
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "document"
|
||||
|
||||
|
||||
class Dialog(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
tenant_id = CharField(max_length=32, null=False)
|
||||
name = CharField(max_length=255, null=True, help_text="dialog application name")
|
||||
description = TextField(null=True, help_text="Dialog description")
|
||||
icon = CharField(max_length=16, null=False, help_text="dialog icon")
|
||||
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
|
||||
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
|
||||
llm_setting_type = CharField(max_length=8, null=False, help_text="Creative|Precise|Evenly|Custom",
|
||||
default="Creative")
|
||||
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
|
||||
"presence_penalty": 0.4, "max_tokens": 215})
|
||||
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
|
||||
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
|
||||
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
class Meta:
|
||||
db_table = "dialog"
|
||||
|
||||
|
||||
class DialogKb(DataBaseModel):
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
kb_id = CharField(max_length=32, null=False)
|
||||
|
||||
class Meta:
|
||||
db_table = "dialog_kb"
|
||||
primary_key = CompositeKey('dialog_id', 'kb_id')
|
||||
|
||||
|
||||
class Conversation(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
name = CharField(max_length=255, null=True, help_text="converastion name")
|
||||
message = JSONField(null=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "conversation"
|
||||
|
||||
|
||||
"""
|
||||
class Job(DataBaseModel):
|
||||
# multi-party common configuration
|
||||
f_user_id = CharField(max_length=25, null=True)
|
||||
f_job_id = CharField(max_length=25, index=True)
|
||||
f_name = CharField(max_length=500, null=True, default='')
|
||||
f_description = TextField(null=True, default='')
|
||||
f_tag = CharField(max_length=50, null=True, default='')
|
||||
f_dsl = JSONField()
|
||||
f_runtime_conf = JSONField()
|
||||
f_runtime_conf_on_party = JSONField()
|
||||
f_train_runtime_conf = JSONField(null=True)
|
||||
f_roles = JSONField()
|
||||
f_initiator_role = CharField(max_length=50)
|
||||
f_initiator_party_id = CharField(max_length=50)
|
||||
f_status = CharField(max_length=50)
|
||||
f_status_code = IntegerField(null=True)
|
||||
f_user = JSONField()
|
||||
# this party configuration
|
||||
f_role = CharField(max_length=50, index=True)
|
||||
f_party_id = CharField(max_length=10, index=True)
|
||||
f_is_initiator = BooleanField(null=True, default=False)
|
||||
f_progress = IntegerField(null=True, default=0)
|
||||
f_ready_signal = BooleanField(default=False)
|
||||
f_ready_time = BigIntegerField(null=True)
|
||||
f_cancel_signal = BooleanField(default=False)
|
||||
f_cancel_time = BigIntegerField(null=True)
|
||||
f_rerun_signal = BooleanField(default=False)
|
||||
f_end_scheduling_updates = IntegerField(null=True, default=0)
|
||||
|
||||
f_engine_name = CharField(max_length=50, null=True)
|
||||
f_engine_type = CharField(max_length=10, null=True)
|
||||
f_cores = IntegerField(default=0)
|
||||
f_memory = IntegerField(default=0) # MB
|
||||
f_remaining_cores = IntegerField(default=0)
|
||||
f_remaining_memory = IntegerField(default=0) # MB
|
||||
f_resource_in_use = BooleanField(default=False)
|
||||
f_apply_resource_time = BigIntegerField(null=True)
|
||||
f_return_resource_time = BigIntegerField(null=True)
|
||||
|
||||
f_inheritance_info = JSONField(null=True)
|
||||
f_inheritance_status = CharField(max_length=50, null=True)
|
||||
|
||||
f_start_time = BigIntegerField(null=True)
|
||||
f_start_date = DateTimeField(null=True)
|
||||
f_end_time = BigIntegerField(null=True)
|
||||
f_end_date = DateTimeField(null=True)
|
||||
f_elapsed = BigIntegerField(null=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "t_job"
|
||||
primary_key = CompositeKey('f_job_id', 'f_role', 'f_party_id')
|
||||
|
||||
|
||||
|
||||
class PipelineComponentMeta(DataBaseModel):
|
||||
f_model_id = CharField(max_length=100, index=True)
|
||||
f_model_version = CharField(max_length=100, index=True)
|
||||
f_role = CharField(max_length=50, index=True)
|
||||
f_party_id = CharField(max_length=10, index=True)
|
||||
f_component_name = CharField(max_length=100, index=True)
|
||||
f_component_module_name = CharField(max_length=100)
|
||||
f_model_alias = CharField(max_length=100, index=True)
|
||||
f_model_proto_index = JSONField(null=True)
|
||||
f_run_parameters = JSONField(null=True)
|
||||
f_archive_sha256 = CharField(max_length=100, null=True)
|
||||
f_archive_from_ip = CharField(max_length=100, null=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 't_pipeline_component_meta'
|
||||
indexes = (
|
||||
(('f_model_id', 'f_model_version', 'f_role', 'f_party_id', 'f_component_name'), True),
|
||||
)
|
||||
|
||||
|
||||
"""
|
157
web_server/db/db_services.py
Normal file
157
web_server/db/db_services.py
Normal file
@ -0,0 +1,157 @@
|
||||
#
|
||||
# Copyright 2021 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import abc
|
||||
import json
|
||||
import time
|
||||
from functools import wraps
|
||||
from shortuuid import ShortUUID
|
||||
|
||||
from web_server.versions import get_fate_version
|
||||
|
||||
from web_server.errors.error_services import *
|
||||
from web_server.settings import (
|
||||
GRPC_PORT, HOST, HTTP_PORT,
|
||||
RANDOM_INSTANCE_ID, stat_logger,
|
||||
)
|
||||
|
||||
|
||||
instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}'
|
||||
server_instance = (
|
||||
f'{HOST}:{GRPC_PORT}',
|
||||
json.dumps({
|
||||
'instance_id': instance_id,
|
||||
'timestamp': round(time.time() * 1000),
|
||||
'version': get_fate_version() or '',
|
||||
'host': HOST,
|
||||
'grpc_port': GRPC_PORT,
|
||||
'http_port': HTTP_PORT,
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
def check_service_supported(method):
|
||||
"""Decorator to check if `service_name` is supported.
|
||||
The attribute `supported_services` MUST be defined in class.
|
||||
The first and second arguments of `method` MUST be `self` and `service_name`.
|
||||
|
||||
:param Callable method: The class method.
|
||||
:return: The inner wrapper function.
|
||||
:rtype: Callable
|
||||
"""
|
||||
@wraps(method)
|
||||
def magic(self, service_name, *args, **kwargs):
|
||||
if service_name not in self.supported_services:
|
||||
raise ServiceNotSupported(service_name=service_name)
|
||||
return method(self, service_name, *args, **kwargs)
|
||||
return magic
|
||||
|
||||
|
||||
class ServicesDB(abc.ABC):
|
||||
"""Database for storage service urls.
|
||||
Abstract base class for the real backends.
|
||||
|
||||
"""
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def supported_services(self):
|
||||
"""The names of supported services.
|
||||
The returned list SHOULD contain `fateflow` (model download) and `servings` (FATE-Serving).
|
||||
|
||||
:return: The service names.
|
||||
:rtype: list
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_serving(self):
|
||||
pass
|
||||
|
||||
def get_serving(self):
|
||||
|
||||
try:
|
||||
return self._get_serving()
|
||||
except ServicesError as e:
|
||||
stat_logger.exception(e)
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def _insert(self, service_name, service_url, value=''):
|
||||
pass
|
||||
|
||||
@check_service_supported
|
||||
def insert(self, service_name, service_url, value=''):
|
||||
"""Insert a service url to database.
|
||||
|
||||
:param str service_name: The service name.
|
||||
:param str service_url: The service url.
|
||||
:return: None
|
||||
"""
|
||||
try:
|
||||
self._insert(service_name, service_url, value)
|
||||
except ServicesError as e:
|
||||
stat_logger.exception(e)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _delete(self, service_name, service_url):
|
||||
pass
|
||||
|
||||
@check_service_supported
|
||||
def delete(self, service_name, service_url):
|
||||
"""Delete a service url from database.
|
||||
|
||||
:param str service_name: The service name.
|
||||
:param str service_url: The service url.
|
||||
:return: None
|
||||
"""
|
||||
try:
|
||||
self._delete(service_name, service_url)
|
||||
except ServicesError as e:
|
||||
stat_logger.exception(e)
|
||||
|
||||
def register_flow(self):
|
||||
"""Call `self.insert` for insert the flow server address to databae.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
self.insert('flow-server', *server_instance)
|
||||
|
||||
def unregister_flow(self):
|
||||
"""Call `self.delete` for delete the flow server address from databae.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
self.delete('flow-server', server_instance[0])
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_urls(self, service_name, with_values=False):
|
||||
pass
|
||||
|
||||
@check_service_supported
|
||||
def get_urls(self, service_name, with_values=False):
|
||||
"""Query service urls from database. The urls may belong to other nodes.
|
||||
Currently, only `fateflow` (model download) urls and `servings` (FATE-Serving) urls are supported.
|
||||
`fateflow` is a url containing scheme, host, port and path,
|
||||
while `servings` only contains host and port.
|
||||
|
||||
:param str service_name: The service name.
|
||||
:return: The service urls.
|
||||
:rtype: list
|
||||
"""
|
||||
try:
|
||||
return self._get_urls(service_name, with_values)
|
||||
except ServicesError as e:
|
||||
stat_logger.exception(e)
|
||||
return []
|
131
web_server/db/db_utils.py
Normal file
131
web_server/db/db_utils.py
Normal file
@ -0,0 +1,131 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
from web_server.utils import current_timestamp, timestamp_to_date
|
||||
|
||||
from web_server.db.db_models import DB, DataBaseModel
|
||||
from web_server.db.runtime_config import RuntimeConfig
|
||||
from web_server.utils.log_utils import getLogger
|
||||
from enum import Enum
|
||||
|
||||
|
||||
LOGGER = getLogger()
|
||||
|
||||
|
||||
@DB.connection_context()
|
||||
def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
|
||||
DB.create_tables([model])
|
||||
|
||||
current_time = current_timestamp()
|
||||
current_date = timestamp_to_date(current_time)
|
||||
|
||||
for data in data_source:
|
||||
if 'f_create_time' not in data:
|
||||
data['f_create_time'] = current_time
|
||||
data['f_create_date'] = timestamp_to_date(data['f_create_time'])
|
||||
data['f_update_time'] = current_time
|
||||
data['f_update_date'] = current_date
|
||||
|
||||
preserve = tuple(data_source[0].keys() - {'f_create_time', 'f_create_date'})
|
||||
|
||||
batch_size = 50 if RuntimeConfig.USE_LOCAL_DATABASE else 1000
|
||||
|
||||
for i in range(0, len(data_source), batch_size):
|
||||
with DB.atomic():
|
||||
query = model.insert_many(data_source[i:i + batch_size])
|
||||
if replace_on_conflict:
|
||||
query = query.on_conflict(preserve=preserve)
|
||||
query.execute()
|
||||
|
||||
|
||||
def get_dynamic_db_model(base, job_id):
|
||||
return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id)))
|
||||
|
||||
|
||||
def get_dynamic_tracking_table_index(job_id):
|
||||
return job_id[:8]
|
||||
|
||||
|
||||
def fill_db_model_object(model_object, human_model_dict):
|
||||
for k, v in human_model_dict.items():
|
||||
attr_name = 'f_%s' % k
|
||||
if hasattr(model_object.__class__, attr_name):
|
||||
setattr(model_object, attr_name, v)
|
||||
return model_object
|
||||
|
||||
|
||||
# https://docs.peewee-orm.com/en/latest/peewee/query_operators.html
|
||||
supported_operators = {
|
||||
'==': operator.eq,
|
||||
'<': operator.lt,
|
||||
'<=': operator.le,
|
||||
'>': operator.gt,
|
||||
'>=': operator.ge,
|
||||
'!=': operator.ne,
|
||||
'<<': operator.lshift,
|
||||
'>>': operator.rshift,
|
||||
'%': operator.mod,
|
||||
'**': operator.pow,
|
||||
'^': operator.xor,
|
||||
'~': operator.inv,
|
||||
}
|
||||
|
||||
def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
|
||||
expression = []
|
||||
|
||||
for field, value in query.items():
|
||||
if not isinstance(value, (list, tuple)):
|
||||
value = ('==', value)
|
||||
op, *val = value
|
||||
|
||||
field = getattr(model, f'f_{field}')
|
||||
value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val)
|
||||
expression.append(value)
|
||||
|
||||
return reduce(operator.iand, expression)
|
||||
|
||||
|
||||
def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0,
|
||||
query: dict = None, order_by: Union[str, list, tuple] = None):
|
||||
data = model.select()
|
||||
if query:
|
||||
data = data.where(query_dict2expression(model, query))
|
||||
count = data.count()
|
||||
|
||||
if not order_by:
|
||||
order_by = 'create_time'
|
||||
if not isinstance(order_by, (list, tuple)):
|
||||
order_by = (order_by, 'asc')
|
||||
order_by, order = order_by
|
||||
order_by = getattr(model, f'f_{order_by}')
|
||||
order_by = getattr(order_by, order)()
|
||||
data = data.order_by(order_by)
|
||||
|
||||
if limit > 0:
|
||||
data = data.limit(limit)
|
||||
if offset > 0:
|
||||
data = data.offset(offset)
|
||||
|
||||
return list(data), count
|
||||
|
||||
|
||||
class StatusEnum(Enum):
|
||||
# 样本可用状态
|
||||
VALID = "1"
|
||||
IN_VALID = "0"
|
141
web_server/db/init_data.py
Normal file
141
web_server/db/init_data.py
Normal file
@ -0,0 +1,141 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from web_server.db import LLMType
|
||||
from web_server.db.db_models import init_database_tables as init_web_db
|
||||
from web_server.db.services import UserService
|
||||
from web_server.db.services.llm_service import LLMFactoriesService, LLMService
|
||||
|
||||
|
||||
def init_superuser():
|
||||
user_info = {
|
||||
"id": uuid.uuid1().hex,
|
||||
"password": "admin",
|
||||
"nickname": "admin",
|
||||
"is_superuser": True,
|
||||
"email": "kai.hu@infiniflow.org",
|
||||
"creator": "system",
|
||||
"status": "1",
|
||||
}
|
||||
UserService.save(**user_info)
|
||||
|
||||
|
||||
def init_llm_factory():
|
||||
factory_infos = [{
|
||||
"name": "OpenAI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "通义千问",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "智普AI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},{
|
||||
"name": "文心一言",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||
"status": "1",
|
||||
},
|
||||
]
|
||||
llm_infos = [{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-3.5-turbo",
|
||||
"tags": "LLM,CHAT,4K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-3.5-turbo-16k-0613",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "text-embedding-ada-002",
|
||||
"tags": "TEXT EMBEDDING,8K",
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "whisper-1",
|
||||
"tags": "SPEECH2TEXT",
|
||||
"model_type": LLMType.SPEECH2TEXT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4-32k",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[0]["name"],
|
||||
"llm_name": "gpt-4-vision-preview",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||
"model_type": LLMType.IMAGE2TEXT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-turbo",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen-plus",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"model_type": LLMType.CHAT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "text-embedding-v2",
|
||||
"tags": "TEXT EMBEDDING,2K",
|
||||
"model_type": LLMType.EMBEDDING.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "paraformer-realtime-8k-v1",
|
||||
"tags": "SPEECH2TEXT",
|
||||
"model_type": LLMType.SPEECH2TEXT.value
|
||||
},{
|
||||
"fid": factory_infos[1]["name"],
|
||||
"llm_name": "qwen_vl_chat_v1",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||
"model_type": LLMType.IMAGE2TEXT.value
|
||||
},
|
||||
]
|
||||
for info in factory_infos:
|
||||
LLMFactoriesService.save(**info)
|
||||
for info in llm_infos:
|
||||
LLMService.save(**info)
|
||||
|
||||
|
||||
def init_web_data():
|
||||
start_time = time.time()
|
||||
if not UserService.get_all().count():
|
||||
init_superuser()
|
||||
|
||||
if not LLMService.get_all().count():init_llm_factory()
|
||||
|
||||
print("init web data success:{}".format(time.time() - start_time))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
init_web_db()
|
||||
init_web_data()
|
21
web_server/db/operatioins.py
Normal file
21
web_server/db/operatioins.py
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import operator
|
||||
import time
|
||||
import typing
|
||||
from web_server.utils.log_utils import sql_logger
|
||||
import peewee
|
27
web_server/db/reload_config_base.py
Normal file
27
web_server/db/reload_config_base.py
Normal file
@ -0,0 +1,27 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
class ReloadConfigBase:
|
||||
@classmethod
|
||||
def get_all(cls):
|
||||
configs = {}
|
||||
for k, v in cls.__dict__.items():
|
||||
if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"):
|
||||
configs[k] = v
|
||||
return configs
|
||||
|
||||
@classmethod
|
||||
def get(cls, config_name):
|
||||
return getattr(cls, config_name) if hasattr(cls, config_name) else None
|
54
web_server/db/runtime_config.py
Normal file
54
web_server/db/runtime_config.py
Normal file
@ -0,0 +1,54 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from web_server.versions import get_versions
|
||||
from .reload_config_base import ReloadConfigBase
|
||||
|
||||
|
||||
class RuntimeConfig(ReloadConfigBase):
|
||||
DEBUG = None
|
||||
WORK_MODE = None
|
||||
HTTP_PORT = None
|
||||
JOB_SERVER_HOST = None
|
||||
JOB_SERVER_VIP = None
|
||||
ENV = dict()
|
||||
SERVICE_DB = None
|
||||
LOAD_CONFIG_MANAGER = False
|
||||
|
||||
@classmethod
|
||||
def init_config(cls, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(cls, k):
|
||||
setattr(cls, k, v)
|
||||
|
||||
@classmethod
|
||||
def init_env(cls):
|
||||
cls.ENV.update(get_versions())
|
||||
|
||||
@classmethod
|
||||
def load_config_manager(cls):
|
||||
cls.LOAD_CONFIG_MANAGER = True
|
||||
|
||||
@classmethod
|
||||
def get_env(cls, key):
|
||||
return cls.ENV.get(key, None)
|
||||
|
||||
@classmethod
|
||||
def get_all_env(cls):
|
||||
return cls.ENV
|
||||
|
||||
@classmethod
|
||||
def set_service_db(cls, service_db):
|
||||
cls.SERVICE_DB = service_db
|
164
web_server/db/service_registry.py
Normal file
164
web_server/db/service_registry.py
Normal file
@ -0,0 +1,164 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from web_server import utils
|
||||
from .db_models import DB, ServiceRegistryInfo, ServerRegistryInfo
|
||||
from .reload_config_base import ReloadConfigBase
|
||||
|
||||
|
||||
class ServiceRegistry(ReloadConfigBase):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def load_service(cls, **kwargs) -> [ServiceRegistryInfo]:
|
||||
service_registry_list = ServiceRegistryInfo.query(**kwargs)
|
||||
return [service for service in service_registry_list]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save_service_info(cls, server_name, service_name, uri, method="POST", server_info=None, params=None, data=None, headers=None, protocol="http"):
|
||||
if not server_info:
|
||||
server_list = ServerRegistry.query_server_info_from_db(server_name=server_name)
|
||||
if not server_list:
|
||||
raise Exception(f"no found server {server_name}")
|
||||
server_info = server_list[0]
|
||||
url = f"{server_info.f_protocol}://{server_info.f_host}:{server_info.f_port}{uri}"
|
||||
else:
|
||||
url = f"{server_info.get('protocol', protocol)}://{server_info.get('host')}:{server_info.get('port')}{uri}"
|
||||
service_info = {
|
||||
"f_server_name": server_name,
|
||||
"f_service_name": service_name,
|
||||
"f_url": url,
|
||||
"f_method": method,
|
||||
"f_params": params if params else {},
|
||||
"f_data": data if data else {},
|
||||
"f_headers": headers if headers else {}
|
||||
}
|
||||
entity_model, status = ServiceRegistryInfo.get_or_create(
|
||||
f_server_name=server_name,
|
||||
f_service_name=service_name,
|
||||
defaults=service_info)
|
||||
if status is False:
|
||||
for key in service_info:
|
||||
setattr(entity_model, key, service_info[key])
|
||||
entity_model.save(force_insert=False)
|
||||
|
||||
|
||||
class ServerRegistry(ReloadConfigBase):
|
||||
FATEBOARD = None
|
||||
FATE_ON_STANDALONE = None
|
||||
FATE_ON_EGGROLL = None
|
||||
FATE_ON_SPARK = None
|
||||
MODEL_STORE_ADDRESS = None
|
||||
SERVINGS = None
|
||||
FATEMANAGER = None
|
||||
STUDIO = None
|
||||
|
||||
@classmethod
|
||||
def load(cls):
|
||||
cls.load_server_info_from_conf()
|
||||
cls.load_server_info_from_db()
|
||||
|
||||
@classmethod
|
||||
def load_server_info_from_conf(cls):
|
||||
path = Path(utils.file_utils.get_project_base_directory()) / 'conf' / utils.SERVICE_CONF
|
||||
conf = utils.file_utils.load_yaml_conf(path)
|
||||
if not isinstance(conf, dict):
|
||||
raise ValueError('invalid config file')
|
||||
|
||||
local_path = path.with_name(f'local.{utils.SERVICE_CONF}')
|
||||
if local_path.exists():
|
||||
local_conf = utils.file_utils.load_yaml_conf(local_path)
|
||||
if not isinstance(local_conf, dict):
|
||||
raise ValueError('invalid local config file')
|
||||
conf.update(local_conf)
|
||||
for k, v in conf.items():
|
||||
if isinstance(v, dict):
|
||||
setattr(cls, k.upper(), v)
|
||||
|
||||
@classmethod
|
||||
def register(cls, server_name, server_info):
|
||||
cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol=server_info.get("protocol", "http"))
|
||||
setattr(cls, server_name, server_info)
|
||||
|
||||
@classmethod
|
||||
def save(cls, service_config):
|
||||
update_server = {}
|
||||
for server_name, server_info in service_config.items():
|
||||
cls.parameter_check(server_info)
|
||||
api_info = server_info.pop("api", {})
|
||||
for service_name, info in api_info.items():
|
||||
ServiceRegistry.save_service_info(server_name, service_name, uri=info.get('uri'), method=info.get('method', 'POST'), server_info=server_info)
|
||||
cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol="http")
|
||||
setattr(cls, server_name.upper(), server_info)
|
||||
return update_server
|
||||
|
||||
@classmethod
|
||||
def parameter_check(cls, service_info):
|
||||
if "host" in service_info and "port" in service_info:
|
||||
cls.connection_test(service_info.get("host"), service_info.get("port"))
|
||||
|
||||
@classmethod
|
||||
def connection_test(cls, ip, port):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
result = s.connect_ex((ip, port))
|
||||
if result != 0:
|
||||
raise ConnectionRefusedError(f"connection refused: host {ip}, port {port}")
|
||||
|
||||
@classmethod
|
||||
def query(cls, service_name, default=None):
|
||||
service_info = getattr(cls, service_name, default)
|
||||
if not service_info:
|
||||
service_info = utils.get_base_config(service_name, default)
|
||||
return service_info
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query_server_info_from_db(cls, server_name=None) -> [ServerRegistryInfo]:
|
||||
if server_name:
|
||||
server_list = ServerRegistryInfo.select().where(ServerRegistryInfo.f_server_name==server_name.upper())
|
||||
else:
|
||||
server_list = ServerRegistryInfo.select()
|
||||
return [server for server in server_list]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def load_server_info_from_db(cls):
|
||||
for server in cls.query_server_info_from_db():
|
||||
server_info = {
|
||||
"host": server.f_host,
|
||||
"port": server.f_port,
|
||||
"protocol": server.f_protocol
|
||||
}
|
||||
setattr(cls, server.f_server_name.upper(), server_info)
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save_server_info_to_db(cls, server_name, host, port, protocol="http"):
|
||||
server_info = {
|
||||
"f_server_name": server_name,
|
||||
"f_host": host,
|
||||
"f_port": port,
|
||||
"f_protocol": protocol
|
||||
}
|
||||
entity_model, status = ServerRegistryInfo.get_or_create(
|
||||
f_server_name=server_name,
|
||||
defaults=server_info)
|
||||
if status is False:
|
||||
for key in server_info:
|
||||
setattr(entity_model, key, server_info[key])
|
||||
entity_model.save(force_insert=False)
|
38
web_server/db/services/__init__.py
Normal file
38
web_server/db/services/__init__.py
Normal file
@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import pathlib
|
||||
import re
|
||||
from .user_service import UserService
|
||||
|
||||
|
||||
def duplicate_name(query_func, **kwargs):
|
||||
fnm = kwargs["name"]
|
||||
objs = query_func(**kwargs)
|
||||
if not objs: return fnm
|
||||
ext = pathlib.Path(fnm).suffix #.jpg
|
||||
nm = re.sub(r"%s$"%ext, "", fnm)
|
||||
r = re.search(r"\([0-9]+\)$", nm)
|
||||
c = 0
|
||||
if r:
|
||||
c = int(r.group(1))
|
||||
nm = re.sub(r"\([0-9]+\)$", "", nm)
|
||||
c += 1
|
||||
nm = f"{nm}({c})"
|
||||
if ext: nm += f"{ext}"
|
||||
|
||||
kwargs["name"] = nm
|
||||
return duplicate_name(query_func, **kwargs)
|
||||
|
153
web_server/db/services/common_service.py
Normal file
153
web_server/db/services/common_service.py
Normal file
@ -0,0 +1,153 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
import peewee
|
||||
|
||||
from web_server.db.db_models import DB
|
||||
from web_server.utils import datetime_format
|
||||
|
||||
|
||||
class CommonService:
|
||||
model = None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
||||
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all(cls, cols=None, reverse=None, order_by=None):
|
||||
if cols:
|
||||
query_records = cls.model.select(*cols)
|
||||
else:
|
||||
query_records = cls.model.select()
|
||||
if reverse is not None:
|
||||
if not order_by or not hasattr(cls, order_by):
|
||||
order_by = "create_time"
|
||||
if reverse is True:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
|
||||
elif reverse is False:
|
||||
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
|
||||
return query_records
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get(cls, **kwargs):
|
||||
return cls.model.get(**kwargs)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_or_none(cls, **kwargs):
|
||||
try:
|
||||
return cls.model.get(**kwargs)
|
||||
except peewee.DoesNotExist:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def save(cls, **kwargs):
|
||||
#if "id" not in kwargs:
|
||||
# kwargs["id"] = get_uuid()
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def insert_many(cls, data_list, batch_size=100):
|
||||
with DB.atomic():
|
||||
for i in range(0, len(data_list), batch_size):
|
||||
cls.model.insert_many(data_list[i:i + batch_size]).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_many_by_id(cls, data_list):
|
||||
cur = datetime_format(datetime.now())
|
||||
with DB.atomic():
|
||||
for data in data_list:
|
||||
data["update_time"] = cur
|
||||
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_by_id(cls, pid, data):
|
||||
data["update_time"] = datetime_format(datetime.now())
|
||||
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_id(cls, pid):
|
||||
try:
|
||||
obj = cls.model.query(id=pid)[0]
|
||||
return True, obj
|
||||
except Exception as e:
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_ids(cls, pids, cols=None):
|
||||
if cols:
|
||||
objs = cls.model.select(*cols)
|
||||
else:
|
||||
objs = cls.model.select()
|
||||
return objs.where(cls.model.id.in_(pids))
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def delete_by_id(cls, pid):
|
||||
return cls.model.delete().where(cls.model.id == pid).execute()
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_delete(cls, filters):
|
||||
with DB.atomic():
|
||||
num = cls.model.delete().where(*filters).execute()
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_update(cls, filters, update_data):
|
||||
with DB.atomic():
|
||||
cls.model.update(update_data).where(*filters).execute()
|
||||
|
||||
@staticmethod
|
||||
def cut_list(tar_list, n):
|
||||
length = len(tar_list)
|
||||
arr = range(length)
|
||||
result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
|
||||
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
|
||||
if not filters:
|
||||
filters = []
|
||||
res_list = []
|
||||
if cols:
|
||||
for i in in_filters_tuple_list:
|
||||
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
|
||||
if query_records:
|
||||
res_list.extend([query_record for query_record in query_records])
|
||||
else:
|
||||
for i in in_filters_tuple_list:
|
||||
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
|
||||
if query_records:
|
||||
res_list.extend([query_record for query_record in query_records])
|
||||
return res_list
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user