mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-24 07:00:02 +08:00
build python version rag-flow (#21)
* clean rust version project * clean rust version project * build python version rag-flow
This commit is contained in:
parent
db8cae3f1e
commit
30791976d5
@ -1,9 +0,0 @@
|
|||||||
# Database
|
|
||||||
HOST=127.0.0.1
|
|
||||||
PORT=8000
|
|
||||||
DATABASE_URL="postgresql://infiniflow:infiniflow@localhost/docgpt"
|
|
||||||
|
|
||||||
# S3 Storage
|
|
||||||
MINIO_HOST="127.0.0.1:9000"
|
|
||||||
MINIO_USR="infiniflow"
|
|
||||||
MINIO_PWD="infiniflow_docgpt"
|
|
42
Cargo.toml
42
Cargo.toml
@ -1,42 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "doc_gpt"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
actix-web = "4.3.1"
|
|
||||||
actix-rt = "2.8.0"
|
|
||||||
actix-files = "0.6.2"
|
|
||||||
actix-multipart = "0.4"
|
|
||||||
actix-session = { version = "0.5" }
|
|
||||||
actix-identity = { version = "0.4" }
|
|
||||||
actix-web-httpauth = { version = "0.6" }
|
|
||||||
actix-ws = "0.2.5"
|
|
||||||
uuid = { version = "1.6.1", features = [
|
|
||||||
"v4",
|
|
||||||
"fast-rng",
|
|
||||||
"macro-diagnostics",
|
|
||||||
] }
|
|
||||||
thiserror = "1.0"
|
|
||||||
postgres = "0.19.7"
|
|
||||||
sea-orm = { version = "0.12.9", features = ["sqlx-postgres", "runtime-tokio-native-tls", "macros"] }
|
|
||||||
serde = { version = "1", features = ["derive"] }
|
|
||||||
serde_json = "1.0"
|
|
||||||
tracing-subscriber = "0.3.18"
|
|
||||||
dotenvy = "0.15.7"
|
|
||||||
listenfd = "1.0.1"
|
|
||||||
chrono = "0.4.31"
|
|
||||||
migration = { path = "./migration" }
|
|
||||||
minio = "0.1.0"
|
|
||||||
futures-util = "0.3.29"
|
|
||||||
actix-multipart-extract = "0.1.5"
|
|
||||||
regex = "1.10.2"
|
|
||||||
tokio = { version = "1.35.1", features = ["rt", "time", "macros"] }
|
|
||||||
|
|
||||||
[[bin]]
|
|
||||||
name = "doc_gpt"
|
|
||||||
|
|
||||||
[workspace]
|
|
||||||
members = [".", "migration"]
|
|
0
python/conf/mapping.json → conf/mapping.json
Executable file → Normal file
0
python/conf/mapping.json → conf/mapping.json
Executable file → Normal file
30
conf/private.pem
Normal file
30
conf/private.pem
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
-----BEGIN RSA PRIVATE KEY-----
|
||||||
|
Proc-Type: 4,ENCRYPTED
|
||||||
|
DEK-Info: DES-EDE3-CBC,EFF8327C41E531AD
|
||||||
|
|
||||||
|
7jdPFDAA6fiTzOIU7XGzKuT324JKZEcK5vBRJqBkA5XO6ENN1wLdhh3zQbl1Ejfv
|
||||||
|
KMSUIgbtQEJB4bvOzS//okbZa1vCNYuTS/NGcpKUnhqdOmAL3hl/kOtOLLjTZrwo
|
||||||
|
3KX8iujLH7wQ64GxArtpUuaFq1k0whN1BB5RGJp3IO/L6pMpSWVRKO+JPUrD1Ujr
|
||||||
|
XA/LUKQJaZtXVUVOYPtIwbyqPsh93QBetJnRwwV3gNOwGpcX2jDpyTxDUkLJCPPg
|
||||||
|
6Hw0pwlQEd8A11sjxCBbASwLeJO1L0w69QiX9chyOkZ+sfDsVpPt/wf1NexA7Cdj
|
||||||
|
9uifJ4JGbby39QD6mInZGtnRzQRdafjuXlBR2I0Qa7fBRu8QsfhmLbWZfWno7j08
|
||||||
|
4bAAoqB1vRNfSu8LVJXdEEh/HKuwu11pgRr5eH8WQ3hJg+Y2k7zDHpp1VaHL7/Kn
|
||||||
|
S+aN5bhQ4Xt0Ujdi1+rsmNchnF6LWsDezHWJeWUM6X7dJnqIBl8oCyghbghT8Tyw
|
||||||
|
aEKWXc2+7FsP5yd0NfG3PFYOLdLgfI43pHTAv5PEQ47w9r1XOwfblKKBUDEzaput
|
||||||
|
T3t5wQ6wxdyhRxeO4arCHfe/i+j3fzvhlwgbuwrmrkWGWSS86eMTaoGM8+uUrHv0
|
||||||
|
6TbU0tj6DKKUslVk1dCHh9TnmNsXZuLJkceZF38PSKNxhzudU8OTtzhS0tFL91HX
|
||||||
|
vo7N+XdiGMs8oOSpjE6RPlhFhVAKGJpXwBj/vXLLcmzesA7ZB2kYtFKMIdsUQpls
|
||||||
|
PE/4K5PEX2d8pxA5zxo0HleA1YjW8i5WEcDQThZQzj2sWvg06zSjenVFrbCm9Bro
|
||||||
|
hFpAB/3zJHxdRN2MpNpvK35WITy1aDUdX1WdyrlcRtIE5ssFTSoxSj9ibbDZ78+z
|
||||||
|
gtbw/MUi6vU6Yz1EjvoYu/bmZAHt9Aagcxw6k58fjO2cEB9njK7xbbiZUSwpJhEe
|
||||||
|
U/PxK+SdOU/MmGKeqdgqSfhJkq0vhacvsEjFGRAfivSCHkL0UjhObU+rSJ3g1RMO
|
||||||
|
oukAev6TOAwbTKVWjg3/EX+pl/zorAgaPNYFX64TSH4lE3VjeWApITb9Z5C/sVxR
|
||||||
|
xW6hU9qyjzWYWY+91y16nkw1l7VQvWHUZwV7QzTScC2BOzDVpeqY1KiYJxgoo6sX
|
||||||
|
ZCqR5oh4vToG4W8ZrRyauwUaZJ3r+zhAgm+6n6TJQNwFEl0muji+1nPl32EiFsRs
|
||||||
|
qR6CtuhUOVQM4VnILDwFJfuGYRFtKzQgvseLNU4ZqAVqQj8l4ARGAP2P1Au/uUKy
|
||||||
|
oGzI7a+b5MvRHuvkxPAclOgXgX/8yyOLaBg+mgaqv9h2JIJD28PzouFl3BajRaVB
|
||||||
|
7GWTnROJYhX5SuX/g585SLRKoQUtK0WhdJCjTRfyRJPwfdppgdTbWO99R4G+ir02
|
||||||
|
JQdSkZf2vmZRXenPNTEPDOUY6nVN6sUuBjmtOwoUF194ODgpYB6IaHqK08sa1pUh
|
||||||
|
1mZyxitHdPbygePTe20XWMZFoK2knAqN0JPPbbNjCqiVV+7oqQAnkDIutspu9t2m
|
||||||
|
ny3jefFmNozbblQMghLUrq+x9wOEgvS76Sqvq3DG/2BkLzJF3MNkvw==
|
||||||
|
-----END RSA PRIVATE KEY-----
|
9
conf/public.pem
Normal file
9
conf/public.pem
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
-----BEGIN PUBLIC KEY-----
|
||||||
|
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/
|
||||||
|
z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp
|
||||||
|
2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOO
|
||||||
|
UEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVK
|
||||||
|
RNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK
|
||||||
|
6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs
|
||||||
|
2wIDAQAB
|
||||||
|
-----END PUBLIC KEY-----
|
28
conf/service_conf.yaml
Normal file
28
conf/service_conf.yaml
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
authentication:
|
||||||
|
client:
|
||||||
|
switch: false
|
||||||
|
http_app_key:
|
||||||
|
http_secret_key:
|
||||||
|
site:
|
||||||
|
switch: false
|
||||||
|
permission:
|
||||||
|
switch: false
|
||||||
|
component: false
|
||||||
|
dataset: false
|
||||||
|
ragflow:
|
||||||
|
# you must set real ip address, 127.0.0.1 and 0.0.0.0 is not supported
|
||||||
|
host: 127.0.0.1
|
||||||
|
http_port: 9380
|
||||||
|
database:
|
||||||
|
name: 'rag_flow'
|
||||||
|
user: 'root'
|
||||||
|
passwd: 'infini_rag_flow'
|
||||||
|
host: '123.60.95.134'
|
||||||
|
port: 5455
|
||||||
|
max_connections: 100
|
||||||
|
stale_timeout: 30
|
||||||
|
oauth:
|
||||||
|
github:
|
||||||
|
client_id: 302129228f0d96055bee
|
||||||
|
secret_key: e518e55ccfcdfcae8996afc40f110e9c95f14fc4
|
||||||
|
url: https://github.com/login/oauth/access_token
|
21
docker/.env
21
docker/.env
@ -1,21 +0,0 @@
|
|||||||
# Version of Elastic products
|
|
||||||
STACK_VERSION=8.11.3
|
|
||||||
|
|
||||||
# Set the cluster name
|
|
||||||
CLUSTER_NAME=docgpt
|
|
||||||
|
|
||||||
# Port to expose Elasticsearch HTTP API to the host
|
|
||||||
ES_PORT=9200
|
|
||||||
|
|
||||||
# Port to expose Kibana to the host
|
|
||||||
KIBANA_PORT=6601
|
|
||||||
|
|
||||||
# Increase or decrease based on the available host memory (in bytes)
|
|
||||||
MEM_LIMIT=4073741824
|
|
||||||
|
|
||||||
POSTGRES_USER=root
|
|
||||||
POSTGRES_PASSWORD=infiniflow_docgpt
|
|
||||||
POSTGRES_DB=docgpt
|
|
||||||
|
|
||||||
MINIO_USER=infiniflow
|
|
||||||
MINIO_PASSWORD=infiniflow_docgpt
|
|
@ -1,7 +1,7 @@
|
|||||||
version: '2.2'
|
version: '2.2'
|
||||||
services:
|
services:
|
||||||
es01:
|
es01:
|
||||||
container_name: docgpt-es-01
|
container_name: ragflow-es-01
|
||||||
image: docker.elastic.co/elasticsearch/elasticsearch:${STACK_VERSION}
|
image: docker.elastic.co/elasticsearch/elasticsearch:${STACK_VERSION}
|
||||||
volumes:
|
volumes:
|
||||||
- esdata01:/usr/share/elasticsearch/data
|
- esdata01:/usr/share/elasticsearch/data
|
||||||
@ -20,14 +20,14 @@ services:
|
|||||||
soft: -1
|
soft: -1
|
||||||
hard: -1
|
hard: -1
|
||||||
networks:
|
networks:
|
||||||
- docgpt
|
- ragflow
|
||||||
restart: always
|
restart: always
|
||||||
|
|
||||||
kibana:
|
kibana:
|
||||||
depends_on:
|
depends_on:
|
||||||
- es01
|
- es01
|
||||||
image: docker.elastic.co/kibana/kibana:${STACK_VERSION}
|
image: docker.elastic.co/kibana/kibana:${STACK_VERSION}
|
||||||
container_name: docgpt-kibana
|
container_name: ragflow-kibana
|
||||||
volumes:
|
volumes:
|
||||||
- kibanadata:/usr/share/kibana/data
|
- kibanadata:/usr/share/kibana/data
|
||||||
ports:
|
ports:
|
||||||
@ -37,26 +37,39 @@ services:
|
|||||||
- ELASTICSEARCH_HOSTS=http://es01:9200
|
- ELASTICSEARCH_HOSTS=http://es01:9200
|
||||||
mem_limit: ${MEM_LIMIT}
|
mem_limit: ${MEM_LIMIT}
|
||||||
networks:
|
networks:
|
||||||
- docgpt
|
- ragflow
|
||||||
|
|
||||||
postgres:
|
mysql:
|
||||||
image: postgres
|
image: mysql:5.7.18
|
||||||
container_name: docgpt-postgres
|
container_name: ragflow-mysql
|
||||||
environment:
|
environment:
|
||||||
- POSTGRES_USER=${POSTGRES_USER}
|
- MYSQL_ROOT_PASSWORD=${MYSQL_PASSWORD}
|
||||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
|
- TZ="Asia/Shanghai"
|
||||||
- POSTGRES_DB=${POSTGRES_DB}
|
command:
|
||||||
|
--max_connections=1000
|
||||||
|
--character-set-server=utf8mb4
|
||||||
|
--collation-server=utf8mb4_general_ci
|
||||||
|
--default-authentication-plugin=mysql_native_password
|
||||||
|
--tls_version="TLSv1.2,TLSv1.3"
|
||||||
|
--init-file /data/application/init.sql
|
||||||
ports:
|
ports:
|
||||||
- 5455:5432
|
- ${MYSQL_PORT}:3306
|
||||||
volumes:
|
volumes:
|
||||||
- pg_data:/var/lib/postgresql/data
|
- mysql_data:/var/lib/mysql
|
||||||
|
- ./init.sql:/data/application/init.sql
|
||||||
networks:
|
networks:
|
||||||
- docgpt
|
- ragflow
|
||||||
|
healthcheck:
|
||||||
|
test: [ "CMD-SHELL", "curl --silent localhost:3306 >/dev/null || exit 1" ]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
restart: always
|
restart: always
|
||||||
|
|
||||||
|
|
||||||
minio:
|
minio:
|
||||||
image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
|
image: quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z
|
||||||
container_name: docgpt-minio
|
container_name: ragflow-minio
|
||||||
command: server --console-address ":9001" /data
|
command: server --console-address ":9001" /data
|
||||||
ports:
|
ports:
|
||||||
- 9000:9000
|
- 9000:9000
|
||||||
@ -67,7 +80,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- minio_data:/data
|
- minio_data:/data
|
||||||
networks:
|
networks:
|
||||||
- docgpt
|
- ragflow
|
||||||
restart: always
|
restart: always
|
||||||
|
|
||||||
|
|
||||||
@ -76,11 +89,11 @@ volumes:
|
|||||||
driver: local
|
driver: local
|
||||||
kibanadata:
|
kibanadata:
|
||||||
driver: local
|
driver: local
|
||||||
pg_data:
|
mysql_data:
|
||||||
driver: local
|
driver: local
|
||||||
minio_data:
|
minio_data:
|
||||||
driver: local
|
driver: local
|
||||||
|
|
||||||
networks:
|
networks:
|
||||||
docgpt:
|
ragflow:
|
||||||
driver: bridge
|
driver: bridge
|
||||||
|
2
docker/init.sql
Normal file
2
docker/init.sql
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
CREATE DATABASE IF NOT EXISTS rag_flow;
|
||||||
|
USE rag_flow;
|
@ -1,20 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "migration"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
publish = false
|
|
||||||
|
|
||||||
[lib]
|
|
||||||
name = "migration"
|
|
||||||
path = "src/lib.rs"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
async-std = { version = "1", features = ["attributes", "tokio1"] }
|
|
||||||
chrono = "0.4.31"
|
|
||||||
|
|
||||||
[dependencies.sea-orm-migration]
|
|
||||||
version = "0.12.0"
|
|
||||||
features = [
|
|
||||||
"runtime-tokio-rustls", # `ASYNC_RUNTIME` feature
|
|
||||||
"sqlx-postgres", # `DATABASE_DRIVER` feature
|
|
||||||
]
|
|
@ -1,41 +0,0 @@
|
|||||||
# Running Migrator CLI
|
|
||||||
|
|
||||||
- Generate a new migration file
|
|
||||||
```sh
|
|
||||||
cargo run -- generate MIGRATION_NAME
|
|
||||||
```
|
|
||||||
- Apply all pending migrations
|
|
||||||
```sh
|
|
||||||
cargo run
|
|
||||||
```
|
|
||||||
```sh
|
|
||||||
cargo run -- up
|
|
||||||
```
|
|
||||||
- Apply first 10 pending migrations
|
|
||||||
```sh
|
|
||||||
cargo run -- up -n 10
|
|
||||||
```
|
|
||||||
- Rollback last applied migrations
|
|
||||||
```sh
|
|
||||||
cargo run -- down
|
|
||||||
```
|
|
||||||
- Rollback last 10 applied migrations
|
|
||||||
```sh
|
|
||||||
cargo run -- down -n 10
|
|
||||||
```
|
|
||||||
- Drop all tables from the database, then reapply all migrations
|
|
||||||
```sh
|
|
||||||
cargo run -- fresh
|
|
||||||
```
|
|
||||||
- Rollback all applied migrations, then reapply all migrations
|
|
||||||
```sh
|
|
||||||
cargo run -- refresh
|
|
||||||
```
|
|
||||||
- Rollback all applied migrations
|
|
||||||
```sh
|
|
||||||
cargo run -- reset
|
|
||||||
```
|
|
||||||
- Check the status of all migrations
|
|
||||||
```sh
|
|
||||||
cargo run -- status
|
|
||||||
```
|
|
@ -1,12 +0,0 @@
|
|||||||
pub use sea_orm_migration::prelude::*;
|
|
||||||
|
|
||||||
mod m20220101_000001_create_table;
|
|
||||||
|
|
||||||
pub struct Migrator;
|
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
|
||||||
impl MigratorTrait for Migrator {
|
|
||||||
fn migrations() -> Vec<Box<dyn MigrationTrait>> {
|
|
||||||
vec![Box::new(m20220101_000001_create_table::Migration)]
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,440 +0,0 @@
|
|||||||
use sea_orm_migration::prelude::*;
|
|
||||||
use chrono::{ FixedOffset, Utc };
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
fn now() -> chrono::DateTime<FixedOffset> {
|
|
||||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
|
||||||
}
|
|
||||||
#[derive(DeriveMigrationName)]
|
|
||||||
pub struct Migration;
|
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
|
||||||
impl MigrationTrait for Migration {
|
|
||||||
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
|
||||||
manager.create_table(
|
|
||||||
Table::create()
|
|
||||||
.table(UserInfo::Table)
|
|
||||||
.if_not_exists()
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(UserInfo::Uid)
|
|
||||||
.big_integer()
|
|
||||||
.not_null()
|
|
||||||
.auto_increment()
|
|
||||||
.primary_key()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(UserInfo::Email).string().not_null())
|
|
||||||
.col(ColumnDef::new(UserInfo::Nickname).string().not_null())
|
|
||||||
.col(ColumnDef::new(UserInfo::AvatarBase64).string())
|
|
||||||
.col(ColumnDef::new(UserInfo::ColorScheme).string().default("dark"))
|
|
||||||
.col(ColumnDef::new(UserInfo::ListStyle).string().default("list"))
|
|
||||||
.col(ColumnDef::new(UserInfo::Language).string().default("chinese"))
|
|
||||||
.col(ColumnDef::new(UserInfo::Password).string().not_null())
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(UserInfo::LastLoginAt)
|
|
||||||
.timestamp_with_time_zone()
|
|
||||||
.default(Expr::current_timestamp())
|
|
||||||
)
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(UserInfo::CreatedAt)
|
|
||||||
.timestamp_with_time_zone()
|
|
||||||
.default(Expr::current_timestamp())
|
|
||||||
.not_null()
|
|
||||||
)
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(UserInfo::UpdatedAt)
|
|
||||||
.timestamp_with_time_zone()
|
|
||||||
.default(Expr::current_timestamp())
|
|
||||||
.not_null()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(UserInfo::IsDeleted).boolean().default(false))
|
|
||||||
.to_owned()
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
manager.create_table(
|
|
||||||
Table::create()
|
|
||||||
.table(TagInfo::Table)
|
|
||||||
.if_not_exists()
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(TagInfo::Tid)
|
|
||||||
.big_integer()
|
|
||||||
.not_null()
|
|
||||||
.auto_increment()
|
|
||||||
.primary_key()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(TagInfo::Uid).big_integer().not_null())
|
|
||||||
.col(ColumnDef::new(TagInfo::TagName).string().not_null())
|
|
||||||
.col(ColumnDef::new(TagInfo::Regx).string())
|
|
||||||
.col(ColumnDef::new(TagInfo::Color).tiny_unsigned().default(1))
|
|
||||||
.col(ColumnDef::new(TagInfo::Icon).tiny_unsigned().default(1))
|
|
||||||
.col(ColumnDef::new(TagInfo::FolderId).big_integer())
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(TagInfo::CreatedAt)
|
|
||||||
.timestamp_with_time_zone()
|
|
||||||
.default(Expr::current_timestamp())
|
|
||||||
.not_null()
|
|
||||||
)
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(TagInfo::UpdatedAt)
|
|
||||||
.timestamp_with_time_zone()
|
|
||||||
.default(Expr::current_timestamp())
|
|
||||||
.not_null()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(TagInfo::IsDeleted).boolean().default(false))
|
|
||||||
.to_owned()
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
manager.create_table(
|
|
||||||
Table::create()
|
|
||||||
.table(Tag2Doc::Table)
|
|
||||||
.if_not_exists()
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(Tag2Doc::Id)
|
|
||||||
.big_integer()
|
|
||||||
.not_null()
|
|
||||||
.auto_increment()
|
|
||||||
.primary_key()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(Tag2Doc::TagId).big_integer())
|
|
||||||
.col(ColumnDef::new(Tag2Doc::Did).big_integer())
|
|
||||||
.to_owned()
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
manager.create_table(
|
|
||||||
Table::create()
|
|
||||||
.table(Kb2Doc::Table)
|
|
||||||
.if_not_exists()
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(Kb2Doc::Id)
|
|
||||||
.big_integer()
|
|
||||||
.not_null()
|
|
||||||
.auto_increment()
|
|
||||||
.primary_key()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(Kb2Doc::KbId).big_integer())
|
|
||||||
.col(ColumnDef::new(Kb2Doc::Did).big_integer())
|
|
||||||
.col(ColumnDef::new(Kb2Doc::KbProgress).float().default(0))
|
|
||||||
.col(ColumnDef::new(Kb2Doc::KbProgressMsg).string().default(""))
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(Kb2Doc::UpdatedAt)
|
|
||||||
.timestamp_with_time_zone()
|
|
||||||
.default(Expr::current_timestamp())
|
|
||||||
.not_null()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(Kb2Doc::IsDeleted).boolean().default(false))
|
|
||||||
.to_owned()
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
manager.create_table(
|
|
||||||
Table::create()
|
|
||||||
.table(Dialog2Kb::Table)
|
|
||||||
.if_not_exists()
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(Dialog2Kb::Id)
|
|
||||||
.big_integer()
|
|
||||||
.not_null()
|
|
||||||
.auto_increment()
|
|
||||||
.primary_key()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(Dialog2Kb::DialogId).big_integer())
|
|
||||||
.col(ColumnDef::new(Dialog2Kb::KbId).big_integer())
|
|
||||||
.to_owned()
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
manager.create_table(
|
|
||||||
Table::create()
|
|
||||||
.table(Doc2Doc::Table)
|
|
||||||
.if_not_exists()
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(Doc2Doc::Id)
|
|
||||||
.big_integer()
|
|
||||||
.not_null()
|
|
||||||
.auto_increment()
|
|
||||||
.primary_key()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(Doc2Doc::ParentId).big_integer())
|
|
||||||
.col(ColumnDef::new(Doc2Doc::Did).big_integer())
|
|
||||||
.to_owned()
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
manager.create_table(
|
|
||||||
Table::create()
|
|
||||||
.table(KbInfo::Table)
|
|
||||||
.if_not_exists()
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(KbInfo::KbId)
|
|
||||||
.big_integer()
|
|
||||||
.auto_increment()
|
|
||||||
.not_null()
|
|
||||||
.primary_key()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(KbInfo::Uid).big_integer().not_null())
|
|
||||||
.col(ColumnDef::new(KbInfo::KbName).string().not_null())
|
|
||||||
.col(ColumnDef::new(KbInfo::Icon).tiny_unsigned().default(1))
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(KbInfo::CreatedAt)
|
|
||||||
.timestamp_with_time_zone()
|
|
||||||
.default(Expr::current_timestamp())
|
|
||||||
.not_null()
|
|
||||||
)
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(KbInfo::UpdatedAt)
|
|
||||||
.timestamp_with_time_zone()
|
|
||||||
.default(Expr::current_timestamp())
|
|
||||||
.not_null()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(KbInfo::IsDeleted).boolean().default(false))
|
|
||||||
.to_owned()
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
manager.create_table(
|
|
||||||
Table::create()
|
|
||||||
.table(DocInfo::Table)
|
|
||||||
.if_not_exists()
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(DocInfo::Did)
|
|
||||||
.big_integer()
|
|
||||||
.not_null()
|
|
||||||
.auto_increment()
|
|
||||||
.primary_key()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(DocInfo::Uid).big_integer().not_null())
|
|
||||||
.col(ColumnDef::new(DocInfo::DocName).string().not_null())
|
|
||||||
.col(ColumnDef::new(DocInfo::Location).string().not_null())
|
|
||||||
.col(ColumnDef::new(DocInfo::Size).big_integer().not_null())
|
|
||||||
.col(ColumnDef::new(DocInfo::Type).string().not_null())
|
|
||||||
.col(ColumnDef::new(DocInfo::ThumbnailBase64).string().default(""))
|
|
||||||
.comment("doc type|folder")
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(DocInfo::CreatedAt)
|
|
||||||
.timestamp_with_time_zone()
|
|
||||||
.default(Expr::current_timestamp())
|
|
||||||
.not_null()
|
|
||||||
)
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(DocInfo::UpdatedAt)
|
|
||||||
.timestamp_with_time_zone()
|
|
||||||
.default(Expr::current_timestamp())
|
|
||||||
.not_null()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(DocInfo::IsDeleted).boolean().default(false))
|
|
||||||
.to_owned()
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
manager.create_table(
|
|
||||||
Table::create()
|
|
||||||
.table(DialogInfo::Table)
|
|
||||||
.if_not_exists()
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(DialogInfo::DialogId)
|
|
||||||
.big_integer()
|
|
||||||
.not_null()
|
|
||||||
.auto_increment()
|
|
||||||
.primary_key()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(DialogInfo::Uid).big_integer().not_null())
|
|
||||||
.col(ColumnDef::new(DialogInfo::KbId).big_integer().not_null())
|
|
||||||
.col(ColumnDef::new(DialogInfo::DialogName).string().not_null())
|
|
||||||
.col(ColumnDef::new(DialogInfo::History).string().comment("json"))
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(DialogInfo::CreatedAt)
|
|
||||||
.timestamp_with_time_zone()
|
|
||||||
.default(Expr::current_timestamp())
|
|
||||||
.not_null()
|
|
||||||
)
|
|
||||||
.col(
|
|
||||||
ColumnDef::new(DialogInfo::UpdatedAt)
|
|
||||||
.timestamp_with_time_zone()
|
|
||||||
.default(Expr::current_timestamp())
|
|
||||||
.not_null()
|
|
||||||
)
|
|
||||||
.col(ColumnDef::new(DialogInfo::IsDeleted).boolean().default(false))
|
|
||||||
.to_owned()
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
let root_insert = Query::insert()
|
|
||||||
.into_table(UserInfo::Table)
|
|
||||||
.columns([UserInfo::Email, UserInfo::Nickname, UserInfo::Password])
|
|
||||||
.values_panic(["kai.hu@infiniflow.org".into(), "root".into(), "123456".into()])
|
|
||||||
.to_owned();
|
|
||||||
|
|
||||||
let doc_insert = Query::insert()
|
|
||||||
.into_table(DocInfo::Table)
|
|
||||||
.columns([
|
|
||||||
DocInfo::Uid,
|
|
||||||
DocInfo::DocName,
|
|
||||||
DocInfo::Size,
|
|
||||||
DocInfo::Type,
|
|
||||||
DocInfo::Location,
|
|
||||||
])
|
|
||||||
.values_panic([(1).into(), "/".into(), (0).into(), "folder".into(), "".into()])
|
|
||||||
.to_owned();
|
|
||||||
|
|
||||||
let tag_insert = Query::insert()
|
|
||||||
.into_table(TagInfo::Table)
|
|
||||||
.columns([TagInfo::Uid, TagInfo::TagName, TagInfo::Regx, TagInfo::Color, TagInfo::Icon])
|
|
||||||
.values_panic([
|
|
||||||
(1).into(),
|
|
||||||
"Video".into(),
|
|
||||||
".*\\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)".into(),
|
|
||||||
(1).into(),
|
|
||||||
(1).into(),
|
|
||||||
])
|
|
||||||
.values_panic([
|
|
||||||
(1).into(),
|
|
||||||
"Picture".into(),
|
|
||||||
".*\\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)".into(),
|
|
||||||
(2).into(),
|
|
||||||
(2).into(),
|
|
||||||
])
|
|
||||||
.values_panic([
|
|
||||||
(1).into(),
|
|
||||||
"Music".into(),
|
|
||||||
".*\\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)".into(),
|
|
||||||
(3).into(),
|
|
||||||
(3).into(),
|
|
||||||
])
|
|
||||||
.values_panic([
|
|
||||||
(1).into(),
|
|
||||||
"Document".into(),
|
|
||||||
".*\\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)".into(),
|
|
||||||
(3).into(),
|
|
||||||
(3).into(),
|
|
||||||
])
|
|
||||||
.to_owned();
|
|
||||||
|
|
||||||
manager.exec_stmt(root_insert).await?;
|
|
||||||
manager.exec_stmt(doc_insert).await?;
|
|
||||||
manager.exec_stmt(tag_insert).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
|
||||||
manager.drop_table(Table::drop().table(UserInfo::Table).to_owned()).await?;
|
|
||||||
|
|
||||||
manager.drop_table(Table::drop().table(TagInfo::Table).to_owned()).await?;
|
|
||||||
|
|
||||||
manager.drop_table(Table::drop().table(Tag2Doc::Table).to_owned()).await?;
|
|
||||||
|
|
||||||
manager.drop_table(Table::drop().table(Kb2Doc::Table).to_owned()).await?;
|
|
||||||
|
|
||||||
manager.drop_table(Table::drop().table(Dialog2Kb::Table).to_owned()).await?;
|
|
||||||
|
|
||||||
manager.drop_table(Table::drop().table(Doc2Doc::Table).to_owned()).await?;
|
|
||||||
|
|
||||||
manager.drop_table(Table::drop().table(KbInfo::Table).to_owned()).await?;
|
|
||||||
|
|
||||||
manager.drop_table(Table::drop().table(DocInfo::Table).to_owned()).await?;
|
|
||||||
|
|
||||||
manager.drop_table(Table::drop().table(DialogInfo::Table).to_owned()).await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(DeriveIden)]
|
|
||||||
enum UserInfo {
|
|
||||||
Table,
|
|
||||||
Uid,
|
|
||||||
Email,
|
|
||||||
Nickname,
|
|
||||||
AvatarBase64,
|
|
||||||
ColorScheme,
|
|
||||||
ListStyle,
|
|
||||||
Language,
|
|
||||||
Password,
|
|
||||||
LastLoginAt,
|
|
||||||
CreatedAt,
|
|
||||||
UpdatedAt,
|
|
||||||
IsDeleted,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(DeriveIden)]
|
|
||||||
enum TagInfo {
|
|
||||||
Table,
|
|
||||||
Tid,
|
|
||||||
Uid,
|
|
||||||
TagName,
|
|
||||||
Regx,
|
|
||||||
Color,
|
|
||||||
Icon,
|
|
||||||
FolderId,
|
|
||||||
CreatedAt,
|
|
||||||
UpdatedAt,
|
|
||||||
IsDeleted,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(DeriveIden)]
|
|
||||||
enum Tag2Doc {
|
|
||||||
Table,
|
|
||||||
Id,
|
|
||||||
TagId,
|
|
||||||
Did,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(DeriveIden)]
|
|
||||||
enum Kb2Doc {
|
|
||||||
Table,
|
|
||||||
Id,
|
|
||||||
KbId,
|
|
||||||
Did,
|
|
||||||
KbProgress,
|
|
||||||
KbProgressMsg,
|
|
||||||
UpdatedAt,
|
|
||||||
IsDeleted,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(DeriveIden)]
|
|
||||||
enum Dialog2Kb {
|
|
||||||
Table,
|
|
||||||
Id,
|
|
||||||
DialogId,
|
|
||||||
KbId,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(DeriveIden)]
|
|
||||||
enum Doc2Doc {
|
|
||||||
Table,
|
|
||||||
Id,
|
|
||||||
ParentId,
|
|
||||||
Did,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(DeriveIden)]
|
|
||||||
enum KbInfo {
|
|
||||||
Table,
|
|
||||||
KbId,
|
|
||||||
Uid,
|
|
||||||
KbName,
|
|
||||||
Icon,
|
|
||||||
CreatedAt,
|
|
||||||
UpdatedAt,
|
|
||||||
IsDeleted,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(DeriveIden)]
|
|
||||||
enum DocInfo {
|
|
||||||
Table,
|
|
||||||
Did,
|
|
||||||
Uid,
|
|
||||||
DocName,
|
|
||||||
Location,
|
|
||||||
Size,
|
|
||||||
Type,
|
|
||||||
ThumbnailBase64,
|
|
||||||
CreatedAt,
|
|
||||||
UpdatedAt,
|
|
||||||
IsDeleted,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(DeriveIden)]
|
|
||||||
enum DialogInfo {
|
|
||||||
Table,
|
|
||||||
Uid,
|
|
||||||
KbId,
|
|
||||||
DialogId,
|
|
||||||
DialogName,
|
|
||||||
History,
|
|
||||||
CreatedAt,
|
|
||||||
UpdatedAt,
|
|
||||||
IsDeleted,
|
|
||||||
}
|
|
@ -1,6 +0,0 @@
|
|||||||
use sea_orm_migration::prelude::*;
|
|
||||||
|
|
||||||
#[async_std::main]
|
|
||||||
async fn main() {
|
|
||||||
cli::run_cli(migration::Migrator).await;
|
|
||||||
}
|
|
29
python/Dockerfile
Normal file
29
python/Dockerfile
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
FROM ubuntu:22.04 as base
|
||||||
|
|
||||||
|
RUN apt-get update
|
||||||
|
|
||||||
|
ENV TZ="Asia/Taipei"
|
||||||
|
RUN apt-get install -yq \
|
||||||
|
build-essential \
|
||||||
|
curl \
|
||||||
|
libncursesw5-dev \
|
||||||
|
libssl-dev \
|
||||||
|
libsqlite3-dev \
|
||||||
|
libgdbm-dev \
|
||||||
|
libc6-dev \
|
||||||
|
libbz2-dev \
|
||||||
|
software-properties-common \
|
||||||
|
python3.11 python3.11-dev python3-pip
|
||||||
|
|
||||||
|
RUN apt-get install -yq git
|
||||||
|
RUN pip3 config set global.index-url https://mirror.baidu.com/pypi/simple
|
||||||
|
RUN pip3 config set global.trusted-host mirror.baidu.com
|
||||||
|
RUN pip3 install --upgrade pip
|
||||||
|
RUN pip3 install torch==2.0.1
|
||||||
|
RUN pip3 install torch-model-archiver==0.8.2
|
||||||
|
RUN pip3 install torchvision==0.15.2
|
||||||
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
WORKDIR /docgpt
|
||||||
|
ENV PYTHONPATH=/docgpt/
|
||||||
|
|
@ -1,22 +0,0 @@
|
|||||||
|
|
||||||
```shell
|
|
||||||
|
|
||||||
docker pull postgres
|
|
||||||
|
|
||||||
LOCAL_POSTGRES_DATA=./postgres-data
|
|
||||||
|
|
||||||
docker run
|
|
||||||
--name docass-postgres
|
|
||||||
-p 5455:5432
|
|
||||||
-v $LOCAL_POSTGRES_DATA:/var/lib/postgresql/data
|
|
||||||
-e POSTGRES_USER=root
|
|
||||||
-e POSTGRES_PASSWORD=infiniflow_docass
|
|
||||||
-e POSTGRES_DB=docass
|
|
||||||
-d
|
|
||||||
postgres
|
|
||||||
|
|
||||||
docker network create elastic
|
|
||||||
docker pull elasticsearch:8.11.3;
|
|
||||||
docker pull docker.elastic.co/kibana/kibana:8.11.3
|
|
||||||
|
|
||||||
```
|
|
63
python/]
Normal file
63
python/]
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from openai import OpenAI
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
class Base(ABC):
|
||||||
|
def describe(self, image, max_tokens=300):
|
||||||
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
|
|
||||||
|
class GptV4(Base):
|
||||||
|
def __init__(self):
|
||||||
|
import openapi
|
||||||
|
openapi.api_key = os.environ["OPENAPI_KEY"]
|
||||||
|
self.client = OpenAI()
|
||||||
|
|
||||||
|
def describe(self, image, max_tokens=300):
|
||||||
|
buffered = BytesIO()
|
||||||
|
try:
|
||||||
|
image.save(buffered, format="JPEG")
|
||||||
|
except Exception as e:
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
res = self.client.chat.completions.create(
|
||||||
|
model="gpt-4-vision-preview",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{b64}"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
return res.choices[0].message.content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class QWen(Base):
|
||||||
|
def chat(self, system, history, gen_conf):
|
||||||
|
from http import HTTPStatus
|
||||||
|
from dashscope import Generation
|
||||||
|
from dashscope.api_entities.dashscope_response import Role
|
||||||
|
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||||
|
response = Generation.call(
|
||||||
|
Generation.Models.qwen_turbo,
|
||||||
|
messages=messages,
|
||||||
|
result_format='message'
|
||||||
|
)
|
||||||
|
if response.status_code == HTTPStatus.OK:
|
||||||
|
return response.output.choices[0]['message']['content']
|
||||||
|
return response.message
|
@ -1,41 +0,0 @@
|
|||||||
{
|
|
||||||
"version":1,
|
|
||||||
"disable_existing_loggers":false,
|
|
||||||
"formatters":{
|
|
||||||
"simple":{
|
|
||||||
"format":"%(asctime)s - %(name)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"handlers":{
|
|
||||||
"console":{
|
|
||||||
"class":"logging.StreamHandler",
|
|
||||||
"level":"DEBUG",
|
|
||||||
"formatter":"simple",
|
|
||||||
"stream":"ext://sys.stdout"
|
|
||||||
},
|
|
||||||
"info_file_handler":{
|
|
||||||
"class":"logging.handlers.TimedRotatingFileHandler",
|
|
||||||
"level":"INFO",
|
|
||||||
"formatter":"simple",
|
|
||||||
"filename":"log/info.log",
|
|
||||||
"when": "MIDNIGHT",
|
|
||||||
"interval":1,
|
|
||||||
"backupCount":30,
|
|
||||||
"encoding":"utf8"
|
|
||||||
},
|
|
||||||
"error_file_handler":{
|
|
||||||
"class":"logging.handlers.TimedRotatingFileHandler",
|
|
||||||
"level":"ERROR",
|
|
||||||
"formatter":"simple",
|
|
||||||
"filename":"log/errors.log",
|
|
||||||
"when": "MIDNIGHT",
|
|
||||||
"interval":1,
|
|
||||||
"backupCount":30,
|
|
||||||
"encoding":"utf8"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"root":{
|
|
||||||
"level":"DEBUG",
|
|
||||||
"handlers":["console","info_file_handler","error_file_handler"]
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,9 +0,0 @@
|
|||||||
[infiniflow]
|
|
||||||
es=http://es01:9200
|
|
||||||
postgres_user=root
|
|
||||||
postgres_password=infiniflow_docgpt
|
|
||||||
postgres_host=postgres
|
|
||||||
postgres_port=5432
|
|
||||||
minio_host=minio:9000
|
|
||||||
minio_user=infiniflow
|
|
||||||
minio_password=infiniflow_docgpt
|
|
@ -1,21 +0,0 @@
|
|||||||
import os
|
|
||||||
from .embedding_model import *
|
|
||||||
from .chat_model import *
|
|
||||||
from .cv_model import *
|
|
||||||
|
|
||||||
EmbeddingModel = None
|
|
||||||
ChatModel = None
|
|
||||||
CvModel = None
|
|
||||||
|
|
||||||
|
|
||||||
if os.environ.get("OPENAI_API_KEY"):
|
|
||||||
EmbeddingModel = GptEmbed()
|
|
||||||
ChatModel = GptTurbo()
|
|
||||||
CvModel = GptV4()
|
|
||||||
|
|
||||||
elif os.environ.get("DASHSCOPE_API_KEY"):
|
|
||||||
EmbeddingModel = QWenEmbd()
|
|
||||||
ChatModel = QWenChat()
|
|
||||||
CvModel = QWenCV()
|
|
||||||
else:
|
|
||||||
EmbeddingModel = HuEmbedding()
|
|
@ -1,61 +0,0 @@
|
|||||||
from abc import ABC
|
|
||||||
from openai import OpenAI
|
|
||||||
from FlagEmbedding import FlagModel
|
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
|
||||||
def encode(self, texts: list, batch_size=32):
|
|
||||||
raise NotImplementedError("Please implement encode method!")
|
|
||||||
|
|
||||||
|
|
||||||
class HuEmbedding(Base):
|
|
||||||
def __init__(self):
|
|
||||||
"""
|
|
||||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
|
||||||
|
|
||||||
For Linux:
|
|
||||||
export HF_ENDPOINT=https://hf-mirror.com
|
|
||||||
|
|
||||||
For Windows:
|
|
||||||
Good luck
|
|
||||||
^_-
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
|
|
||||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
|
||||||
use_fp16=torch.cuda.is_available())
|
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=32):
|
|
||||||
res = []
|
|
||||||
for i in range(0, len(texts), batch_size):
|
|
||||||
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
|
||||||
return np.array(res)
|
|
||||||
|
|
||||||
|
|
||||||
class GptEmbed(Base):
|
|
||||||
def __init__(self):
|
|
||||||
self.client = OpenAI(api_key=os.envirement["OPENAI_API_KEY"])
|
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=32):
|
|
||||||
res = self.client.embeddings.create(input=texts,
|
|
||||||
model="text-embedding-ada-002")
|
|
||||||
return [d["embedding"] for d in res["data"]]
|
|
||||||
|
|
||||||
|
|
||||||
class QWenEmbd(Base):
|
|
||||||
def encode(self, texts: list, batch_size=32, text_type="document"):
|
|
||||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
|
||||||
import dashscope
|
|
||||||
from http import HTTPStatus
|
|
||||||
res = []
|
|
||||||
for txt in texts:
|
|
||||||
resp = dashscope.TextEmbedding.call(
|
|
||||||
model=dashscope.TextEmbedding.Models.text_embedding_v2,
|
|
||||||
input=txt[:2048],
|
|
||||||
text_type=text_type
|
|
||||||
)
|
|
||||||
res.append(resp["output"]["embeddings"][0]["embedding"])
|
|
||||||
return res
|
|
0
python/output/ToPDF.pdf
Normal file
0
python/output/ToPDF.pdf
Normal file
@ -1,25 +0,0 @@
|
|||||||
from openpyxl import load_workbook
|
|
||||||
import sys
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
|
|
||||||
class HuExcelParser:
|
|
||||||
def __call__(self, fnm):
|
|
||||||
if isinstance(fnm, str):
|
|
||||||
wb = load_workbook(fnm)
|
|
||||||
else:
|
|
||||||
wb = load_workbook(BytesIO(fnm))
|
|
||||||
res = []
|
|
||||||
for sheetname in wb.sheetnames:
|
|
||||||
ws = wb[sheetname]
|
|
||||||
lines = []
|
|
||||||
for r in ws.rows:
|
|
||||||
lines.append(
|
|
||||||
"\t".join([str(c.value) if c.value is not None else "" for c in r]))
|
|
||||||
res.append(f"《{sheetname}》\n" + "\n".join(lines))
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
psr = HuExcelParser()
|
|
||||||
psr(sys.argv[1])
|
|
@ -1,194 +0,0 @@
|
|||||||
accelerate==0.24.1
|
|
||||||
addict==2.4.0
|
|
||||||
aiobotocore==2.7.0
|
|
||||||
aiofiles==23.2.1
|
|
||||||
aiohttp==3.8.6
|
|
||||||
aioitertools==0.11.0
|
|
||||||
aiosignal==1.3.1
|
|
||||||
aliyun-python-sdk-core==2.14.0
|
|
||||||
aliyun-python-sdk-kms==2.16.2
|
|
||||||
altair==5.1.2
|
|
||||||
anyio==3.7.1
|
|
||||||
astor==0.8.1
|
|
||||||
async-timeout==4.0.3
|
|
||||||
attrdict==2.0.1
|
|
||||||
attrs==23.1.0
|
|
||||||
Babel==2.13.1
|
|
||||||
bce-python-sdk==0.8.92
|
|
||||||
beautifulsoup4==4.12.2
|
|
||||||
bitsandbytes==0.41.1
|
|
||||||
blinker==1.7.0
|
|
||||||
botocore==1.31.64
|
|
||||||
cachetools==5.3.2
|
|
||||||
certifi==2023.7.22
|
|
||||||
cffi==1.16.0
|
|
||||||
charset-normalizer==3.3.2
|
|
||||||
click==8.1.7
|
|
||||||
cloudpickle==3.0.0
|
|
||||||
contourpy==1.2.0
|
|
||||||
crcmod==1.7
|
|
||||||
cryptography==41.0.5
|
|
||||||
cssselect==1.2.0
|
|
||||||
cssutils==2.9.0
|
|
||||||
cycler==0.12.1
|
|
||||||
Cython==3.0.5
|
|
||||||
datasets==2.13.0
|
|
||||||
datrie==0.8.2
|
|
||||||
decorator==5.1.1
|
|
||||||
defusedxml==0.7.1
|
|
||||||
dill==0.3.6
|
|
||||||
einops==0.7.0
|
|
||||||
elastic-transport==8.10.0
|
|
||||||
elasticsearch==8.10.1
|
|
||||||
elasticsearch-dsl==8.9.0
|
|
||||||
et-xmlfile==1.1.0
|
|
||||||
fastapi==0.104.1
|
|
||||||
ffmpy==0.3.1
|
|
||||||
filelock==3.13.1
|
|
||||||
fire==0.5.0
|
|
||||||
FlagEmbedding==1.1.5
|
|
||||||
Flask==3.0.0
|
|
||||||
flask-babel==4.0.0
|
|
||||||
fonttools==4.44.0
|
|
||||||
frozenlist==1.4.0
|
|
||||||
fsspec==2023.10.0
|
|
||||||
future==0.18.3
|
|
||||||
gast==0.5.4
|
|
||||||
-e
|
|
||||||
git+https://github.com/ggerganov/llama.cpp.git@5f6e0c0dff1e7a89331e6b25eca9a9fd71324069#egg=gguf&subdirectory=gguf-py
|
|
||||||
gradio==3.50.2
|
|
||||||
gradio_client==0.6.1
|
|
||||||
greenlet==3.0.1
|
|
||||||
h11==0.14.0
|
|
||||||
hanziconv==0.3.2
|
|
||||||
httpcore==1.0.1
|
|
||||||
httpx==0.25.1
|
|
||||||
huggingface-hub==0.17.3
|
|
||||||
idna==3.4
|
|
||||||
imageio==2.31.6
|
|
||||||
imgaug==0.4.0
|
|
||||||
importlib-metadata==6.8.0
|
|
||||||
importlib-resources==6.1.0
|
|
||||||
install==1.3.5
|
|
||||||
itsdangerous==2.1.2
|
|
||||||
Jinja2==3.1.2
|
|
||||||
jmespath==0.10.0
|
|
||||||
joblib==1.3.2
|
|
||||||
jsonschema==4.19.2
|
|
||||||
jsonschema-specifications==2023.7.1
|
|
||||||
kiwisolver==1.4.5
|
|
||||||
lazy_loader==0.3
|
|
||||||
lmdb==1.4.1
|
|
||||||
lxml==4.9.3
|
|
||||||
MarkupSafe==2.1.3
|
|
||||||
matplotlib==3.8.1
|
|
||||||
modelscope==1.9.4
|
|
||||||
mpmath==1.3.0
|
|
||||||
multidict==6.0.4
|
|
||||||
multiprocess==0.70.14
|
|
||||||
networkx==3.2.1
|
|
||||||
nltk==3.8.1
|
|
||||||
numpy==1.24.4
|
|
||||||
nvidia-cublas-cu12==12.1.3.1
|
|
||||||
nvidia-cuda-cupti-cu12==12.1.105
|
|
||||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
|
||||||
nvidia-cuda-runtime-cu12==12.1.105
|
|
||||||
nvidia-cudnn-cu12==8.9.2.26
|
|
||||||
nvidia-cufft-cu12==11.0.2.54
|
|
||||||
nvidia-curand-cu12==10.3.2.106
|
|
||||||
nvidia-cusolver-cu12==11.4.5.107
|
|
||||||
nvidia-cusparse-cu12==12.1.0.106
|
|
||||||
nvidia-nccl-cu12==2.18.1
|
|
||||||
nvidia-nvjitlink-cu12==12.3.52
|
|
||||||
nvidia-nvtx-cu12==12.1.105
|
|
||||||
opencv-contrib-python==4.6.0.66
|
|
||||||
opencv-python==4.6.0.66
|
|
||||||
openpyxl==3.1.2
|
|
||||||
opt-einsum==3.3.0
|
|
||||||
orjson==3.9.10
|
|
||||||
oss2==2.18.3
|
|
||||||
packaging==23.2
|
|
||||||
paddleocr==2.7.0.3
|
|
||||||
paddlepaddle-gpu==2.5.2.post120
|
|
||||||
pandas==2.1.2
|
|
||||||
pdf2docx==0.5.5
|
|
||||||
pdfminer.six==20221105
|
|
||||||
pdfplumber==0.10.3
|
|
||||||
Pillow==10.0.1
|
|
||||||
platformdirs==3.11.0
|
|
||||||
premailer==3.10.0
|
|
||||||
protobuf==4.25.0
|
|
||||||
psutil==5.9.6
|
|
||||||
pyarrow==14.0.0
|
|
||||||
pyclipper==1.3.0.post5
|
|
||||||
pycocotools==2.0.7
|
|
||||||
pycparser==2.21
|
|
||||||
pycryptodome==3.19.0
|
|
||||||
pydantic==1.10.13
|
|
||||||
pydub==0.25.1
|
|
||||||
PyMuPDF==1.20.2
|
|
||||||
pyparsing==3.1.1
|
|
||||||
pypdfium2==4.23.1
|
|
||||||
python-dateutil==2.8.2
|
|
||||||
python-docx==1.1.0
|
|
||||||
python-multipart==0.0.6
|
|
||||||
pytz==2023.3.post1
|
|
||||||
PyYAML==6.0.1
|
|
||||||
rapidfuzz==3.5.2
|
|
||||||
rarfile==4.1
|
|
||||||
referencing==0.30.2
|
|
||||||
regex==2023.10.3
|
|
||||||
requests==2.31.0
|
|
||||||
rpds-py==0.12.0
|
|
||||||
s3fs==2023.10.0
|
|
||||||
safetensors==0.4.0
|
|
||||||
scikit-image==0.22.0
|
|
||||||
scikit-learn==1.3.2
|
|
||||||
scipy==1.11.3
|
|
||||||
semantic-version==2.10.0
|
|
||||||
sentence-transformers==2.2.2
|
|
||||||
sentencepiece==0.1.98
|
|
||||||
shapely==2.0.2
|
|
||||||
simplejson==3.19.2
|
|
||||||
six==1.16.0
|
|
||||||
sniffio==1.3.0
|
|
||||||
sortedcontainers==2.4.0
|
|
||||||
soupsieve==2.5
|
|
||||||
SQLAlchemy==2.0.23
|
|
||||||
starlette==0.27.0
|
|
||||||
sympy==1.12
|
|
||||||
tabulate==0.9.0
|
|
||||||
tblib==3.0.0
|
|
||||||
termcolor==2.3.0
|
|
||||||
threadpoolctl==3.2.0
|
|
||||||
tifffile==2023.9.26
|
|
||||||
tiktoken==0.5.1
|
|
||||||
timm==0.9.10
|
|
||||||
tokenizers==0.13.3
|
|
||||||
tomli==2.0.1
|
|
||||||
toolz==0.12.0
|
|
||||||
torch==2.1.0
|
|
||||||
torchaudio==2.1.0
|
|
||||||
torchvision==0.16.0
|
|
||||||
tornado==6.3.3
|
|
||||||
tqdm==4.66.1
|
|
||||||
transformers==4.33.0
|
|
||||||
transformers-stream-generator==0.0.4
|
|
||||||
triton==2.1.0
|
|
||||||
typing_extensions==4.8.0
|
|
||||||
tzdata==2023.3
|
|
||||||
urllib3==2.0.7
|
|
||||||
uvicorn==0.24.0
|
|
||||||
uvloop==0.19.0
|
|
||||||
visualdl==2.5.3
|
|
||||||
websockets==11.0.3
|
|
||||||
Werkzeug==3.0.1
|
|
||||||
wrapt==1.15.0
|
|
||||||
xgboost==2.0.1
|
|
||||||
xinference==0.6.0
|
|
||||||
xorbits==0.7.0
|
|
||||||
xoscar==0.1.3
|
|
||||||
xxhash==3.4.1
|
|
||||||
yapf==0.40.2
|
|
||||||
yarl==1.9.2
|
|
||||||
zipp==3.17.0
|
|
8
python/res/1-0.tm
Normal file
8
python/res/1-0.tm
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
2023-12-20 11:44:08.791336+00:00
|
||||||
|
2023-12-20 11:44:08.853249+00:00
|
||||||
|
2023-12-20 11:44:08.909933+00:00
|
||||||
|
2023-12-21 00:47:09.996757+00:00
|
||||||
|
2023-12-20 11:44:08.965855+00:00
|
||||||
|
2023-12-20 11:44:09.011682+00:00
|
||||||
|
2023-12-21 00:47:10.063326+00:00
|
||||||
|
2023-12-20 11:44:09.069486+00:00
|
3
python/res/thumbnail-1-0.tm
Normal file
3
python/res/thumbnail-1-0.tm
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
2023-12-27 08:21:49.309802+00:00
|
||||||
|
2023-12-27 08:37:22.407772+00:00
|
||||||
|
2023-12-27 08:59:18.845627+00:00
|
@ -1,118 +0,0 @@
|
|||||||
import sys, datetime, random, re, cv2
|
|
||||||
from os.path import dirname, realpath
|
|
||||||
sys.path.append(dirname(realpath(__file__)) + "/../")
|
|
||||||
from util.db_conn import Postgres
|
|
||||||
from util.minio_conn import HuMinio
|
|
||||||
from util import findMaxDt
|
|
||||||
import base64
|
|
||||||
from io import BytesIO
|
|
||||||
import pandas as pd
|
|
||||||
from PIL import Image
|
|
||||||
import pdfplumber
|
|
||||||
|
|
||||||
|
|
||||||
PG = Postgres("infiniflow", "docgpt")
|
|
||||||
MINIO = HuMinio("infiniflow")
|
|
||||||
def set_thumbnail(did, base64):
|
|
||||||
sql = f"""
|
|
||||||
update doc_info set thumbnail_base64='{base64}'
|
|
||||||
where
|
|
||||||
did={did}
|
|
||||||
"""
|
|
||||||
PG.update(sql)
|
|
||||||
|
|
||||||
|
|
||||||
def collect(comm, mod, tm):
|
|
||||||
sql = f"""
|
|
||||||
select
|
|
||||||
did, uid, doc_name, location, updated_at
|
|
||||||
from doc_info
|
|
||||||
where
|
|
||||||
updated_at >= '{tm}'
|
|
||||||
and MOD(did, {comm}) = {mod}
|
|
||||||
and is_deleted=false
|
|
||||||
and type <> 'folder'
|
|
||||||
and thumbnail_base64=''
|
|
||||||
order by updated_at asc
|
|
||||||
limit 10
|
|
||||||
"""
|
|
||||||
docs = PG.select(sql)
|
|
||||||
if len(docs) == 0:return pd.DataFrame()
|
|
||||||
|
|
||||||
mtm = str(docs["updated_at"].max())[:19]
|
|
||||||
print("TOTAL:", len(docs), "To: ", mtm)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
def build(row):
|
|
||||||
if not re.search(r"\.(pdf|jpg|jpeg|png|gif|svg|apng|icon|ico|webp|mpg|mpeg|avi|rm|rmvb|mov|wmv|mp4)$",
|
|
||||||
row["doc_name"].lower().strip()):
|
|
||||||
set_thumbnail(row["did"], "_")
|
|
||||||
return
|
|
||||||
|
|
||||||
def thumbnail(img, SIZE=128):
|
|
||||||
w,h = img.size
|
|
||||||
p = SIZE/max(w, h)
|
|
||||||
w, h = int(w*p), int(h*p)
|
|
||||||
img.thumbnail((w, h))
|
|
||||||
buffered = BytesIO()
|
|
||||||
try:
|
|
||||||
img.save(buffered, format="JPEG")
|
|
||||||
except Exception as e:
|
|
||||||
try:
|
|
||||||
img.save(buffered, format="PNG")
|
|
||||||
except Exception as ee:
|
|
||||||
pass
|
|
||||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
iobytes = BytesIO(MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
|
|
||||||
if re.search(r"\.pdf$", row["doc_name"].lower().strip()):
|
|
||||||
pdf = pdfplumber.open(iobytes)
|
|
||||||
img = pdf.pages[0].to_image().annotated
|
|
||||||
set_thumbnail(row["did"], thumbnail(img))
|
|
||||||
|
|
||||||
if re.search(r"\.(jpg|jpeg|png|gif|svg|apng|webp|icon|ico)$", row["doc_name"].lower().strip()):
|
|
||||||
img = Image.open(iobytes)
|
|
||||||
set_thumbnail(row["did"], thumbnail(img))
|
|
||||||
|
|
||||||
if re.search(r"\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|mp4)$", row["doc_name"].lower().strip()):
|
|
||||||
url = MINIO.get_presigned_url("%s-upload"%str(row["uid"]),
|
|
||||||
row["location"],
|
|
||||||
expires=datetime.timedelta(seconds=60)
|
|
||||||
)
|
|
||||||
cap = cv2.VideoCapture(url)
|
|
||||||
succ = cap.isOpened()
|
|
||||||
i = random.randint(1, 11)
|
|
||||||
while succ:
|
|
||||||
ret, frame = cap.read()
|
|
||||||
if not ret: break
|
|
||||||
if i > 0:
|
|
||||||
i -= 1
|
|
||||||
continue
|
|
||||||
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
||||||
print(img.size)
|
|
||||||
set_thumbnail(row["did"], thumbnail(img))
|
|
||||||
cap.release()
|
|
||||||
cv2.destroyAllWindows()
|
|
||||||
|
|
||||||
|
|
||||||
def main(comm, mod):
|
|
||||||
global model
|
|
||||||
tm_fnm = f"res/thumbnail-{comm}-{mod}.tm"
|
|
||||||
tm = findMaxDt(tm_fnm)
|
|
||||||
rows = collect(comm, mod, tm)
|
|
||||||
if len(rows) == 0:return
|
|
||||||
|
|
||||||
tmf = open(tm_fnm, "a+")
|
|
||||||
for _, r in rows.iterrows():
|
|
||||||
build(r)
|
|
||||||
tmf.write(str(r["updated_at"]) + "\n")
|
|
||||||
tmf.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from mpi4py import MPI
|
|
||||||
comm = MPI.COMM_WORLD
|
|
||||||
main(comm.Get_size(), comm.Get_rank())
|
|
||||||
|
|
@ -1,165 +0,0 @@
|
|||||||
#-*- coding:utf-8 -*-
|
|
||||||
import sys, os, re,inspect,json,traceback,logging,argparse, copy
|
|
||||||
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
|
|
||||||
from tornado.web import RequestHandler,Application
|
|
||||||
from tornado.ioloop import IOLoop
|
|
||||||
from tornado.httpserver import HTTPServer
|
|
||||||
from tornado.options import define,options
|
|
||||||
from util import es_conn, setup_logging
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
|
||||||
from nlp import huqie
|
|
||||||
from nlp import query as Query
|
|
||||||
from nlp import search
|
|
||||||
from llm import HuEmbedding, GptTurbo
|
|
||||||
import numpy as np
|
|
||||||
from io import BytesIO
|
|
||||||
from util import config
|
|
||||||
from timeit import default_timer as timer
|
|
||||||
from collections import OrderedDict
|
|
||||||
from llm import ChatModel, EmbeddingModel
|
|
||||||
|
|
||||||
SE = None
|
|
||||||
CFIELD="content_ltks"
|
|
||||||
EMBEDDING = EmbeddingModel
|
|
||||||
LLM = ChatModel
|
|
||||||
|
|
||||||
def get_QA_pairs(hists):
|
|
||||||
pa = []
|
|
||||||
for h in hists:
|
|
||||||
for k in ["user", "assistant"]:
|
|
||||||
if h.get(k):
|
|
||||||
pa.append({
|
|
||||||
"content": h[k],
|
|
||||||
"role": k,
|
|
||||||
})
|
|
||||||
|
|
||||||
for p in pa[:-1]: assert len(p) == 2, p
|
|
||||||
return pa
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"):
|
|
||||||
max_len //= len(top_i)
|
|
||||||
# add instruction to prompt
|
|
||||||
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
|
|
||||||
if len(instructions)>2:
|
|
||||||
# Said that LLM is sensitive to the first and the last one, so
|
|
||||||
# rearrange the order of references
|
|
||||||
instructions.append(copy.deepcopy(instructions[1]))
|
|
||||||
instructions.pop(1)
|
|
||||||
|
|
||||||
def token_num(txt):
|
|
||||||
c = 0
|
|
||||||
for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt):
|
|
||||||
if re.match(r"[a-zA-Z-]+$", tk):
|
|
||||||
c += 1
|
|
||||||
continue
|
|
||||||
c += len(tk)
|
|
||||||
return c
|
|
||||||
|
|
||||||
_inst = ""
|
|
||||||
for ins in instructions:
|
|
||||||
if token_num(_inst) > 4096:
|
|
||||||
_inst += "\n知识库:" + instructions[-1][:max_len]
|
|
||||||
break
|
|
||||||
_inst += "\n知识库:" + ins[:max_len]
|
|
||||||
return _inst
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_and_answer(history, inst):
|
|
||||||
hist = get_QA_pairs(history)
|
|
||||||
chks = []
|
|
||||||
for s in re.split(r"[::;;。\n\r]+", inst):
|
|
||||||
if s: chks.append(s)
|
|
||||||
chks = len(set(chks))/(0.1+len(chks))
|
|
||||||
print("Duplication portion:", chks)
|
|
||||||
|
|
||||||
system = """
|
|
||||||
你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。
|
|
||||||
以下是知识库:
|
|
||||||
%s
|
|
||||||
以上是知识库。
|
|
||||||
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
|
|
||||||
|
|
||||||
print("【PROMPT】:", system)
|
|
||||||
start = timer()
|
|
||||||
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
|
|
||||||
print("GENERATE: ", timer()-start)
|
|
||||||
print("===>>", response)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class Handler(RequestHandler):
|
|
||||||
def post(self):
|
|
||||||
global SE,MUST_TK_NUM
|
|
||||||
param = json.loads(self.request.body.decode('utf-8'))
|
|
||||||
try:
|
|
||||||
question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
|
|
||||||
res = SE.search({
|
|
||||||
"question": question,
|
|
||||||
"kb_ids": param.get("kb_ids", []),
|
|
||||||
"size": param.get("topn", 15)},
|
|
||||||
search.index_name(param["uid"])
|
|
||||||
)
|
|
||||||
|
|
||||||
sim = SE.rerank(res, question)
|
|
||||||
rk_idx = np.argsort(sim*-1)
|
|
||||||
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
|
|
||||||
inst = get_instruction(res, topidx)
|
|
||||||
|
|
||||||
ans, topidx = prompt_and_answer(param["history"], inst)
|
|
||||||
ans = SE.insert_citations(ans, topidx, res)
|
|
||||||
|
|
||||||
refer = OrderedDict()
|
|
||||||
docnms = {}
|
|
||||||
for i in rk_idx:
|
|
||||||
did = res.field[res.ids[i]]["doc_id"]
|
|
||||||
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"]
|
|
||||||
if did not in refer: refer[did] = []
|
|
||||||
refer[did].append({
|
|
||||||
"chunk_id": res.ids[i],
|
|
||||||
"content": res.field[res.ids[i]]["content_ltks"],
|
|
||||||
"image": ""
|
|
||||||
})
|
|
||||||
|
|
||||||
print("::::::::::::::", ans)
|
|
||||||
self.write(json.dumps({
|
|
||||||
"code":0,
|
|
||||||
"msg":"success",
|
|
||||||
"data":{
|
|
||||||
"uid": param["uid"],
|
|
||||||
"dialog_id": param["dialog_id"],
|
|
||||||
"assistant": ans,
|
|
||||||
"refer": [{
|
|
||||||
"did": did,
|
|
||||||
"doc_name": docnms[did],
|
|
||||||
"chunks": chunks
|
|
||||||
} for did, chunks in refer.items()]
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("Request 500: "+str(e))
|
|
||||||
self.write(json.dumps({
|
|
||||||
"code":500,
|
|
||||||
"msg":str(e),
|
|
||||||
"data":{}
|
|
||||||
}))
|
|
||||||
print(traceback.format_exc())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--port", default=4455, type=int, help="Port used for service")
|
|
||||||
ARGS = parser.parse_args()
|
|
||||||
|
|
||||||
SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING)
|
|
||||||
|
|
||||||
app = Application([(r'/v1/chat/completions', Handler)],debug=False)
|
|
||||||
http_server = HTTPServer(app)
|
|
||||||
http_server.bind(ARGS.port)
|
|
||||||
http_server.start(3)
|
|
||||||
|
|
||||||
IOLoop.current().start()
|
|
||||||
|
|
@ -1,258 +0,0 @@
|
|||||||
import json, os, sys, hashlib, copy, time, random, re
|
|
||||||
from os.path import dirname, realpath
|
|
||||||
sys.path.append(dirname(realpath(__file__)) + "/../")
|
|
||||||
from util.es_conn import HuEs
|
|
||||||
from util.db_conn import Postgres
|
|
||||||
from util.minio_conn import HuMinio
|
|
||||||
from util import rmSpace, findMaxDt
|
|
||||||
from FlagEmbedding import FlagModel
|
|
||||||
from nlp import huchunk, huqie, search
|
|
||||||
from io import BytesIO
|
|
||||||
import pandas as pd
|
|
||||||
from elasticsearch_dsl import Q
|
|
||||||
from PIL import Image
|
|
||||||
from parser import (
|
|
||||||
PdfParser,
|
|
||||||
DocxParser,
|
|
||||||
ExcelParser
|
|
||||||
)
|
|
||||||
from nlp.huchunk import (
|
|
||||||
PdfChunker,
|
|
||||||
DocxChunker,
|
|
||||||
ExcelChunker,
|
|
||||||
PptChunker,
|
|
||||||
TextChunker
|
|
||||||
)
|
|
||||||
|
|
||||||
ES = HuEs("infiniflow")
|
|
||||||
BATCH_SIZE = 64
|
|
||||||
PG = Postgres("infiniflow", "docgpt")
|
|
||||||
MINIO = HuMinio("infiniflow")
|
|
||||||
|
|
||||||
PDF = PdfChunker(PdfParser())
|
|
||||||
DOC = DocxChunker(DocxParser())
|
|
||||||
EXC = ExcelChunker(ExcelParser())
|
|
||||||
PPT = PptChunker()
|
|
||||||
|
|
||||||
def chuck_doc(name, binary):
|
|
||||||
suff = os.path.split(name)[-1].lower().split(".")[-1]
|
|
||||||
if suff.find("pdf") >= 0: return PDF(binary)
|
|
||||||
if suff.find("doc") >= 0: return DOC(binary)
|
|
||||||
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
|
|
||||||
if suff.find("ppt") >= 0: return PPT(binary)
|
|
||||||
if os.envirement.get("PARSE_IMAGE") \
|
|
||||||
and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
|
|
||||||
name.lower()):
|
|
||||||
from llm import CvModel
|
|
||||||
txt = CvModel.describe(binary)
|
|
||||||
field = TextChunker.Fields()
|
|
||||||
field.text_chunks = [(txt, binary)]
|
|
||||||
field.table_chunks = []
|
|
||||||
|
|
||||||
|
|
||||||
return TextChunker()(binary)
|
|
||||||
|
|
||||||
|
|
||||||
def collect(comm, mod, tm):
|
|
||||||
sql = f"""
|
|
||||||
select
|
|
||||||
id as kb2doc_id,
|
|
||||||
kb_id,
|
|
||||||
did,
|
|
||||||
updated_at,
|
|
||||||
is_deleted
|
|
||||||
from kb2_doc
|
|
||||||
where
|
|
||||||
updated_at >= '{tm}'
|
|
||||||
and kb_progress = 0
|
|
||||||
and MOD(did, {comm}) = {mod}
|
|
||||||
order by updated_at asc
|
|
||||||
limit 1000
|
|
||||||
"""
|
|
||||||
kb2doc = PG.select(sql)
|
|
||||||
if len(kb2doc) == 0:return pd.DataFrame()
|
|
||||||
|
|
||||||
sql = """
|
|
||||||
select
|
|
||||||
did,
|
|
||||||
uid,
|
|
||||||
doc_name,
|
|
||||||
location,
|
|
||||||
size
|
|
||||||
from doc_info
|
|
||||||
where
|
|
||||||
did in (%s)
|
|
||||||
"""%",".join([str(i) for i in kb2doc["did"].unique()])
|
|
||||||
docs = PG.select(sql)
|
|
||||||
docs = docs.fillna("")
|
|
||||||
docs = docs.join(kb2doc.set_index("did"), on="did", how="left")
|
|
||||||
|
|
||||||
mtm = str(docs["updated_at"].max())[:19]
|
|
||||||
print("TOTAL:", len(docs), "To: ", mtm)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
def set_progress(kb2doc_id, prog, msg="Processing..."):
|
|
||||||
sql = f"""
|
|
||||||
update kb2_doc set kb_progress={prog}, kb_progress_msg='{msg}'
|
|
||||||
where
|
|
||||||
id={kb2doc_id}
|
|
||||||
"""
|
|
||||||
PG.update(sql)
|
|
||||||
|
|
||||||
|
|
||||||
def build(row):
|
|
||||||
if row["size"] > 256000000:
|
|
||||||
set_progress(row["kb2doc_id"], -1, "File size exceeds( <= 256Mb )")
|
|
||||||
return []
|
|
||||||
res = ES.search(Q("term", doc_id=row["did"]))
|
|
||||||
if ES.getTotal(res) > 0:
|
|
||||||
ES.updateScriptByQuery(Q("term", doc_id=row["did"]),
|
|
||||||
scripts="""
|
|
||||||
if(!ctx._source.kb_id.contains('%s'))
|
|
||||||
ctx._source.kb_id.add('%s');
|
|
||||||
"""%(str(row["kb_id"]), str(row["kb_id"])),
|
|
||||||
idxnm = search.index_name(row["uid"])
|
|
||||||
)
|
|
||||||
set_progress(row["kb2doc_id"], 1, "Done")
|
|
||||||
return []
|
|
||||||
|
|
||||||
random.seed(time.time())
|
|
||||||
set_progress(row["kb2doc_id"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!")
|
|
||||||
try:
|
|
||||||
obj = chuck_doc(row["doc_name"], MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
|
|
||||||
except Exception as e:
|
|
||||||
if re.search("(No such file|not found)", str(e)):
|
|
||||||
set_progress(row["kb2doc_id"], -1, "Can not find file <%s>"%row["doc_name"])
|
|
||||||
else:
|
|
||||||
set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not obj.text_chunks and not obj.table_chunks:
|
|
||||||
set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
|
|
||||||
return []
|
|
||||||
|
|
||||||
set_progress(row["kb2doc_id"], random.randint(20, 60)/100., "Finished slicing files. Start to embedding the content.")
|
|
||||||
|
|
||||||
doc = {
|
|
||||||
"doc_id": row["did"],
|
|
||||||
"kb_id": [str(row["kb_id"])],
|
|
||||||
"docnm_kwd": os.path.split(row["location"])[-1],
|
|
||||||
"title_tks": huqie.qie(os.path.split(row["location"])[-1]),
|
|
||||||
"updated_at": str(row["updated_at"]).replace("T", " ")[:19]
|
|
||||||
}
|
|
||||||
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
|
||||||
output_buffer = BytesIO()
|
|
||||||
docs = []
|
|
||||||
md5 = hashlib.md5()
|
|
||||||
for txt, img in obj.text_chunks:
|
|
||||||
d = copy.deepcopy(doc)
|
|
||||||
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
|
||||||
d["_id"] = md5.hexdigest()
|
|
||||||
d["content_ltks"] = huqie.qie(txt)
|
|
||||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
|
||||||
if not img:
|
|
||||||
docs.append(d)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(img, Image): img.save(output_buffer, format='JPEG')
|
|
||||||
else: output_buffer = BytesIO(img)
|
|
||||||
|
|
||||||
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
|
|
||||||
output_buffer.getvalue())
|
|
||||||
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
|
|
||||||
docs.append(d)
|
|
||||||
|
|
||||||
for arr, img in obj.table_chunks:
|
|
||||||
for i, txt in enumerate(arr):
|
|
||||||
d = copy.deepcopy(doc)
|
|
||||||
d["content_ltks"] = huqie.qie(txt)
|
|
||||||
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
|
||||||
d["_id"] = md5.hexdigest()
|
|
||||||
if not img:
|
|
||||||
docs.append(d)
|
|
||||||
continue
|
|
||||||
img.save(output_buffer, format='JPEG')
|
|
||||||
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
|
|
||||||
output_buffer.getvalue())
|
|
||||||
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
|
|
||||||
docs.append(d)
|
|
||||||
set_progress(row["kb2doc_id"], random.randint(60, 70)/100., "Continue embedding the content.")
|
|
||||||
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
def init_kb(row):
|
|
||||||
idxnm = search.index_name(row["uid"])
|
|
||||||
if ES.indexExist(idxnm): return
|
|
||||||
return ES.createIdx(idxnm, json.load(open("conf/mapping.json", "r")))
|
|
||||||
|
|
||||||
|
|
||||||
model = None
|
|
||||||
def embedding(docs):
|
|
||||||
global model
|
|
||||||
tts = model.encode([rmSpace(d["title_tks"]) for d in docs])
|
|
||||||
cnts = model.encode([rmSpace(d["content_ltks"]) for d in docs])
|
|
||||||
vects = 0.1 * tts + 0.9 * cnts
|
|
||||||
assert len(vects) == len(docs)
|
|
||||||
for i,d in enumerate(docs):d["q_vec"] = vects[i].tolist()
|
|
||||||
|
|
||||||
|
|
||||||
def rm_doc_from_kb(df):
|
|
||||||
if len(df) == 0:return
|
|
||||||
for _,r in df.iterrows():
|
|
||||||
ES.updateScriptByQuery(Q("term", doc_id=r["did"]),
|
|
||||||
scripts="""
|
|
||||||
if(ctx._source.kb_id.contains('%s'))
|
|
||||||
ctx._source.kb_id.remove(
|
|
||||||
ctx._source.kb_id.indexOf('%s')
|
|
||||||
);
|
|
||||||
"""%(str(r["kb_id"]),str(r["kb_id"])),
|
|
||||||
idxnm = search.index_name(r["uid"])
|
|
||||||
)
|
|
||||||
if len(df) == 0:return
|
|
||||||
sql = """
|
|
||||||
delete from kb2_doc where id in (%s)
|
|
||||||
"""%",".join([str(i) for i in df["kb2doc_id"]])
|
|
||||||
PG.update(sql)
|
|
||||||
|
|
||||||
|
|
||||||
def main(comm, mod):
|
|
||||||
global model
|
|
||||||
from llm import HuEmbedding
|
|
||||||
model = HuEmbedding()
|
|
||||||
tm_fnm = f"res/{comm}-{mod}.tm"
|
|
||||||
tm = findMaxDt(tm_fnm)
|
|
||||||
rows = collect(comm, mod, tm)
|
|
||||||
if len(rows) == 0:return
|
|
||||||
|
|
||||||
rm_doc_from_kb(rows.loc[rows.is_deleted == True])
|
|
||||||
rows = rows.loc[rows.is_deleted == False].reset_index(drop=True)
|
|
||||||
if len(rows) == 0:return
|
|
||||||
tmf = open(tm_fnm, "a+")
|
|
||||||
for _, r in rows.iterrows():
|
|
||||||
cks = build(r)
|
|
||||||
if not cks:
|
|
||||||
tmf.write(str(r["updated_at"]) + "\n")
|
|
||||||
continue
|
|
||||||
## TODO: exception handler
|
|
||||||
## set_progress(r["did"], -1, "ERROR: ")
|
|
||||||
embedding(cks)
|
|
||||||
|
|
||||||
set_progress(r["kb2doc_id"], random.randint(70, 95)/100.,
|
|
||||||
"Finished embedding! Start to build index!")
|
|
||||||
init_kb(r)
|
|
||||||
es_r = ES.bulk(cks, search.index_name(r["uid"]))
|
|
||||||
if es_r:
|
|
||||||
set_progress(r["kb2doc_id"], -1, "Index failure!")
|
|
||||||
print(es_r)
|
|
||||||
else: set_progress(r["kb2doc_id"], 1., "Done!")
|
|
||||||
tmf.write(str(r["updated_at"]) + "\n")
|
|
||||||
tmf.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from mpi4py import MPI
|
|
||||||
comm = MPI.COMM_WORLD
|
|
||||||
main(comm.Get_size(), comm.Get_rank())
|
|
||||||
|
|
15
python/tmp.log
Normal file
15
python/tmp.log
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s]
Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 106184.91it/s]
|
||||||
|
----------- Model Configuration -----------
|
||||||
|
Model Arch: GFL
|
||||||
|
Transform Order:
|
||||||
|
--transform op: Resize
|
||||||
|
--transform op: NormalizeImage
|
||||||
|
--transform op: Permute
|
||||||
|
--transform op: PadStride
|
||||||
|
--------------------------------------------
|
||||||
|
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
|
||||||
|
The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.
|
||||||
|
Some weights of the model checkpoint at microsoft/table-transformer-structure-recognition were not used when initializing TableTransformerForObjectDetection: ['model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
|
||||||
|
- This IS expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
|
||||||
|
- This IS NOT expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
|
||||||
|
WARNING:root:The files are stored in /opt/home/kevinhu/docgpt/, please check it!
|
@ -1,31 +0,0 @@
|
|||||||
from configparser import ConfigParser
|
|
||||||
import os
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
CF = ConfigParser()
|
|
||||||
__fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
|
|
||||||
if not os.path.exists(__fnm):
|
|
||||||
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
|
|
||||||
assert os.path.exists(
|
|
||||||
__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
|
|
||||||
if not os.path.exists(__fnm):
|
|
||||||
__fnm = "./sys.cnf"
|
|
||||||
|
|
||||||
CF.read(__fnm)
|
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
def __init__(self, env):
|
|
||||||
self.env = env
|
|
||||||
if env == "spark":
|
|
||||||
CF.read("./cv.cnf")
|
|
||||||
|
|
||||||
def get(self, key, default=None):
|
|
||||||
global CF
|
|
||||||
return os.environ.get(key.upper(),
|
|
||||||
CF[self.env].get(key, default)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def init(env):
|
|
||||||
return Config(env)
|
|
@ -1,70 +0,0 @@
|
|||||||
import logging
|
|
||||||
import time
|
|
||||||
from util import config
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
|
|
||||||
class Postgres(object):
|
|
||||||
def __init__(self, env, dbnm):
|
|
||||||
self.config = config.init(env)
|
|
||||||
self.conn = None
|
|
||||||
self.dbnm = dbnm
|
|
||||||
self.__open__()
|
|
||||||
|
|
||||||
def __open__(self):
|
|
||||||
import psycopg2
|
|
||||||
try:
|
|
||||||
if self.conn:
|
|
||||||
self.__close__()
|
|
||||||
del self.conn
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.conn = psycopg2.connect(f"""dbname={self.dbnm}
|
|
||||||
user={self.config.get('postgres_user')}
|
|
||||||
password={self.config.get('postgres_password')}
|
|
||||||
host={self.config.get('postgres_host')}
|
|
||||||
port={self.config.get('postgres_port')}""")
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(
|
|
||||||
"Fail to connect %s " %
|
|
||||||
self.config.get("pgdb_host") + str(e))
|
|
||||||
|
|
||||||
def __close__(self):
|
|
||||||
try:
|
|
||||||
self.conn.close()
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(
|
|
||||||
"Fail to close %s " %
|
|
||||||
self.config.get("pgdb_host") + str(e))
|
|
||||||
|
|
||||||
def select(self, sql):
|
|
||||||
for _ in range(10):
|
|
||||||
try:
|
|
||||||
return pd.read_sql(sql, self.conn)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Fail to exec {sql} " + str(e))
|
|
||||||
self.__open__()
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
def update(self, sql):
|
|
||||||
for _ in range(10):
|
|
||||||
try:
|
|
||||||
cur = self.conn.cursor()
|
|
||||||
cur.execute(sql)
|
|
||||||
updated_rows = cur.rowcount
|
|
||||||
self.conn.commit()
|
|
||||||
cur.close()
|
|
||||||
return updated_rows
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Fail to exec {sql} " + str(e))
|
|
||||||
self.__open__()
|
|
||||||
time.sleep(1)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
Postgres("infiniflow", "docgpt")
|
|
@ -1,36 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging.config
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def log_dir():
|
|
||||||
fnm = os.path.join(os.path.dirname(__file__), '../log/')
|
|
||||||
if not os.path.exists(fnm):
|
|
||||||
fnm = os.path.join(os.path.dirname(__file__), '../../log/')
|
|
||||||
assert os.path.exists(fnm), f"Can't locate log dir: {fnm}"
|
|
||||||
return fnm
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(default_path="conf/logging.json",
|
|
||||||
default_level=logging.INFO,
|
|
||||||
env_key="LOG_CFG"):
|
|
||||||
path = default_path
|
|
||||||
value = os.getenv(env_key, None)
|
|
||||||
if value:
|
|
||||||
path = value
|
|
||||||
if os.path.exists(path):
|
|
||||||
with open(path, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
fnm = log_dir()
|
|
||||||
|
|
||||||
config["handlers"]["info_file_handler"]["filename"] = fnm + "info.log"
|
|
||||||
config["handlers"]["error_file_handler"]["filename"] = fnm + "error.log"
|
|
||||||
logging.config.dictConfig(config)
|
|
||||||
else:
|
|
||||||
logging.basicConfig(level=default_level)
|
|
||||||
|
|
||||||
|
|
||||||
__fnm = os.path.join(os.path.dirname(__file__), 'conf/logging.json')
|
|
||||||
if not os.path.exists(__fnm):
|
|
||||||
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/logging.json')
|
|
||||||
setup_logging(__fnm)
|
|
0
rag/__init__.py
Normal file
0
rag/__init__.py
Normal file
32
rag/llm/__init__.py
Normal file
32
rag/llm/__init__.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from .embedding_model import *
|
||||||
|
from .chat_model import *
|
||||||
|
from .cv_model import *
|
||||||
|
|
||||||
|
|
||||||
|
EmbeddingModel = {
|
||||||
|
"local": HuEmbedding,
|
||||||
|
"OpenAI": OpenAIEmbed,
|
||||||
|
"通义千问": QWenEmbed,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CvModel = {
|
||||||
|
"OpenAI": GptV4,
|
||||||
|
"通义千问": QWenCV,
|
||||||
|
}
|
||||||
|
|
@ -1,3 +1,18 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import os
|
import os
|
@ -1,3 +1,18 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import os
|
import os
|
||||||
@ -6,6 +21,9 @@ from io import BytesIO
|
|||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
|
def __init__(self, key, model_name):
|
||||||
|
pass
|
||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image, max_tokens=300):
|
||||||
raise NotImplementedError("Please implement encode method!")
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
@ -40,14 +58,15 @@ class Base(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class GptV4(Base):
|
class GptV4(Base):
|
||||||
def __init__(self):
|
def __init__(self, key, model_name="gpt-4-vision-preview"):
|
||||||
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
self.client = OpenAI(key)
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image, max_tokens=300):
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
|
|
||||||
res = self.client.chat.completions.create(
|
res = self.client.chat.completions.create(
|
||||||
model="gpt-4-vision-preview",
|
model=self.model_name,
|
||||||
messages=self.prompt(b64),
|
messages=self.prompt(b64),
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
@ -55,11 +74,15 @@ class GptV4(Base):
|
|||||||
|
|
||||||
|
|
||||||
class QWenCV(Base):
|
class QWenCV(Base):
|
||||||
|
def __init__(self, key, model_name="qwen-vl-chat-v1"):
|
||||||
|
import dashscope
|
||||||
|
dashscope.api_key = key
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
def describe(self, image, max_tokens=300):
|
def describe(self, image, max_tokens=300):
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from dashscope import MultiModalConversation
|
from dashscope import MultiModalConversation
|
||||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
response = MultiModalConversation.call(model=self.model_name,
|
||||||
response = MultiModalConversation.call(model=MultiModalConversation.Models.qwen_vl_chat_v1,
|
|
||||||
messages=self.prompt(self.image2base64(image)))
|
messages=self.prompt(self.image2base64(image)))
|
||||||
if response.status_code == HTTPStatus.OK:
|
if response.status_code == HTTPStatus.OK:
|
||||||
return response.output.choices[0]['message']['content']
|
return response.output.choices[0]['message']['content']
|
94
rag/llm/embedding_model.py
Normal file
94
rag/llm/embedding_model.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
import dashscope
|
||||||
|
from openai import OpenAI
|
||||||
|
from FlagEmbedding import FlagModel
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from rag.utils import num_tokens_from_string
|
||||||
|
|
||||||
|
|
||||||
|
class Base(ABC):
|
||||||
|
def __init__(self, key, model_name):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def encode(self, texts: list, batch_size=32):
|
||||||
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
|
|
||||||
|
class HuEmbedding(Base):
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||||
|
|
||||||
|
For Linux:
|
||||||
|
export HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
|
||||||
|
For Windows:
|
||||||
|
Good luck
|
||||||
|
^_-
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
|
||||||
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||||
|
use_fp16=torch.cuda.is_available())
|
||||||
|
|
||||||
|
|
||||||
|
def encode(self, texts: list, batch_size=32):
|
||||||
|
token_count = 0
|
||||||
|
for t in texts: token_count += num_tokens_from_string(t)
|
||||||
|
res = []
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
||||||
|
return np.array(res), token_count
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbed(Base):
|
||||||
|
def __init__(self, key, model_name="text-embedding-ada-002"):
|
||||||
|
self.client = OpenAI(key)
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def encode(self, texts: list, batch_size=32):
|
||||||
|
token_count = 0
|
||||||
|
for t in texts: token_count += num_tokens_from_string(t)
|
||||||
|
res = self.client.embeddings.create(input=texts,
|
||||||
|
model=self.model_name)
|
||||||
|
return [d["embedding"] for d in res["data"]], token_count
|
||||||
|
|
||||||
|
|
||||||
|
class QWenEmbed(Base):
|
||||||
|
def __init__(self, key, model_name="text_embedding_v2"):
|
||||||
|
dashscope.api_key = key
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def encode(self, texts: list, batch_size=32, text_type="document"):
|
||||||
|
import dashscope
|
||||||
|
res = []
|
||||||
|
token_count = 0
|
||||||
|
for txt in texts:
|
||||||
|
resp = dashscope.TextEmbedding.call(
|
||||||
|
model=self.model_name,
|
||||||
|
input=txt[:2048],
|
||||||
|
text_type=text_type
|
||||||
|
)
|
||||||
|
res.append(resp["output"]["embeddings"][0]["embedding"])
|
||||||
|
token_count += resp["usage"]["total_tokens"]
|
||||||
|
return res, token_count
|
0
rag/nlp/__init__.py
Normal file
0
rag/nlp/__init__.py
Normal file
11
python/nlp/huqie.py → rag/nlp/huqie.py
Executable file → Normal file
11
python/nlp/huqie.py → rag/nlp/huqie.py
Executable file → Normal file
@ -9,6 +9,8 @@ import string
|
|||||||
import sys
|
import sys
|
||||||
from hanziconv import HanziConv
|
from hanziconv import HanziConv
|
||||||
|
|
||||||
|
from web_server.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
||||||
class Huqie:
|
class Huqie:
|
||||||
def key_(self, line):
|
def key_(self, line):
|
||||||
@ -41,14 +43,7 @@ class Huqie:
|
|||||||
self.DEBUG = debug
|
self.DEBUG = debug
|
||||||
self.DENOMINATOR = 1000000
|
self.DENOMINATOR = 1000000
|
||||||
self.trie_ = datrie.Trie(string.printable)
|
self.trie_ = datrie.Trie(string.printable)
|
||||||
self.DIR_ = ""
|
self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
|
||||||
if os.path.exists("../res/huqie.txt"):
|
|
||||||
self.DIR_ = "../res/huqie"
|
|
||||||
if os.path.exists("./res/huqie.txt"):
|
|
||||||
self.DIR_ = "./res/huqie"
|
|
||||||
if os.path.exists("./huqie.txt"):
|
|
||||||
self.DIR_ = "./huqie"
|
|
||||||
assert self.DIR_, f"【Can't find huqie】"
|
|
||||||
|
|
||||||
self.SPLIT_CHAR = r"([ ,\.<>/?;'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
|
self.SPLIT_CHAR = r"([ ,\.<>/?;'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
|
||||||
try:
|
try:
|
6
python/nlp/query.py → rag/nlp/query.py
Executable file → Normal file
6
python/nlp/query.py → rag/nlp/query.py
Executable file → Normal file
@ -1,12 +1,12 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
from elasticsearch_dsl import Q, Search
|
from elasticsearch_dsl import Q, Search
|
||||||
from nlp import huqie, term_weight, synonym
|
from rag.nlp import huqie, term_weight, synonym
|
||||||
|
|
||||||
|
|
||||||
class EsQueryer:
|
class EsQueryer:
|
@ -1,13 +1,11 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
import re
|
import re
|
||||||
from elasticsearch_dsl import Q, Search, A
|
from elasticsearch_dsl import Q, Search, A
|
||||||
from typing import List, Optional, Tuple, Dict, Union
|
from typing import List, Optional, Tuple, Dict, Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from util import setup_logging, rmSpace
|
from rag.utils import rmSpace
|
||||||
from nlp import huqie, query
|
from rag.nlp import huqie, query
|
||||||
from datetime import datetime
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
|
|
||||||
def index_name(uid): return f"docgpt_{uid}"
|
def index_name(uid): return f"docgpt_{uid}"
|
13
python/nlp/synonym.py → rag/nlp/synonym.py
Executable file → Normal file
13
python/nlp/synonym.py → rag/nlp/synonym.py
Executable file → Normal file
@ -1,8 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from web_server.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
||||||
class Dealer:
|
class Dealer:
|
||||||
def __init__(self, redis=None):
|
def __init__(self, redis=None):
|
||||||
@ -10,15 +13,9 @@ class Dealer:
|
|||||||
self.lookup_num = 100000000
|
self.lookup_num = 100000000
|
||||||
self.load_tm = time.time() - 1000000
|
self.load_tm = time.time() - 1000000
|
||||||
self.dictionary = None
|
self.dictionary = None
|
||||||
|
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
|
||||||
try:
|
try:
|
||||||
self.dictionary = json.load(open("./synonym.json", 'r'))
|
self.dictionary = json.load(open(path, 'r'))
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
self.dictionary = json.load(open("./res/synonym.json", 'r'))
|
|
||||||
except Exception as e:
|
|
||||||
try:
|
|
||||||
self.dictionary = json.load(open("../res/synonym.json", 'r'))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warn("Miss synonym.json")
|
logging.warn("Miss synonym.json")
|
||||||
self.dictionary = {}
|
self.dictionary = {}
|
12
python/nlp/term_weight.py → rag/nlp/term_weight.py
Executable file → Normal file
12
python/nlp/term_weight.py → rag/nlp/term_weight.py
Executable file → Normal file
@ -1,9 +1,11 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
import math
|
import math
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from nlp import huqie
|
from rag.nlp import huqie
|
||||||
|
from web_server.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
||||||
class Dealer:
|
class Dealer:
|
||||||
@ -60,16 +62,14 @@ class Dealer:
|
|||||||
return set(res.keys())
|
return set(res.keys())
|
||||||
return res
|
return res
|
||||||
|
|
||||||
fnm = os.path.join(os.path.dirname(__file__), '../res/')
|
fnm = os.path.join(get_project_base_directory(), "res")
|
||||||
if not os.path.exists(fnm):
|
|
||||||
fnm = os.path.join(os.path.dirname(__file__), '../../res/')
|
|
||||||
self.ne, self.df = {}, {}
|
self.ne, self.df = {}, {}
|
||||||
try:
|
try:
|
||||||
self.ne = json.load(open(fnm + "ner.json", "r"))
|
self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("[WARNING] Load ner.json FAIL!")
|
print("[WARNING] Load ner.json FAIL!")
|
||||||
try:
|
try:
|
||||||
self.df = load_dict(fnm + "term.freq")
|
self.df = load_dict(os.path.join(fnm, "term.freq"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("[WARNING] Load term.freq FAIL!")
|
print("[WARNING] Load term.freq FAIL!")
|
||||||
|
|
@ -1,8 +1,9 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
from docx import Document
|
from docx import Document
|
||||||
import re
|
import re
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from nlp import huqie
|
from rag.nlp import huqie
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
33
rag/parser/excel_parser.py
Normal file
33
rag/parser/excel_parser.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from openpyxl import load_workbook
|
||||||
|
import sys
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
|
class HuExcelParser:
|
||||||
|
def __call__(self, fnm):
|
||||||
|
if isinstance(fnm, str):
|
||||||
|
wb = load_workbook(fnm)
|
||||||
|
else:
|
||||||
|
wb = load_workbook(BytesIO(fnm))
|
||||||
|
res = []
|
||||||
|
for sheetname in wb.sheetnames:
|
||||||
|
ws = wb[sheetname]
|
||||||
|
rows = list(ws.rows)
|
||||||
|
ti = list(rows[0])
|
||||||
|
for r in list(rows[1:]):
|
||||||
|
l = []
|
||||||
|
for i,c in enumerate(r):
|
||||||
|
if not c.value:continue
|
||||||
|
t = str(ti[i].value) if i < len(ti) else ""
|
||||||
|
t += (":" if t else "") + str(c.value)
|
||||||
|
l.append(t)
|
||||||
|
l = "; ".join(l)
|
||||||
|
if sheetname.lower().find("sheet") <0: l += " ——"+sheetname
|
||||||
|
res.append(l)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
psr = HuExcelParser()
|
||||||
|
psr(sys.argv[1])
|
@ -1,3 +1,4 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import torch
|
import torch
|
||||||
@ -6,11 +7,11 @@ import pdfplumber
|
|||||||
import logging
|
import logging
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from nlp import huqie
|
from rag.nlp import huqie
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from cv.table_recognize import TableTransformer
|
from rag.cv.table_recognize import TableTransformer
|
||||||
from cv.ppdetection import PPDet
|
from rag.cv.ppdetection import PPDet
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
logging.getLogger("pdfminer").setLevel(logging.WARNING)
|
logging.getLogger("pdfminer").setLevel(logging.WARNING)
|
||||||
|
|
0
python/res/ner.json → rag/res/ner.json
Executable file → Normal file
0
python/res/ner.json → rag/res/ner.json
Executable file → Normal file
37
rag/settings.py
Normal file
37
rag/settings.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import os
|
||||||
|
from web_server.utils import get_base_config,decrypt_database_config
|
||||||
|
from web_server.utils.file_utils import get_project_base_directory
|
||||||
|
from web_server.utils.log_utils import LoggerFactory, getLogger
|
||||||
|
|
||||||
|
|
||||||
|
# Server
|
||||||
|
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||||
|
SUBPROCESS_STD_LOG_NAME = "std.log"
|
||||||
|
|
||||||
|
ES = get_base_config("es", {})
|
||||||
|
MINIO = decrypt_database_config(name="minio")
|
||||||
|
DOC_MAXIMUM_SIZE = 64 * 1024 * 1024
|
||||||
|
|
||||||
|
# Logger
|
||||||
|
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag"))
|
||||||
|
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
|
||||||
|
LoggerFactory.LEVEL = 10
|
||||||
|
|
||||||
|
es_logger = getLogger("es")
|
||||||
|
minio_logger = getLogger("minio")
|
||||||
|
cron_logger = getLogger("cron_logger")
|
279
rag/svr/parse_user_docs.py
Normal file
279
rag/svr/parse_user_docs.py
Normal file
@ -0,0 +1,279 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
import copy
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
from rag.llm import EmbeddingModel, CvModel
|
||||||
|
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||||
|
from rag.utils import ELASTICSEARCH, num_tokens_from_string
|
||||||
|
from rag.utils import MINIO
|
||||||
|
from rag.utils import rmSpace, findMaxDt
|
||||||
|
from rag.nlp import huchunk, huqie, search
|
||||||
|
from io import BytesIO
|
||||||
|
import pandas as pd
|
||||||
|
from elasticsearch_dsl import Q
|
||||||
|
from PIL import Image
|
||||||
|
from rag.parser import (
|
||||||
|
PdfParser,
|
||||||
|
DocxParser,
|
||||||
|
ExcelParser
|
||||||
|
)
|
||||||
|
from rag.nlp.huchunk import (
|
||||||
|
PdfChunker,
|
||||||
|
DocxChunker,
|
||||||
|
ExcelChunker,
|
||||||
|
PptChunker,
|
||||||
|
TextChunker
|
||||||
|
)
|
||||||
|
from web_server.db import LLMType
|
||||||
|
from web_server.db.services.document_service import DocumentService
|
||||||
|
from web_server.db.services.llm_service import TenantLLMService
|
||||||
|
from web_server.utils import get_format_time
|
||||||
|
from web_server.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
|
||||||
|
PDF = PdfChunker(PdfParser())
|
||||||
|
DOC = DocxChunker(DocxParser())
|
||||||
|
EXC = ExcelChunker(ExcelParser())
|
||||||
|
PPT = PptChunker()
|
||||||
|
|
||||||
|
|
||||||
|
def chuck_doc(name, binary, cvmdl=None):
|
||||||
|
suff = os.path.split(name)[-1].lower().split(".")[-1]
|
||||||
|
if suff.find("pdf") >= 0:
|
||||||
|
return PDF(binary)
|
||||||
|
if suff.find("doc") >= 0:
|
||||||
|
return DOC(binary)
|
||||||
|
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff):
|
||||||
|
return EXC(binary)
|
||||||
|
if suff.find("ppt") >= 0:
|
||||||
|
return PPT(binary)
|
||||||
|
if cvmdl and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
|
||||||
|
name.lower()):
|
||||||
|
txt = cvmdl.describe(binary)
|
||||||
|
field = TextChunker.Fields()
|
||||||
|
field.text_chunks = [(txt, binary)]
|
||||||
|
field.table_chunks = []
|
||||||
|
|
||||||
|
return TextChunker()(binary)
|
||||||
|
|
||||||
|
|
||||||
|
def collect(comm, mod, tm):
|
||||||
|
docs = DocumentService.get_newly_uploaded(tm, mod, comm)
|
||||||
|
if len(docs) == 0:
|
||||||
|
return pd.DataFrame()
|
||||||
|
docs = pd.DataFrame(docs)
|
||||||
|
mtm = str(docs["update_time"].max())[:19]
|
||||||
|
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def set_progress(docid, prog, msg="Processing...", begin=False):
|
||||||
|
d = {"progress": prog, "progress_msg": msg}
|
||||||
|
if begin:
|
||||||
|
d["process_begin_at"] = get_format_time()
|
||||||
|
try:
|
||||||
|
DocumentService.update_by_id(
|
||||||
|
docid, {"progress": prog, "progress_msg": msg})
|
||||||
|
except Exception as e:
|
||||||
|
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
|
||||||
|
|
||||||
|
|
||||||
|
def build(row):
|
||||||
|
if row["size"] > DOC_MAXIMUM_SIZE:
|
||||||
|
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
|
||||||
|
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||||
|
return []
|
||||||
|
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
|
||||||
|
if ELASTICSEARCH.getTotal(res) > 0:
|
||||||
|
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
|
||||||
|
scripts="""
|
||||||
|
if(!ctx._source.kb_id.contains('%s'))
|
||||||
|
ctx._source.kb_id.add('%s');
|
||||||
|
""" % (str(row["kb_id"]), str(row["kb_id"])),
|
||||||
|
idxnm=search.index_name(row["tenant_id"])
|
||||||
|
)
|
||||||
|
set_progress(row["id"], 1, "Done")
|
||||||
|
return []
|
||||||
|
|
||||||
|
random.seed(time.time())
|
||||||
|
set_progress(row["id"], random.randint(0, 20) /
|
||||||
|
100., "Finished preparing! Start to slice file!", True)
|
||||||
|
try:
|
||||||
|
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]))
|
||||||
|
except Exception as e:
|
||||||
|
if re.search("(No such file|not found)", str(e)):
|
||||||
|
set_progress(
|
||||||
|
row["id"], -1, "Can not find file <%s>" %
|
||||||
|
row["doc_name"])
|
||||||
|
else:
|
||||||
|
set_progress(
|
||||||
|
row["id"], -1, f"Internal server error: %s" %
|
||||||
|
str(e).replace(
|
||||||
|
"'", ""))
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not obj.text_chunks and not obj.table_chunks:
|
||||||
|
set_progress(
|
||||||
|
row["id"],
|
||||||
|
1,
|
||||||
|
"Nothing added! Mostly, file type unsupported yet.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
set_progress(row["id"], random.randint(20, 60) / 100.,
|
||||||
|
"Finished slicing files. Start to embedding the content.")
|
||||||
|
|
||||||
|
doc = {
|
||||||
|
"doc_id": row["did"],
|
||||||
|
"kb_id": [str(row["kb_id"])],
|
||||||
|
"docnm_kwd": os.path.split(row["location"])[-1],
|
||||||
|
"title_tks": huqie.qie(row["name"]),
|
||||||
|
"updated_at": str(row["update_time"]).replace("T", " ")[:19]
|
||||||
|
}
|
||||||
|
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
||||||
|
output_buffer = BytesIO()
|
||||||
|
docs = []
|
||||||
|
md5 = hashlib.md5()
|
||||||
|
for txt, img in obj.text_chunks:
|
||||||
|
d = copy.deepcopy(doc)
|
||||||
|
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
||||||
|
d["_id"] = md5.hexdigest()
|
||||||
|
d["content_ltks"] = huqie.qie(txt)
|
||||||
|
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||||
|
if not img:
|
||||||
|
docs.append(d)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(img, Image):
|
||||||
|
img.save(output_buffer, format='JPEG')
|
||||||
|
else:
|
||||||
|
output_buffer = BytesIO(img)
|
||||||
|
|
||||||
|
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
||||||
|
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
||||||
|
docs.append(d)
|
||||||
|
|
||||||
|
for arr, img in obj.table_chunks:
|
||||||
|
for i, txt in enumerate(arr):
|
||||||
|
d = copy.deepcopy(doc)
|
||||||
|
d["content_ltks"] = huqie.qie(txt)
|
||||||
|
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
||||||
|
d["_id"] = md5.hexdigest()
|
||||||
|
if not img:
|
||||||
|
docs.append(d)
|
||||||
|
continue
|
||||||
|
img.save(output_buffer, format='JPEG')
|
||||||
|
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
||||||
|
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
||||||
|
docs.append(d)
|
||||||
|
set_progress(row["id"], random.randint(60, 70) /
|
||||||
|
100., "Continue embedding the content.")
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def init_kb(row):
|
||||||
|
idxnm = search.index_name(row["tenant_id"])
|
||||||
|
if ELASTICSEARCH.indexExist(idxnm):
|
||||||
|
return
|
||||||
|
return ELASTICSEARCH.createIdx(idxnm, json.load(
|
||||||
|
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
||||||
|
|
||||||
|
|
||||||
|
def embedding(docs, mdl):
|
||||||
|
tts, cnts = [rmSpace(d["title_tks"]) for d in docs], [rmSpace(d["content_ltks"]) for d in docs]
|
||||||
|
tk_count = 0
|
||||||
|
tts, c = mdl.encode(tts)
|
||||||
|
tk_count += c
|
||||||
|
cnts, c = mdl.encode(cnts)
|
||||||
|
tk_count += c
|
||||||
|
vects = 0.1 * tts + 0.9 * cnts
|
||||||
|
assert len(vects) == len(docs)
|
||||||
|
for i, d in enumerate(docs):
|
||||||
|
d["q_vec"] = vects[i].tolist()
|
||||||
|
return tk_count
|
||||||
|
|
||||||
|
|
||||||
|
def model_instance(tenant_id, llm_type):
|
||||||
|
model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING)
|
||||||
|
if not model_config:return
|
||||||
|
model_config = model_config[0]
|
||||||
|
if llm_type == LLMType.EMBEDDING:
|
||||||
|
if model_config.llm_factory not in EmbeddingModel: return
|
||||||
|
return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
|
||||||
|
if llm_type == LLMType.IMAGE2TEXT:
|
||||||
|
if model_config.llm_factory not in CvModel: return
|
||||||
|
return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
|
||||||
|
|
||||||
|
|
||||||
|
def main(comm, mod):
|
||||||
|
global model
|
||||||
|
from rag.llm import HuEmbedding
|
||||||
|
model = HuEmbedding()
|
||||||
|
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
|
||||||
|
tm = findMaxDt(tm_fnm)
|
||||||
|
rows = collect(comm, mod, tm)
|
||||||
|
if len(rows) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
tmf = open(tm_fnm, "a+")
|
||||||
|
for _, r in rows.iterrows():
|
||||||
|
embd_mdl = model_instance(r["tenant_id"], LLMType.EMBEDDING)
|
||||||
|
if not embd_mdl:
|
||||||
|
set_progress(r["id"], -1, "Can't find embedding model!")
|
||||||
|
cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"]))
|
||||||
|
continue
|
||||||
|
cv_mdl = model_instance(r["tenant_id"], LLMType.IMAGE2TEXT)
|
||||||
|
st_tm = timer()
|
||||||
|
cks = build(r, cv_mdl)
|
||||||
|
if not cks:
|
||||||
|
tmf.write(str(r["updated_at"]) + "\n")
|
||||||
|
continue
|
||||||
|
# TODO: exception handler
|
||||||
|
## set_progress(r["did"], -1, "ERROR: ")
|
||||||
|
try:
|
||||||
|
tk_count = embedding(cks, embd_mdl)
|
||||||
|
except Exception as e:
|
||||||
|
set_progress(r["id"], -1, "Embedding error:{}".format(str(e)))
|
||||||
|
cron_logger.error(str(e))
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
set_progress(r["id"], random.randint(70, 95) / 100.,
|
||||||
|
"Finished embedding! Start to build index!")
|
||||||
|
init_kb(r)
|
||||||
|
es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
|
||||||
|
if es_r:
|
||||||
|
set_progress(r["id"], -1, "Index failure!")
|
||||||
|
cron_logger.error(str(es_r))
|
||||||
|
else:
|
||||||
|
set_progress(r["id"], 1., "Done!")
|
||||||
|
DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm})
|
||||||
|
tmf.write(str(r["update_time"]) + "\n")
|
||||||
|
tmf.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from mpi4py import MPI
|
||||||
|
comm = MPI.COMM_WORLD
|
||||||
|
main(comm.Get_size(), comm.Get_rank())
|
@ -1,6 +1,23 @@
|
|||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
|
def singleton(cls, *args, **kw):
|
||||||
|
instances = {}
|
||||||
|
|
||||||
|
def _singleton():
|
||||||
|
key = str(cls) + str(os.getpid())
|
||||||
|
if key not in instances:
|
||||||
|
instances[key] = cls(*args, **kw)
|
||||||
|
return instances[key]
|
||||||
|
|
||||||
|
return _singleton
|
||||||
|
|
||||||
|
|
||||||
|
from .minio_conn import MINIO
|
||||||
|
from .es_conn import ELASTICSEARCH
|
||||||
|
|
||||||
def rmSpace(txt):
|
def rmSpace(txt):
|
||||||
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
|
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
|
||||||
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
|
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
|
||||||
@ -22,3 +39,9 @@ def findMaxDt(fnm):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("WARNING: can't find " + fnm)
|
print("WARNING: can't find " + fnm)
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
def num_tokens_from_string(string: str) -> int:
|
||||||
|
"""Returns the number of tokens in a text string."""
|
||||||
|
encoding = tiktoken.get_encoding('cl100k_base')
|
||||||
|
num_tokens = len(encoding.encode(string))
|
||||||
|
return num_tokens
|
92
python/util/es_conn.py → rag/utils/es_conn.py
Executable file → Normal file
92
python/util/es_conn.py → rag/utils/es_conn.py
Executable file → Normal file
@ -1,51 +1,39 @@
|
|||||||
import re
|
import re
|
||||||
import logging
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
import elasticsearch
|
import elasticsearch
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
from elasticsearch_dsl import UpdateByQuery, Search, Index, Q
|
from elasticsearch_dsl import UpdateByQuery, Search, Index
|
||||||
from util import config
|
from rag.settings import es_logger
|
||||||
|
from rag import settings
|
||||||
|
from rag.utils import singleton
|
||||||
|
|
||||||
logging.info("Elasticsearch version: ", elasticsearch.__version__)
|
es_logger.info("Elasticsearch version: "+ str(elasticsearch.__version__))
|
||||||
|
|
||||||
|
|
||||||
def instance(env):
|
|
||||||
CF = config.init(env)
|
|
||||||
ES_DRESS = CF.get("es").split(",")
|
|
||||||
|
|
||||||
ES = Elasticsearch(
|
|
||||||
ES_DRESS,
|
|
||||||
timeout=600
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("ES: ", ES_DRESS, ES.info())
|
|
||||||
|
|
||||||
return ES
|
|
||||||
|
|
||||||
|
|
||||||
|
@singleton
|
||||||
class HuEs:
|
class HuEs:
|
||||||
def __init__(self, env):
|
def __init__(self):
|
||||||
self.env = env
|
|
||||||
self.info = {}
|
self.info = {}
|
||||||
self.config = config.init(env)
|
|
||||||
self.conn()
|
self.conn()
|
||||||
self.idxnm = self.config.get("idx_nm", "")
|
self.idxnm = settings.ES.get("index_name", "")
|
||||||
if not self.es.ping():
|
if not self.es.ping():
|
||||||
raise Exception("Can't connect to ES cluster")
|
raise Exception("Can't connect to ES cluster")
|
||||||
|
|
||||||
def conn(self):
|
def conn(self):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
try:
|
try:
|
||||||
c = instance(self.env)
|
self.es = Elasticsearch(
|
||||||
if c:
|
settings.ES["hosts"].split(","),
|
||||||
self.es = c
|
timeout=600
|
||||||
self.info = c.info()
|
)
|
||||||
logging.info("Connect to es.")
|
if self.es:
|
||||||
|
self.info = self.es.info()
|
||||||
|
es_logger.info("Connect to es.")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Fail to connect to es: " + str(e))
|
es_logger.error("Fail to connect to es: " + str(e))
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def version(self):
|
def version(self):
|
||||||
@ -80,11 +68,11 @@ class HuEs:
|
|||||||
refresh=False,
|
refresh=False,
|
||||||
doc_type="_doc",
|
doc_type="_doc",
|
||||||
retry_on_conflict=100)
|
retry_on_conflict=100)
|
||||||
logging.info("Successfully upsert: %s" % id)
|
es_logger.info("Successfully upsert: %s" % id)
|
||||||
T = True
|
T = True
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("Fail to index: " +
|
es_logger.warning("Fail to index: " +
|
||||||
json.dumps(d, ensure_ascii=False) + str(e))
|
json.dumps(d, ensure_ascii=False) + str(e))
|
||||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
@ -94,7 +82,7 @@ class HuEs:
|
|||||||
|
|
||||||
if not T:
|
if not T:
|
||||||
res.append(d)
|
res.append(d)
|
||||||
logging.error(
|
es_logger.error(
|
||||||
"Fail to index: " +
|
"Fail to index: " +
|
||||||
re.sub(
|
re.sub(
|
||||||
"[\r\n]",
|
"[\r\n]",
|
||||||
@ -147,7 +135,7 @@ class HuEs:
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warn("Fail to bulk: " + str(e))
|
es_logger.warn("Fail to bulk: " + str(e))
|
||||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
continue
|
continue
|
||||||
@ -162,7 +150,7 @@ class HuEs:
|
|||||||
ids[id] = copy.deepcopy(d["raw"])
|
ids[id] = copy.deepcopy(d["raw"])
|
||||||
acts.append({"update": {"_id": id, "_index": self.idxnm}})
|
acts.append({"update": {"_id": id, "_index": self.idxnm}})
|
||||||
acts.append(d["script"])
|
acts.append(d["script"])
|
||||||
logging.info("bulk upsert: %s" % id)
|
es_logger.info("bulk upsert: %s" % id)
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
@ -189,7 +177,7 @@ class HuEs:
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("Fail to bulk: " + str(e))
|
es_logger.warning("Fail to bulk: " + str(e))
|
||||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
continue
|
continue
|
||||||
@ -212,10 +200,10 @@ class HuEs:
|
|||||||
id=d["id"],
|
id=d["id"],
|
||||||
refresh=True,
|
refresh=True,
|
||||||
doc_type="_doc")
|
doc_type="_doc")
|
||||||
logging.info("Remove %s" % d["id"])
|
es_logger.info("Remove %s" % d["id"])
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warn("Fail to delete: " + str(d) + str(e))
|
es_logger.warn("Fail to delete: " + str(d) + str(e))
|
||||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
continue
|
continue
|
||||||
@ -223,7 +211,7 @@ class HuEs:
|
|||||||
return True
|
return True
|
||||||
self.conn()
|
self.conn()
|
||||||
|
|
||||||
logging.error("Fail to delete: " + str(d))
|
es_logger.error("Fail to delete: " + str(d))
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -242,7 +230,7 @@ class HuEs:
|
|||||||
raise Exception("Es Timeout.")
|
raise Exception("Es Timeout.")
|
||||||
return res
|
return res
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
es_logger.error(
|
||||||
"ES search exception: " +
|
"ES search exception: " +
|
||||||
str(e) +
|
str(e) +
|
||||||
"【Q】:" +
|
"【Q】:" +
|
||||||
@ -250,7 +238,7 @@ class HuEs:
|
|||||||
if str(e).find("Timeout") > 0:
|
if str(e).find("Timeout") > 0:
|
||||||
continue
|
continue
|
||||||
raise e
|
raise e
|
||||||
logging.error("ES search timeout for 3 times!")
|
es_logger.error("ES search timeout for 3 times!")
|
||||||
raise Exception("ES search timeout.")
|
raise Exception("ES search timeout.")
|
||||||
|
|
||||||
def updateByQuery(self, q, d):
|
def updateByQuery(self, q, d):
|
||||||
@ -267,7 +255,7 @@ class HuEs:
|
|||||||
r = ubq.execute()
|
r = ubq.execute()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ES updateByQuery exception: " +
|
es_logger.error("ES updateByQuery exception: " +
|
||||||
str(e) + "【Q】:" + str(q.to_dict()))
|
str(e) + "【Q】:" + str(q.to_dict()))
|
||||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||||
continue
|
continue
|
||||||
@ -288,7 +276,7 @@ class HuEs:
|
|||||||
r = ubq.execute()
|
r = ubq.execute()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ES updateByQuery exception: " +
|
es_logger.error("ES updateByQuery exception: " +
|
||||||
str(e) + "【Q】:" + str(q.to_dict()))
|
str(e) + "【Q】:" + str(q.to_dict()))
|
||||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||||
continue
|
continue
|
||||||
@ -304,7 +292,7 @@ class HuEs:
|
|||||||
body=Search().query(query).to_dict())
|
body=Search().query(query).to_dict())
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ES updateByQuery deleteByQuery: " +
|
es_logger.error("ES updateByQuery deleteByQuery: " +
|
||||||
str(e) + "【Q】:" + str(query.to_dict()))
|
str(e) + "【Q】:" + str(query.to_dict()))
|
||||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||||
continue
|
continue
|
||||||
@ -329,7 +317,8 @@ class HuEs:
|
|||||||
routing=routing, refresh=False) # , doc_type="_doc")
|
routing=routing, refresh=False) # , doc_type="_doc")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
|
es_logger.error(
|
||||||
|
"ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
|
||||||
json.dumps(script, ensure_ascii=False))
|
json.dumps(script, ensure_ascii=False))
|
||||||
if str(e).find("Timeout") > 0:
|
if str(e).find("Timeout") > 0:
|
||||||
continue
|
continue
|
||||||
@ -342,7 +331,7 @@ class HuEs:
|
|||||||
try:
|
try:
|
||||||
return s.exists()
|
return s.exists()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ES updateByQuery indexExist: " + str(e))
|
es_logger.error("ES updateByQuery indexExist: " + str(e))
|
||||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -354,7 +343,7 @@ class HuEs:
|
|||||||
return self.es.exists(index=(idxnm if idxnm else self.idxnm),
|
return self.es.exists(index=(idxnm if idxnm else self.idxnm),
|
||||||
id=docid)
|
id=docid)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ES Doc Exist: " + str(e))
|
es_logger.error("ES Doc Exist: " + str(e))
|
||||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||||
continue
|
continue
|
||||||
return False
|
return False
|
||||||
@ -368,13 +357,13 @@ class HuEs:
|
|||||||
settings=mapping["settings"],
|
settings=mapping["settings"],
|
||||||
mappings=mapping["mappings"])
|
mappings=mapping["mappings"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ES create index error %s ----%s" % (idxnm, str(e)))
|
es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
|
||||||
|
|
||||||
def deleteIdx(self, idxnm):
|
def deleteIdx(self, idxnm):
|
||||||
try:
|
try:
|
||||||
return self.es.indices.delete(idxnm, allow_no_indices=True)
|
return self.es.indices.delete(idxnm, allow_no_indices=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ES delete index error %s ----%s" % (idxnm, str(e)))
|
es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
|
||||||
|
|
||||||
def getTotal(self, res):
|
def getTotal(self, res):
|
||||||
if isinstance(res["hits"]["total"], type({})):
|
if isinstance(res["hits"]["total"], type({})):
|
||||||
@ -405,12 +394,12 @@ class HuEs:
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ES scrolling fail. " + str(e))
|
es_logger.error("ES scrolling fail. " + str(e))
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
|
|
||||||
sid = page['_scroll_id']
|
sid = page['_scroll_id']
|
||||||
scroll_size = page['hits']['total']["value"]
|
scroll_size = page['hits']['total']["value"]
|
||||||
logging.info("[TOTAL]%d" % scroll_size)
|
es_logger.info("[TOTAL]%d" % scroll_size)
|
||||||
# Start scrolling
|
# Start scrolling
|
||||||
while scroll_size > 0:
|
while scroll_size > 0:
|
||||||
yield page["hits"]["hits"]
|
yield page["hits"]["hits"]
|
||||||
@ -419,10 +408,13 @@ class HuEs:
|
|||||||
page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
|
page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ES scrolling fail. " + str(e))
|
es_logger.error("ES scrolling fail. " + str(e))
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
|
|
||||||
# Update the scroll ID
|
# Update the scroll ID
|
||||||
sid = page['_scroll_id']
|
sid = page['_scroll_id']
|
||||||
# Get the number of results that we returned in the last scroll
|
# Get the number of results that we returned in the last scroll
|
||||||
scroll_size = len(page['hits']['hits'])
|
scroll_size = len(page['hits']['hits'])
|
||||||
|
|
||||||
|
|
||||||
|
ELASTICSEARCH = HuEs()
|
@ -1,13 +1,15 @@
|
|||||||
import logging
|
import os
|
||||||
import time
|
import time
|
||||||
from util import config
|
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from rag import settings
|
||||||
|
from rag.settings import minio_logger
|
||||||
|
from rag.utils import singleton
|
||||||
|
|
||||||
|
|
||||||
|
@singleton
|
||||||
class HuMinio(object):
|
class HuMinio(object):
|
||||||
def __init__(self, env):
|
def __init__(self):
|
||||||
self.config = config.init(env)
|
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self.__open__()
|
self.__open__()
|
||||||
|
|
||||||
@ -19,15 +21,14 @@ class HuMinio(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.conn = Minio(self.config.get("minio_host"),
|
self.conn = Minio(settings.MINIO["host"],
|
||||||
access_key=self.config.get("minio_user"),
|
access_key=settings.MINIO["user"],
|
||||||
secret_key=self.config.get("minio_password"),
|
secret_key=settings.MINIO["passwd"],
|
||||||
secure=False
|
secure=False
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
minio_logger.error(
|
||||||
"Fail to connect %s " %
|
"Fail to connect %s " % settings.MINIO["host"] + str(e))
|
||||||
self.config.get("minio_host") + str(e))
|
|
||||||
|
|
||||||
def __close__(self):
|
def __close__(self):
|
||||||
del self.conn
|
del self.conn
|
||||||
@ -45,34 +46,51 @@ class HuMinio(object):
|
|||||||
)
|
)
|
||||||
return r
|
return r
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Fail put {bucket}/{fnm}: " + str(e))
|
minio_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
|
||||||
self.__open__()
|
self.__open__()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
|
def rm(self, bucket, fnm):
|
||||||
|
try:
|
||||||
|
self.conn.remove_object(bucket, fnm)
|
||||||
|
except Exception as e:
|
||||||
|
minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
|
||||||
|
|
||||||
|
|
||||||
def get(self, bucket, fnm):
|
def get(self, bucket, fnm):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
try:
|
try:
|
||||||
r = self.conn.get_object(bucket, fnm)
|
r = self.conn.get_object(bucket, fnm)
|
||||||
return r.read()
|
return r.read()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
|
minio_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
|
||||||
self.__open__()
|
self.__open__()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def obj_exist(self, bucket, fnm):
|
||||||
|
try:
|
||||||
|
if self.conn.stat_object(bucket, fnm):return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
minio_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_presigned_url(self, bucket, fnm, expires):
|
def get_presigned_url(self, bucket, fnm, expires):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
try:
|
try:
|
||||||
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
|
minio_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
|
||||||
self.__open__()
|
self.__open__()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
MINIO = HuMinio()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
conn = HuMinio("infiniflow")
|
conn = HuMinio()
|
||||||
fnm = "/opt/home/kevinhu/docgpt/upload/13/11-408.jpg"
|
fnm = "/opt/home/kevinhu/docgpt/upload/13/11-408.jpg"
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
img = Image.open(fnm)
|
img = Image.open(fnm)
|
@ -1,178 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use actix_web::{ HttpResponse, post, web };
|
|
||||||
use serde::Deserialize;
|
|
||||||
use serde_json::Value;
|
|
||||||
use serde_json::json;
|
|
||||||
use crate::api::JsonResponse;
|
|
||||||
use crate::AppState;
|
|
||||||
use crate::errors::AppError;
|
|
||||||
use crate::service::dialog_info::Query;
|
|
||||||
use crate::service::dialog_info::Mutation;
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct ListParams {
|
|
||||||
pub uid: i64,
|
|
||||||
pub dialog_id: Option<i64>,
|
|
||||||
}
|
|
||||||
#[post("/v1.0/dialogs")]
|
|
||||||
async fn list(
|
|
||||||
params: web::Json<ListParams>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
if let Some(dia_id) = params.dialog_id {
|
|
||||||
let dia = Query::find_dialog_info_by_id(&data.conn, dia_id).await?.unwrap();
|
|
||||||
let kb = crate::service::kb_info::Query
|
|
||||||
::find_kb_info_by_id(&data.conn, dia.kb_id).await?
|
|
||||||
.unwrap();
|
|
||||||
print!("{:?}", dia.history);
|
|
||||||
let hist: Value = serde_json::from_str(&dia.history)?;
|
|
||||||
let detail =
|
|
||||||
json!({
|
|
||||||
"dialog_id": dia_id,
|
|
||||||
"dialog_name": dia.dialog_name.to_owned(),
|
|
||||||
"created_at": dia.created_at.to_string().to_owned(),
|
|
||||||
"updated_at": dia.updated_at.to_string().to_owned(),
|
|
||||||
"history": hist,
|
|
||||||
"kb_info": kb
|
|
||||||
});
|
|
||||||
|
|
||||||
result.insert("dialogs", vec![detail]);
|
|
||||||
} else {
|
|
||||||
let mut dias = Vec::<Value>::new();
|
|
||||||
for dia in Query::find_dialog_infos_by_uid(&data.conn, params.uid).await? {
|
|
||||||
let kb = crate::service::kb_info::Query
|
|
||||||
::find_kb_info_by_id(&data.conn, dia.kb_id).await?
|
|
||||||
.unwrap();
|
|
||||||
let hist: Value = serde_json::from_str(&dia.history)?;
|
|
||||||
dias.push(
|
|
||||||
json!({
|
|
||||||
"dialog_id": dia.dialog_id,
|
|
||||||
"dialog_name": dia.dialog_name.to_owned(),
|
|
||||||
"created_at": dia.created_at.to_string().to_owned(),
|
|
||||||
"updated_at": dia.updated_at.to_string().to_owned(),
|
|
||||||
"history": hist,
|
|
||||||
"kb_info": kb
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
result.insert("dialogs", dias);
|
|
||||||
}
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: result,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct RmParams {
|
|
||||||
pub uid: i64,
|
|
||||||
pub dialog_id: i64,
|
|
||||||
}
|
|
||||||
#[post("/v1.0/delete_dialog")]
|
|
||||||
async fn delete(
|
|
||||||
params: web::Json<RmParams>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let _ = Mutation::delete_dialog_info(&data.conn, params.dialog_id).await?;
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: (),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct CreateParams {
|
|
||||||
pub uid: i64,
|
|
||||||
pub dialog_id: Option<i64>,
|
|
||||||
pub kb_id: i64,
|
|
||||||
pub name: String,
|
|
||||||
}
|
|
||||||
#[post("/v1.0/create_dialog")]
|
|
||||||
async fn create(
|
|
||||||
param: web::Json<CreateParams>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
if let Some(dia_id) = param.dialog_id {
|
|
||||||
result.insert("dialog_id", dia_id);
|
|
||||||
let dia = Query::find_dialog_info_by_id(&data.conn, dia_id).await?;
|
|
||||||
let _ = Mutation::update_dialog_info_by_id(
|
|
||||||
&data.conn,
|
|
||||||
dia_id,
|
|
||||||
¶m.name,
|
|
||||||
&dia.unwrap().history
|
|
||||||
).await?;
|
|
||||||
} else {
|
|
||||||
let dia = Mutation::create_dialog_info(
|
|
||||||
&data.conn,
|
|
||||||
param.uid,
|
|
||||||
param.kb_id,
|
|
||||||
¶m.name
|
|
||||||
).await?;
|
|
||||||
result.insert("dialog_id", dia.dialog_id.unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: result,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct UpdateHistoryParams {
|
|
||||||
pub uid: i64,
|
|
||||||
pub dialog_id: i64,
|
|
||||||
pub history: Value,
|
|
||||||
}
|
|
||||||
#[post("/v1.0/update_history")]
|
|
||||||
async fn update_history(
|
|
||||||
param: web::Json<UpdateHistoryParams>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let mut json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: (),
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(dia) = Query::find_dialog_info_by_id(&data.conn, param.dialog_id).await? {
|
|
||||||
let _ = Mutation::update_dialog_info_by_id(
|
|
||||||
&data.conn,
|
|
||||||
param.dialog_id,
|
|
||||||
&dia.dialog_name,
|
|
||||||
¶m.history.to_string()
|
|
||||||
).await?;
|
|
||||||
} else {
|
|
||||||
json_response.code = 500;
|
|
||||||
json_response.err = "Can't find dialog data!".to_owned();
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
@ -1,314 +0,0 @@
|
|||||||
use std::collections::{HashMap};
|
|
||||||
use std::io::BufReader;
|
|
||||||
use actix_multipart_extract::{ File, Multipart, MultipartForm };
|
|
||||||
use actix_web::web::Bytes;
|
|
||||||
use actix_web::{ HttpResponse, post, web };
|
|
||||||
use chrono::{ Utc, FixedOffset };
|
|
||||||
use minio::s3::args::{ BucketExistsArgs, MakeBucketArgs, PutObjectArgs };
|
|
||||||
use sea_orm::DbConn;
|
|
||||||
use crate::api::JsonResponse;
|
|
||||||
use crate::AppState;
|
|
||||||
use crate::entity::doc_info::Model;
|
|
||||||
use crate::errors::AppError;
|
|
||||||
use crate::service::doc_info::{ Mutation, Query };
|
|
||||||
use serde::Deserialize;
|
|
||||||
use regex::Regex;
|
|
||||||
|
|
||||||
fn now() -> chrono::DateTime<FixedOffset> {
|
|
||||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct ListParams {
|
|
||||||
pub uid: i64,
|
|
||||||
pub filter: FilterParams,
|
|
||||||
pub sortby: String,
|
|
||||||
pub page: Option<u32>,
|
|
||||||
pub per_page: Option<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct FilterParams {
|
|
||||||
pub keywords: Option<String>,
|
|
||||||
pub folder_id: Option<i64>,
|
|
||||||
pub tag_id: Option<i64>,
|
|
||||||
pub kb_id: Option<i64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/docs")]
|
|
||||||
async fn list(
|
|
||||||
params: web::Json<ListParams>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let docs = Query::find_doc_infos_by_params(&data.conn, params.into_inner()).await?;
|
|
||||||
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
result.insert("docs", docs);
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: result,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, MultipartForm, Debug)]
|
|
||||||
pub struct UploadForm {
|
|
||||||
#[multipart(max_size = 512MB)]
|
|
||||||
file_field: File,
|
|
||||||
uid: i64,
|
|
||||||
did: i64,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn file_type(filename: &String) -> String {
|
|
||||||
let fnm = filename.to_lowercase();
|
|
||||||
if
|
|
||||||
let Some(_) = Regex::new(r"\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$")
|
|
||||||
.unwrap()
|
|
||||||
.captures(&fnm)
|
|
||||||
{
|
|
||||||
return "Video".to_owned();
|
|
||||||
}
|
|
||||||
if
|
|
||||||
let Some(_) = Regex::new(
|
|
||||||
r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$"
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
.captures(&fnm)
|
|
||||||
{
|
|
||||||
return "Picture".to_owned();
|
|
||||||
}
|
|
||||||
if
|
|
||||||
let Some(_) = Regex::new(r"\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$")
|
|
||||||
.unwrap()
|
|
||||||
.captures(&fnm)
|
|
||||||
{
|
|
||||||
return "Music".to_owned();
|
|
||||||
}
|
|
||||||
if
|
|
||||||
let Some(_) = Regex::new(r"\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$")
|
|
||||||
.unwrap()
|
|
||||||
.captures(&fnm)
|
|
||||||
{
|
|
||||||
return "Document".to_owned();
|
|
||||||
}
|
|
||||||
"Other".to_owned()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#[post("/v1.0/upload")]
|
|
||||||
async fn upload(
|
|
||||||
payload: Multipart<UploadForm>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
|
|
||||||
|
|
||||||
Ok(HttpResponse::Ok().body("File uploaded successfully"))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn _upload_file(uid: i64, did: i64, file_name: &str, bytes: &[u8], data: &web::Data<AppState>) -> Result<(), AppError> {
|
|
||||||
async fn add_number_to_filename(
|
|
||||||
file_name: &str,
|
|
||||||
conn: &DbConn,
|
|
||||||
uid: i64,
|
|
||||||
parent_id: i64
|
|
||||||
) -> String {
|
|
||||||
let mut i = 0;
|
|
||||||
let mut new_file_name = file_name.to_string();
|
|
||||||
let arr: Vec<&str> = file_name.split(".").collect();
|
|
||||||
let suffix = String::from(arr[arr.len() - 1]);
|
|
||||||
let preffix = arr[..arr.len() - 1].join(".");
|
|
||||||
let mut docs = Query::find_doc_infos_by_name(
|
|
||||||
conn,
|
|
||||||
uid,
|
|
||||||
&new_file_name,
|
|
||||||
Some(parent_id)
|
|
||||||
).await.unwrap();
|
|
||||||
while docs.len() > 0 {
|
|
||||||
i += 1;
|
|
||||||
new_file_name = format!("{}_{}.{}", preffix, i, suffix);
|
|
||||||
docs = Query::find_doc_infos_by_name(
|
|
||||||
conn,
|
|
||||||
uid,
|
|
||||||
&new_file_name,
|
|
||||||
Some(parent_id)
|
|
||||||
).await.unwrap();
|
|
||||||
}
|
|
||||||
new_file_name
|
|
||||||
}
|
|
||||||
let fnm = add_number_to_filename(file_name, &data.conn, uid, did).await;
|
|
||||||
|
|
||||||
let bucket_name = format!("{}-upload", uid);
|
|
||||||
let s3_client: &minio::s3::client::Client = &data.s3_client;
|
|
||||||
let buckets_exists = s3_client
|
|
||||||
.bucket_exists(&BucketExistsArgs::new(&bucket_name).unwrap()).await
|
|
||||||
.unwrap();
|
|
||||||
if !buckets_exists {
|
|
||||||
print!("Create bucket: {}", bucket_name.clone());
|
|
||||||
s3_client.make_bucket(&MakeBucketArgs::new(&bucket_name).unwrap()).await.unwrap();
|
|
||||||
} else {
|
|
||||||
print!("Existing bucket: {}", bucket_name.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
let location = format!("/{}/{}", did, fnm)
|
|
||||||
.as_bytes()
|
|
||||||
.to_vec()
|
|
||||||
.iter()
|
|
||||||
.map(|b| format!("{:02x}", b).to_string())
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join("");
|
|
||||||
print!("===>{}", location.clone());
|
|
||||||
s3_client.put_object(
|
|
||||||
&mut PutObjectArgs::new(
|
|
||||||
&bucket_name,
|
|
||||||
&location,
|
|
||||||
&mut BufReader::new(bytes),
|
|
||||||
Some(bytes.len()),
|
|
||||||
None
|
|
||||||
)?
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
let doc = Mutation::create_doc_info(&data.conn, Model {
|
|
||||||
did: Default::default(),
|
|
||||||
uid: uid,
|
|
||||||
doc_name: fnm.clone(),
|
|
||||||
size: bytes.len() as i64,
|
|
||||||
location,
|
|
||||||
r#type: file_type(&fnm),
|
|
||||||
thumbnail_base64: Default::default(),
|
|
||||||
created_at: now(),
|
|
||||||
updated_at: now(),
|
|
||||||
is_deleted: Default::default(),
|
|
||||||
}).await?;
|
|
||||||
|
|
||||||
let _ = Mutation::place_doc(&data.conn, did, doc.did.unwrap()).await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct RmDocsParam {
|
|
||||||
uid: i64,
|
|
||||||
dids: Vec<i64>,
|
|
||||||
}
|
|
||||||
#[post("/v1.0/delete_docs")]
|
|
||||||
async fn delete(
|
|
||||||
params: web::Json<RmDocsParam>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let _ = Mutation::delete_doc_info(&data.conn, ¶ms.dids).await?;
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: (),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct MvParams {
|
|
||||||
pub uid: i64,
|
|
||||||
pub dids: Vec<i64>,
|
|
||||||
pub dest_did: i64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/mv_docs")]
|
|
||||||
async fn mv(
|
|
||||||
params: web::Json<MvParams>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
Mutation::mv_doc_info(&data.conn, params.dest_did, ¶ms.dids).await?;
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: (),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct NewFoldParams {
|
|
||||||
pub uid: i64,
|
|
||||||
pub parent_id: i64,
|
|
||||||
pub name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/new_folder")]
|
|
||||||
async fn new_folder(
|
|
||||||
params: web::Json<NewFoldParams>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let doc = Mutation::create_doc_info(&data.conn, Model {
|
|
||||||
did: Default::default(),
|
|
||||||
uid: params.uid,
|
|
||||||
doc_name: params.name.to_string(),
|
|
||||||
size: 0,
|
|
||||||
r#type: "folder".to_string(),
|
|
||||||
location: "".to_owned(),
|
|
||||||
thumbnail_base64: Default::default(),
|
|
||||||
created_at: now(),
|
|
||||||
updated_at: now(),
|
|
||||||
is_deleted: Default::default(),
|
|
||||||
}).await?;
|
|
||||||
let _ = Mutation::place_doc(&data.conn, params.parent_id, doc.did.unwrap()).await?;
|
|
||||||
|
|
||||||
Ok(HttpResponse::Ok().body("Folder created successfully"))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct RenameParams {
|
|
||||||
pub uid: i64,
|
|
||||||
pub did: i64,
|
|
||||||
pub name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/rename")]
|
|
||||||
async fn rename(
|
|
||||||
params: web::Json<RenameParams>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let docs = Query::find_doc_infos_by_name(&data.conn, params.uid, ¶ms.name, None).await?;
|
|
||||||
if docs.len() > 0 {
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 500,
|
|
||||||
err: "Name duplicated!".to_owned(),
|
|
||||||
data: (),
|
|
||||||
};
|
|
||||||
return Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let doc = Mutation::rename(&data.conn, params.did, ¶ms.name).await?;
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: doc,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
@ -1,166 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use actix_web::{ get, HttpResponse, post, web };
|
|
||||||
use serde::Serialize;
|
|
||||||
use crate::api::JsonResponse;
|
|
||||||
use crate::AppState;
|
|
||||||
use crate::entity::kb_info;
|
|
||||||
use crate::errors::AppError;
|
|
||||||
use crate::service::kb_info::Mutation;
|
|
||||||
use crate::service::kb_info::Query;
|
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
||||||
pub struct AddDocs2KbParams {
|
|
||||||
pub uid: i64,
|
|
||||||
pub dids: Vec<i64>,
|
|
||||||
pub kb_id: i64,
|
|
||||||
}
|
|
||||||
#[post("/v1.0/create_kb")]
|
|
||||||
async fn create(
|
|
||||||
model: web::Json<kb_info::Model>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let mut docs = Query::find_kb_infos_by_name(
|
|
||||||
&data.conn,
|
|
||||||
model.kb_name.to_owned()
|
|
||||||
).await.unwrap();
|
|
||||||
if docs.len() > 0 {
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 201,
|
|
||||||
err: "Duplicated name.".to_owned(),
|
|
||||||
data: (),
|
|
||||||
};
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
let model = Mutation::create_kb_info(&data.conn, model.into_inner()).await?;
|
|
||||||
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
result.insert("kb_id", model.kb_id.unwrap());
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: result,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/add_docs_to_kb")]
|
|
||||||
async fn add_docs_to_kb(
|
|
||||||
param: web::Json<AddDocs2KbParams>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let _ = Mutation::add_docs(&data.conn, param.kb_id, param.dids.to_owned()).await?;
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: (),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/anti_kb_docs")]
|
|
||||||
async fn anti_kb_docs(
|
|
||||||
param: web::Json<AddDocs2KbParams>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let _ = Mutation::remove_docs(&data.conn, param.dids.to_owned(), Some(param.kb_id)).await?;
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: (),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
#[get("/v1.0/kbs")]
|
|
||||||
async fn list(
|
|
||||||
model: web::Json<kb_info::Model>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let kbs = Query::find_kb_infos_by_uid(&data.conn, model.uid).await?;
|
|
||||||
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
result.insert("kbs", kbs);
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: result,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/delete_kb")]
|
|
||||||
async fn delete(
|
|
||||||
model: web::Json<kb_info::Model>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let _ = Mutation::delete_kb_info(&data.conn, model.kb_id).await?;
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: (),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
||||||
pub struct DocIdsParams {
|
|
||||||
pub uid: i64,
|
|
||||||
pub dids: Vec<i64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/all_relevents")]
|
|
||||||
async fn all_relevents(
|
|
||||||
params: web::Json<DocIdsParams>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let dids = crate::service::doc_info::Query::all_descendent_ids(&data.conn, ¶ms.dids).await?;
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
let kbs = Query::find_kb_by_docs(&data.conn, dids).await?;
|
|
||||||
result.insert("kbs", kbs);
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: result,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
@ -1,14 +0,0 @@
|
|||||||
use serde::{ Deserialize, Serialize };
|
|
||||||
|
|
||||||
pub(crate) mod tag_info;
|
|
||||||
pub(crate) mod kb_info;
|
|
||||||
pub(crate) mod dialog_info;
|
|
||||||
pub(crate) mod doc_info;
|
|
||||||
pub(crate) mod user_info;
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
|
||||||
struct JsonResponse<T> {
|
|
||||||
code: u32,
|
|
||||||
err: String,
|
|
||||||
data: T,
|
|
||||||
}
|
|
@ -1,81 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use actix_web::{ get, HttpResponse, post, web };
|
|
||||||
use serde::Deserialize;
|
|
||||||
use crate::api::JsonResponse;
|
|
||||||
use crate::AppState;
|
|
||||||
use crate::entity::tag_info;
|
|
||||||
use crate::errors::AppError;
|
|
||||||
use crate::service::tag_info::{ Mutation, Query };
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct TagListParams {
|
|
||||||
pub uid: i64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/create_tag")]
|
|
||||||
async fn create(
|
|
||||||
model: web::Json<tag_info::Model>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let model = Mutation::create_tag(&data.conn, model.into_inner()).await?;
|
|
||||||
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
result.insert("tid", model.tid.unwrap());
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: result,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/delete_tag")]
|
|
||||||
async fn delete(
|
|
||||||
model: web::Json<tag_info::Model>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let _ = Mutation::delete_tag(&data.conn, model.tid).await?;
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: (),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
//#[get("/v1.0/tags", wrap = "HttpAuthentication::bearer(validator)")]
|
|
||||||
|
|
||||||
#[post("/v1.0/tags")]
|
|
||||||
async fn list(
|
|
||||||
param: web::Json<TagListParams>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let tags = Query::find_tags_by_uid(param.uid, &data.conn).await?;
|
|
||||||
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
result.insert("tags", tags);
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: result,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
@ -1,149 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::io::SeekFrom;
|
|
||||||
use std::ptr::null;
|
|
||||||
use actix_identity::Identity;
|
|
||||||
use actix_web::{ HttpResponse, post, web };
|
|
||||||
use chrono::{ FixedOffset, Utc };
|
|
||||||
use sea_orm::ActiveValue::NotSet;
|
|
||||||
use serde::{ Deserialize, Serialize };
|
|
||||||
use crate::api::JsonResponse;
|
|
||||||
use crate::AppState;
|
|
||||||
use crate::entity::{ doc_info, tag_info };
|
|
||||||
use crate::entity::user_info::Model;
|
|
||||||
use crate::errors::{ AppError, UserError };
|
|
||||||
use crate::service::user_info::Mutation;
|
|
||||||
use crate::service::user_info::Query;
|
|
||||||
|
|
||||||
fn now() -> chrono::DateTime<FixedOffset> {
|
|
||||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn create_auth_token(user: &Model) -> u64 {
|
|
||||||
use std::{ collections::hash_map::DefaultHasher, hash::{ Hash, Hasher } };
|
|
||||||
|
|
||||||
let mut hasher = DefaultHasher::new();
|
|
||||||
user.hash(&mut hasher);
|
|
||||||
hasher.finish()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
||||||
pub(crate) struct LoginParams {
|
|
||||||
pub(crate) email: String,
|
|
||||||
pub(crate) password: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/login")]
|
|
||||||
async fn login(
|
|
||||||
data: web::Data<AppState>,
|
|
||||||
identity: Identity,
|
|
||||||
input: web::Json<LoginParams>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
match Query::login(&data.conn, &input.email, &input.password).await? {
|
|
||||||
Some(user) => {
|
|
||||||
let _ = Mutation::update_login_status(user.uid, &data.conn).await?;
|
|
||||||
let token = create_auth_token(&user).to_string();
|
|
||||||
|
|
||||||
identity.remember(token.clone());
|
|
||||||
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: token.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.append_header(("X-Auth-Token", token))
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
None => Err(UserError::LoginFailed.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/register")]
|
|
||||||
async fn register(
|
|
||||||
model: web::Json<Model>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let mut result = HashMap::new();
|
|
||||||
let u = Query::find_user_infos(&data.conn, &model.email).await?;
|
|
||||||
if let Some(_) = u {
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 500,
|
|
||||||
err: "Email registered!".to_owned(),
|
|
||||||
data: (),
|
|
||||||
};
|
|
||||||
return Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let usr = Mutation::create_user(&data.conn, &model).await?;
|
|
||||||
result.insert("uid", usr.uid.clone().unwrap());
|
|
||||||
crate::service::doc_info::Mutation::create_doc_info(&data.conn, doc_info::Model {
|
|
||||||
did: Default::default(),
|
|
||||||
uid: usr.uid.clone().unwrap(),
|
|
||||||
doc_name: "/".into(),
|
|
||||||
size: 0,
|
|
||||||
location: "".into(),
|
|
||||||
thumbnail_base64: "".into(),
|
|
||||||
r#type: "folder".to_string(),
|
|
||||||
created_at: now(),
|
|
||||||
updated_at: now(),
|
|
||||||
is_deleted: Default::default(),
|
|
||||||
}).await?;
|
|
||||||
let tnm = vec!["Video", "Picture", "Music", "Document"];
|
|
||||||
let tregx = vec![
|
|
||||||
".*\\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa)",
|
|
||||||
".*\\.(png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng)",
|
|
||||||
".*\\.(WAV|FLAC|APE|ALAC|WavPack|WV|MP3|AAC|Ogg|Vorbis|Opus)",
|
|
||||||
".*\\.(pdf|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp)"
|
|
||||||
];
|
|
||||||
for i in 0..4 {
|
|
||||||
crate::service::tag_info::Mutation::create_tag(&data.conn, tag_info::Model {
|
|
||||||
tid: Default::default(),
|
|
||||||
uid: usr.uid.clone().unwrap(),
|
|
||||||
tag_name: tnm[i].to_owned(),
|
|
||||||
regx: tregx[i].to_owned(),
|
|
||||||
color: (i + 1).to_owned() as i16,
|
|
||||||
icon: (i + 1).to_owned() as i16,
|
|
||||||
folder_id: 0,
|
|
||||||
created_at: Default::default(),
|
|
||||||
updated_at: Default::default(),
|
|
||||||
}).await?;
|
|
||||||
}
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: result,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1.0/setting")]
|
|
||||||
async fn setting(
|
|
||||||
model: web::Json<Model>,
|
|
||||||
data: web::Data<AppState>
|
|
||||||
) -> Result<HttpResponse, AppError> {
|
|
||||||
let _ = Mutation::update_user_by_id(&data.conn, &model).await?;
|
|
||||||
let json_response = JsonResponse {
|
|
||||||
code: 200,
|
|
||||||
err: "".to_owned(),
|
|
||||||
data: (),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(
|
|
||||||
HttpResponse::Ok()
|
|
||||||
.content_type("application/json")
|
|
||||||
.body(serde_json::to_string(&json_response)?)
|
|
||||||
)
|
|
||||||
}
|
|
@ -1,38 +0,0 @@
|
|||||||
use sea_orm::entity::prelude::*;
|
|
||||||
use serde::{ Deserialize, Serialize };
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
|
|
||||||
#[sea_orm(table_name = "dialog2_kb")]
|
|
||||||
pub struct Model {
|
|
||||||
#[sea_orm(primary_key, auto_increment = true)]
|
|
||||||
pub id: i64,
|
|
||||||
#[sea_orm(index)]
|
|
||||||
pub dialog_id: i64,
|
|
||||||
#[sea_orm(index)]
|
|
||||||
pub kb_id: i64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, EnumIter)]
|
|
||||||
pub enum Relation {
|
|
||||||
DialogInfo,
|
|
||||||
KbInfo,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RelationTrait for Relation {
|
|
||||||
fn def(&self) -> RelationDef {
|
|
||||||
match self {
|
|
||||||
Self::DialogInfo =>
|
|
||||||
Entity::belongs_to(super::dialog_info::Entity)
|
|
||||||
.from(Column::DialogId)
|
|
||||||
.to(super::dialog_info::Column::DialogId)
|
|
||||||
.into(),
|
|
||||||
Self::KbInfo =>
|
|
||||||
Entity::belongs_to(super::kb_info::Entity)
|
|
||||||
.from(Column::KbId)
|
|
||||||
.to(super::kb_info::Column::KbId)
|
|
||||||
.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,38 +0,0 @@
|
|||||||
use chrono::{ DateTime, FixedOffset };
|
|
||||||
use sea_orm::entity::prelude::*;
|
|
||||||
use serde::{ Deserialize, Serialize };
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
|
|
||||||
#[sea_orm(table_name = "dialog_info")]
|
|
||||||
pub struct Model {
|
|
||||||
#[sea_orm(primary_key, auto_increment = false)]
|
|
||||||
pub dialog_id: i64,
|
|
||||||
#[sea_orm(index)]
|
|
||||||
pub uid: i64,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub kb_id: i64,
|
|
||||||
pub dialog_name: String,
|
|
||||||
pub history: String,
|
|
||||||
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub created_at: DateTime<FixedOffset>,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub updated_at: DateTime<FixedOffset>,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub is_deleted: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
|
||||||
pub enum Relation {}
|
|
||||||
|
|
||||||
impl Related<super::kb_info::Entity> for Entity {
|
|
||||||
fn to() -> RelationDef {
|
|
||||||
super::dialog2_kb::Relation::KbInfo.def()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn via() -> Option<RelationDef> {
|
|
||||||
Some(super::dialog2_kb::Relation::DialogInfo.def().rev())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,38 +0,0 @@
|
|||||||
use sea_orm::entity::prelude::*;
|
|
||||||
use serde::{ Deserialize, Serialize };
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
|
|
||||||
#[sea_orm(table_name = "doc2_doc")]
|
|
||||||
pub struct Model {
|
|
||||||
#[sea_orm(primary_key, auto_increment = true)]
|
|
||||||
pub id: i64,
|
|
||||||
#[sea_orm(index)]
|
|
||||||
pub parent_id: i64,
|
|
||||||
#[sea_orm(index)]
|
|
||||||
pub did: i64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, EnumIter)]
|
|
||||||
pub enum Relation {
|
|
||||||
Parent,
|
|
||||||
Child,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RelationTrait for Relation {
|
|
||||||
fn def(&self) -> RelationDef {
|
|
||||||
match self {
|
|
||||||
Self::Parent =>
|
|
||||||
Entity::belongs_to(super::doc_info::Entity)
|
|
||||||
.from(Column::ParentId)
|
|
||||||
.to(super::doc_info::Column::Did)
|
|
||||||
.into(),
|
|
||||||
Self::Child =>
|
|
||||||
Entity::belongs_to(super::doc_info::Entity)
|
|
||||||
.from(Column::Did)
|
|
||||||
.to(super::doc_info::Column::Did)
|
|
||||||
.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,62 +0,0 @@
|
|||||||
use sea_orm::entity::prelude::*;
|
|
||||||
use serde::{ Deserialize, Serialize };
|
|
||||||
use crate::entity::kb_info;
|
|
||||||
use chrono::{ DateTime, FixedOffset };
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Deserialize, Serialize)]
|
|
||||||
#[sea_orm(table_name = "doc_info")]
|
|
||||||
pub struct Model {
|
|
||||||
#[sea_orm(primary_key, auto_increment = false)]
|
|
||||||
pub did: i64,
|
|
||||||
#[sea_orm(index)]
|
|
||||||
pub uid: i64,
|
|
||||||
pub doc_name: String,
|
|
||||||
pub size: i64,
|
|
||||||
#[sea_orm(column_name = "type")]
|
|
||||||
pub r#type: String,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub location: String,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub thumbnail_base64: String,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub created_at: DateTime<FixedOffset>,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub updated_at: DateTime<FixedOffset>,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub is_deleted: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
|
||||||
pub enum Relation {}
|
|
||||||
|
|
||||||
impl Related<super::tag_info::Entity> for Entity {
|
|
||||||
fn to() -> RelationDef {
|
|
||||||
super::tag2_doc::Relation::Tag.def()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn via() -> Option<RelationDef> {
|
|
||||||
Some(super::tag2_doc::Relation::DocInfo.def().rev())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Related<super::kb_info::Entity> for Entity {
|
|
||||||
fn to() -> RelationDef {
|
|
||||||
super::kb2_doc::Relation::KbInfo.def()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn via() -> Option<RelationDef> {
|
|
||||||
Some(super::kb2_doc::Relation::DocInfo.def().rev())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Related<super::doc2_doc::Entity> for Entity {
|
|
||||||
fn to() -> RelationDef {
|
|
||||||
super::doc2_doc::Relation::Parent.def()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn via() -> Option<RelationDef> {
|
|
||||||
Some(super::doc2_doc::Relation::Child.def().rev())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,47 +0,0 @@
|
|||||||
use sea_orm::entity::prelude::*;
|
|
||||||
use serde::{ Deserialize, Serialize };
|
|
||||||
use chrono::{ DateTime, FixedOffset };
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Deserialize, Serialize)]
|
|
||||||
#[sea_orm(table_name = "kb2_doc")]
|
|
||||||
pub struct Model {
|
|
||||||
#[sea_orm(primary_key, auto_increment = true)]
|
|
||||||
pub id: i64,
|
|
||||||
#[sea_orm(index)]
|
|
||||||
pub kb_id: i64,
|
|
||||||
#[sea_orm(index)]
|
|
||||||
pub did: i64,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub kb_progress: f32,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub kb_progress_msg: String,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub updated_at: DateTime<FixedOffset>,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub is_deleted: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, EnumIter)]
|
|
||||||
pub enum Relation {
|
|
||||||
DocInfo,
|
|
||||||
KbInfo,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RelationTrait for Relation {
|
|
||||||
fn def(&self) -> RelationDef {
|
|
||||||
match self {
|
|
||||||
Self::DocInfo =>
|
|
||||||
Entity::belongs_to(super::doc_info::Entity)
|
|
||||||
.from(Column::Did)
|
|
||||||
.to(super::doc_info::Column::Did)
|
|
||||||
.into(),
|
|
||||||
Self::KbInfo =>
|
|
||||||
Entity::belongs_to(super::kb_info::Entity)
|
|
||||||
.from(Column::KbId)
|
|
||||||
.to(super::kb_info::Column::KbId)
|
|
||||||
.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,47 +0,0 @@
|
|||||||
use sea_orm::entity::prelude::*;
|
|
||||||
use serde::{ Deserialize, Serialize };
|
|
||||||
use chrono::{ DateTime, FixedOffset };
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
|
|
||||||
#[sea_orm(table_name = "kb_info")]
|
|
||||||
pub struct Model {
|
|
||||||
#[sea_orm(primary_key, auto_increment = false)]
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub kb_id: i64,
|
|
||||||
#[sea_orm(index)]
|
|
||||||
pub uid: i64,
|
|
||||||
pub kb_name: String,
|
|
||||||
pub icon: i16,
|
|
||||||
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub created_at: DateTime<FixedOffset>,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub updated_at: DateTime<FixedOffset>,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub is_deleted: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
|
||||||
pub enum Relation {}
|
|
||||||
|
|
||||||
impl Related<super::doc_info::Entity> for Entity {
|
|
||||||
fn to() -> RelationDef {
|
|
||||||
super::kb2_doc::Relation::DocInfo.def()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn via() -> Option<RelationDef> {
|
|
||||||
Some(super::kb2_doc::Relation::KbInfo.def().rev())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Related<super::dialog_info::Entity> for Entity {
|
|
||||||
fn to() -> RelationDef {
|
|
||||||
super::dialog2_kb::Relation::DialogInfo.def()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn via() -> Option<RelationDef> {
|
|
||||||
Some(super::dialog2_kb::Relation::KbInfo.def().rev())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,9 +0,0 @@
|
|||||||
pub(crate) mod user_info;
|
|
||||||
pub(crate) mod tag_info;
|
|
||||||
pub(crate) mod tag2_doc;
|
|
||||||
pub(crate) mod kb2_doc;
|
|
||||||
pub(crate) mod dialog2_kb;
|
|
||||||
pub(crate) mod doc2_doc;
|
|
||||||
pub(crate) mod kb_info;
|
|
||||||
pub(crate) mod doc_info;
|
|
||||||
pub(crate) mod dialog_info;
|
|
@ -1,38 +0,0 @@
|
|||||||
use sea_orm::entity::prelude::*;
|
|
||||||
use serde::{ Deserialize, Serialize };
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Deserialize, Serialize)]
|
|
||||||
#[sea_orm(table_name = "tag2_doc")]
|
|
||||||
pub struct Model {
|
|
||||||
#[sea_orm(primary_key, auto_increment = true)]
|
|
||||||
pub id: i64,
|
|
||||||
#[sea_orm(index)]
|
|
||||||
pub tag_id: i64,
|
|
||||||
#[sea_orm(index)]
|
|
||||||
pub did: i64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, EnumIter)]
|
|
||||||
pub enum Relation {
|
|
||||||
Tag,
|
|
||||||
DocInfo,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RelationTrait for Relation {
|
|
||||||
fn def(&self) -> sea_orm::RelationDef {
|
|
||||||
match self {
|
|
||||||
Self::Tag =>
|
|
||||||
Entity::belongs_to(super::tag_info::Entity)
|
|
||||||
.from(Column::TagId)
|
|
||||||
.to(super::tag_info::Column::Tid)
|
|
||||||
.into(),
|
|
||||||
Self::DocInfo =>
|
|
||||||
Entity::belongs_to(super::doc_info::Entity)
|
|
||||||
.from(Column::Did)
|
|
||||||
.to(super::doc_info::Column::Did)
|
|
||||||
.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,40 +0,0 @@
|
|||||||
use sea_orm::entity::prelude::*;
|
|
||||||
use serde::{ Deserialize, Serialize };
|
|
||||||
use chrono::{ DateTime, FixedOffset };
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Deserialize, Serialize)]
|
|
||||||
#[sea_orm(table_name = "tag_info")]
|
|
||||||
pub struct Model {
|
|
||||||
#[sea_orm(primary_key)]
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub tid: i64,
|
|
||||||
#[sea_orm(index)]
|
|
||||||
pub uid: i64,
|
|
||||||
pub tag_name: String,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub regx: String,
|
|
||||||
pub color: i16,
|
|
||||||
pub icon: i16,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub folder_id: i64,
|
|
||||||
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub created_at: DateTime<FixedOffset>,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub updated_at: DateTime<FixedOffset>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
|
||||||
pub enum Relation {}
|
|
||||||
|
|
||||||
impl Related<super::doc_info::Entity> for Entity {
|
|
||||||
fn to() -> RelationDef {
|
|
||||||
super::tag2_doc::Relation::DocInfo.def()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn via() -> Option<RelationDef> {
|
|
||||||
Some(super::tag2_doc::Relation::Tag.def().rev())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,30 +0,0 @@
|
|||||||
use sea_orm::entity::prelude::*;
|
|
||||||
use serde::{ Deserialize, Serialize };
|
|
||||||
use chrono::{ DateTime, FixedOffset };
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, DeriveEntityModel, Deserialize, Serialize)]
|
|
||||||
#[sea_orm(table_name = "user_info")]
|
|
||||||
pub struct Model {
|
|
||||||
#[sea_orm(primary_key)]
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub uid: i64,
|
|
||||||
pub email: String,
|
|
||||||
pub nickname: String,
|
|
||||||
pub avatar_base64: String,
|
|
||||||
pub color_scheme: String,
|
|
||||||
pub list_style: String,
|
|
||||||
pub language: String,
|
|
||||||
pub password: String,
|
|
||||||
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub last_login_at: DateTime<FixedOffset>,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub created_at: DateTime<FixedOffset>,
|
|
||||||
#[serde(skip_deserializing)]
|
|
||||||
pub updated_at: DateTime<FixedOffset>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
|
||||||
pub enum Relation {}
|
|
||||||
|
|
||||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,83 +0,0 @@
|
|||||||
use actix_web::{HttpResponse, ResponseError};
|
|
||||||
use thiserror::Error;
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub(crate) enum AppError {
|
|
||||||
#[error("`{0}`")]
|
|
||||||
User(#[from] UserError),
|
|
||||||
|
|
||||||
#[error("`{0}`")]
|
|
||||||
Json(#[from] serde_json::Error),
|
|
||||||
|
|
||||||
#[error("`{0}`")]
|
|
||||||
Actix(#[from] actix_web::Error),
|
|
||||||
|
|
||||||
#[error("`{0}`")]
|
|
||||||
Db(#[from] sea_orm::DbErr),
|
|
||||||
|
|
||||||
#[error("`{0}`")]
|
|
||||||
MinioS3(#[from] minio::s3::error::Error),
|
|
||||||
|
|
||||||
#[error("`{0}`")]
|
|
||||||
Std(#[from] std::io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub(crate) enum UserError {
|
|
||||||
#[error("`username` field of `User` cannot be empty!")]
|
|
||||||
EmptyUsername,
|
|
||||||
|
|
||||||
#[error("`username` field of `User` cannot contain whitespaces!")]
|
|
||||||
UsernameInvalidCharacter,
|
|
||||||
|
|
||||||
#[error("`password` field of `User` cannot be empty!")]
|
|
||||||
EmptyPassword,
|
|
||||||
|
|
||||||
#[error("`password` field of `User` cannot contain whitespaces!")]
|
|
||||||
PasswordInvalidCharacter,
|
|
||||||
|
|
||||||
#[error("Could not find any `User` for id: `{0}`!")]
|
|
||||||
NotFound(i64),
|
|
||||||
|
|
||||||
#[error("Failed to login user!")]
|
|
||||||
LoginFailed,
|
|
||||||
|
|
||||||
#[error("User is not logged in!")]
|
|
||||||
NotLoggedIn,
|
|
||||||
|
|
||||||
#[error("Invalid authorization token!")]
|
|
||||||
InvalidToken,
|
|
||||||
|
|
||||||
#[error("Could not find any `User`!")]
|
|
||||||
Empty,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResponseError for AppError {
|
|
||||||
fn status_code(&self) -> actix_web::http::StatusCode {
|
|
||||||
match self {
|
|
||||||
AppError::User(user_error) => match user_error {
|
|
||||||
UserError::EmptyUsername => actix_web::http::StatusCode::UNPROCESSABLE_ENTITY,
|
|
||||||
UserError::UsernameInvalidCharacter => {
|
|
||||||
actix_web::http::StatusCode::UNPROCESSABLE_ENTITY
|
|
||||||
}
|
|
||||||
UserError::EmptyPassword => actix_web::http::StatusCode::UNPROCESSABLE_ENTITY,
|
|
||||||
UserError::PasswordInvalidCharacter => {
|
|
||||||
actix_web::http::StatusCode::UNPROCESSABLE_ENTITY
|
|
||||||
}
|
|
||||||
UserError::NotFound(_) => actix_web::http::StatusCode::NOT_FOUND,
|
|
||||||
UserError::NotLoggedIn => actix_web::http::StatusCode::UNAUTHORIZED,
|
|
||||||
UserError::Empty => actix_web::http::StatusCode::NOT_FOUND,
|
|
||||||
UserError::LoginFailed => actix_web::http::StatusCode::NOT_FOUND,
|
|
||||||
UserError::InvalidToken => actix_web::http::StatusCode::UNAUTHORIZED,
|
|
||||||
},
|
|
||||||
AppError::Actix(fail) => fail.as_response_error().status_code(),
|
|
||||||
_ => actix_web::http::StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn error_response(&self) -> HttpResponse {
|
|
||||||
let status_code = self.status_code();
|
|
||||||
let response = HttpResponse::build(status_code).body(self.to_string());
|
|
||||||
response
|
|
||||||
}
|
|
||||||
}
|
|
145
src/main.rs
145
src/main.rs
@ -1,145 +0,0 @@
|
|||||||
mod api;
|
|
||||||
mod entity;
|
|
||||||
mod service;
|
|
||||||
mod errors;
|
|
||||||
mod web_socket;
|
|
||||||
|
|
||||||
use std::env;
|
|
||||||
use actix_files::Files;
|
|
||||||
use actix_identity::{ CookieIdentityPolicy, IdentityService, RequestIdentity };
|
|
||||||
use actix_session::CookieSession;
|
|
||||||
use actix_web::{ web, App, HttpServer, middleware, Error };
|
|
||||||
use actix_web::cookie::time::Duration;
|
|
||||||
use actix_web::dev::ServiceRequest;
|
|
||||||
use actix_web::error::ErrorUnauthorized;
|
|
||||||
use actix_web_httpauth::extractors::bearer::BearerAuth;
|
|
||||||
use listenfd::ListenFd;
|
|
||||||
use minio::s3::client::Client;
|
|
||||||
use minio::s3::creds::StaticProvider;
|
|
||||||
use minio::s3::http::BaseUrl;
|
|
||||||
use sea_orm::{ Database, DatabaseConnection };
|
|
||||||
use migration::{ Migrator, MigratorTrait };
|
|
||||||
use crate::errors::{ AppError, UserError };
|
|
||||||
use crate::web_socket::doc_info::upload_file_ws;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct AppState {
|
|
||||||
conn: DatabaseConnection,
|
|
||||||
s3_client: Client,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn validator(
|
|
||||||
req: ServiceRequest,
|
|
||||||
credentials: BearerAuth
|
|
||||||
) -> Result<ServiceRequest, Error> {
|
|
||||||
if let Some(token) = req.get_identity() {
|
|
||||||
println!("{}, {}", credentials.token(), token);
|
|
||||||
(credentials.token() == token)
|
|
||||||
.then(|| req)
|
|
||||||
.ok_or(ErrorUnauthorized(UserError::InvalidToken))
|
|
||||||
} else {
|
|
||||||
Err(ErrorUnauthorized(UserError::NotLoggedIn))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[actix_web::main]
|
|
||||||
async fn main() -> Result<(), AppError> {
|
|
||||||
std::env::set_var("RUST_LOG", "debug");
|
|
||||||
tracing_subscriber::fmt::init();
|
|
||||||
|
|
||||||
// get env vars
|
|
||||||
dotenvy::dotenv().ok();
|
|
||||||
let db_url = env::var("DATABASE_URL").expect("DATABASE_URL is not set in .env file");
|
|
||||||
let host = env::var("HOST").expect("HOST is not set in .env file");
|
|
||||||
let port = env::var("PORT").expect("PORT is not set in .env file");
|
|
||||||
let server_url = format!("{host}:{port}");
|
|
||||||
|
|
||||||
let mut s3_base_url = env::var("MINIO_HOST").expect("MINIO_HOST is not set in .env file");
|
|
||||||
let s3_access_key = env::var("MINIO_USR").expect("MINIO_USR is not set in .env file");
|
|
||||||
let s3_secret_key = env::var("MINIO_PWD").expect("MINIO_PWD is not set in .env file");
|
|
||||||
if s3_base_url.find("http") != Some(0) {
|
|
||||||
s3_base_url = format!("http://{}", s3_base_url);
|
|
||||||
}
|
|
||||||
|
|
||||||
// establish connection to database and apply migrations
|
|
||||||
// -> create post table if not exists
|
|
||||||
let conn = Database::connect(&db_url).await.unwrap();
|
|
||||||
Migrator::up(&conn, None).await.unwrap();
|
|
||||||
|
|
||||||
let static_provider = StaticProvider::new(s3_access_key.as_str(), s3_secret_key.as_str(), None);
|
|
||||||
|
|
||||||
let s3_client = Client::new(
|
|
||||||
s3_base_url.parse::<BaseUrl>()?,
|
|
||||||
Some(Box::new(static_provider)),
|
|
||||||
None,
|
|
||||||
Some(true)
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let state = AppState { conn, s3_client };
|
|
||||||
|
|
||||||
// create server and try to serve over socket if possible
|
|
||||||
let mut listenfd = ListenFd::from_env();
|
|
||||||
let mut server = HttpServer::new(move || {
|
|
||||||
App::new()
|
|
||||||
.service(Files::new("/static", "./static"))
|
|
||||||
.app_data(web::Data::new(state.clone()))
|
|
||||||
.wrap(
|
|
||||||
IdentityService::new(
|
|
||||||
CookieIdentityPolicy::new(&[0; 32])
|
|
||||||
.name("auth-cookie")
|
|
||||||
.login_deadline(Duration::seconds(120))
|
|
||||||
.secure(false)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.wrap(
|
|
||||||
CookieSession::signed(&[0; 32])
|
|
||||||
.name("session-cookie")
|
|
||||||
.secure(false)
|
|
||||||
// WARNING(alex): This uses the `time` crate, not `std::time`!
|
|
||||||
.expires_in_time(Duration::seconds(60))
|
|
||||||
)
|
|
||||||
.wrap(middleware::Logger::default())
|
|
||||||
.configure(init)
|
|
||||||
});
|
|
||||||
|
|
||||||
server = match listenfd.take_tcp_listener(0)? {
|
|
||||||
Some(listener) => server.listen(listener)?,
|
|
||||||
None => server.bind(&server_url)?,
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("Starting server at {server_url}");
|
|
||||||
server.run().await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn init(cfg: &mut web::ServiceConfig) {
|
|
||||||
cfg.service(api::tag_info::create);
|
|
||||||
cfg.service(api::tag_info::delete);
|
|
||||||
cfg.service(api::tag_info::list);
|
|
||||||
|
|
||||||
cfg.service(api::kb_info::create);
|
|
||||||
cfg.service(api::kb_info::delete);
|
|
||||||
cfg.service(api::kb_info::list);
|
|
||||||
cfg.service(api::kb_info::add_docs_to_kb);
|
|
||||||
cfg.service(api::kb_info::anti_kb_docs);
|
|
||||||
cfg.service(api::kb_info::all_relevents);
|
|
||||||
|
|
||||||
cfg.service(api::doc_info::list);
|
|
||||||
cfg.service(api::doc_info::delete);
|
|
||||||
cfg.service(api::doc_info::mv);
|
|
||||||
cfg.service(api::doc_info::upload);
|
|
||||||
cfg.service(api::doc_info::new_folder);
|
|
||||||
cfg.service(api::doc_info::rename);
|
|
||||||
|
|
||||||
cfg.service(api::dialog_info::list);
|
|
||||||
cfg.service(api::dialog_info::delete);
|
|
||||||
cfg.service(api::dialog_info::create);
|
|
||||||
cfg.service(api::dialog_info::update_history);
|
|
||||||
|
|
||||||
cfg.service(api::user_info::login);
|
|
||||||
cfg.service(api::user_info::register);
|
|
||||||
cfg.service(api::user_info::setting);
|
|
||||||
|
|
||||||
cfg.service(web::resource("/ws-upload-doc").route(web::get().to(upload_file_ws)));
|
|
||||||
}
|
|
@ -1,107 +0,0 @@
|
|||||||
use chrono::{ Local, FixedOffset, Utc };
|
|
||||||
use migration::Expr;
|
|
||||||
use sea_orm::{
|
|
||||||
ActiveModelTrait,
|
|
||||||
DbConn,
|
|
||||||
DbErr,
|
|
||||||
DeleteResult,
|
|
||||||
EntityTrait,
|
|
||||||
PaginatorTrait,
|
|
||||||
QueryOrder,
|
|
||||||
UpdateResult,
|
|
||||||
};
|
|
||||||
use sea_orm::ActiveValue::Set;
|
|
||||||
use sea_orm::QueryFilter;
|
|
||||||
use sea_orm::ColumnTrait;
|
|
||||||
use crate::entity::dialog_info;
|
|
||||||
use crate::entity::dialog_info::Entity;
|
|
||||||
|
|
||||||
fn now() -> chrono::DateTime<FixedOffset> {
|
|
||||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
|
||||||
}
|
|
||||||
pub struct Query;
|
|
||||||
|
|
||||||
impl Query {
|
|
||||||
pub async fn find_dialog_info_by_id(
|
|
||||||
db: &DbConn,
|
|
||||||
id: i64
|
|
||||||
) -> Result<Option<dialog_info::Model>, DbErr> {
|
|
||||||
Entity::find_by_id(id).one(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_dialog_infos(db: &DbConn) -> Result<Vec<dialog_info::Model>, DbErr> {
|
|
||||||
Entity::find().all(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_dialog_infos_by_uid(
|
|
||||||
db: &DbConn,
|
|
||||||
uid: i64
|
|
||||||
) -> Result<Vec<dialog_info::Model>, DbErr> {
|
|
||||||
Entity::find()
|
|
||||||
.filter(dialog_info::Column::Uid.eq(uid))
|
|
||||||
.filter(dialog_info::Column::IsDeleted.eq(false))
|
|
||||||
.all(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_dialog_infos_in_page(
|
|
||||||
db: &DbConn,
|
|
||||||
page: u64,
|
|
||||||
posts_per_page: u64
|
|
||||||
) -> Result<(Vec<dialog_info::Model>, u64), DbErr> {
|
|
||||||
// Setup paginator
|
|
||||||
let paginator = Entity::find()
|
|
||||||
.order_by_asc(dialog_info::Column::DialogId)
|
|
||||||
.paginate(db, posts_per_page);
|
|
||||||
let num_pages = paginator.num_pages().await?;
|
|
||||||
|
|
||||||
// Fetch paginated posts
|
|
||||||
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Mutation;
|
|
||||||
|
|
||||||
impl Mutation {
|
|
||||||
pub async fn create_dialog_info(
|
|
||||||
db: &DbConn,
|
|
||||||
uid: i64,
|
|
||||||
kb_id: i64,
|
|
||||||
name: &String
|
|
||||||
) -> Result<dialog_info::ActiveModel, DbErr> {
|
|
||||||
(dialog_info::ActiveModel {
|
|
||||||
dialog_id: Default::default(),
|
|
||||||
uid: Set(uid),
|
|
||||||
kb_id: Set(kb_id),
|
|
||||||
dialog_name: Set(name.to_owned()),
|
|
||||||
history: Set("".to_owned()),
|
|
||||||
created_at: Set(now()),
|
|
||||||
updated_at: Set(now()),
|
|
||||||
is_deleted: Default::default(),
|
|
||||||
}).save(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn update_dialog_info_by_id(
|
|
||||||
db: &DbConn,
|
|
||||||
dialog_id: i64,
|
|
||||||
dialog_name: &String,
|
|
||||||
history: &String
|
|
||||||
) -> Result<UpdateResult, DbErr> {
|
|
||||||
Entity::update_many()
|
|
||||||
.col_expr(dialog_info::Column::DialogName, Expr::value(dialog_name))
|
|
||||||
.col_expr(dialog_info::Column::History, Expr::value(history))
|
|
||||||
.col_expr(dialog_info::Column::UpdatedAt, Expr::value(now()))
|
|
||||||
.filter(dialog_info::Column::DialogId.eq(dialog_id))
|
|
||||||
.exec(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_dialog_info(db: &DbConn, dialog_id: i64) -> Result<UpdateResult, DbErr> {
|
|
||||||
Entity::update_many()
|
|
||||||
.col_expr(dialog_info::Column::IsDeleted, Expr::value(true))
|
|
||||||
.filter(dialog_info::Column::DialogId.eq(dialog_id))
|
|
||||||
.exec(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_all_dialog_infos(db: &DbConn) -> Result<DeleteResult, DbErr> {
|
|
||||||
Entity::delete_many().exec(db).await
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,335 +0,0 @@
|
|||||||
use chrono::{ Utc, FixedOffset };
|
|
||||||
use sea_orm::{
|
|
||||||
ActiveModelTrait,
|
|
||||||
ColumnTrait,
|
|
||||||
DbConn,
|
|
||||||
DbErr,
|
|
||||||
DeleteResult,
|
|
||||||
EntityTrait,
|
|
||||||
PaginatorTrait,
|
|
||||||
QueryOrder,
|
|
||||||
Unset,
|
|
||||||
Unchanged,
|
|
||||||
ConditionalStatement,
|
|
||||||
QuerySelect,
|
|
||||||
JoinType,
|
|
||||||
RelationTrait,
|
|
||||||
DbBackend,
|
|
||||||
Statement,
|
|
||||||
UpdateResult,
|
|
||||||
};
|
|
||||||
use sea_orm::ActiveValue::Set;
|
|
||||||
use sea_orm::QueryFilter;
|
|
||||||
use crate::api::doc_info::ListParams;
|
|
||||||
use crate::entity::{ doc2_doc, doc_info };
|
|
||||||
use crate::entity::doc_info::Entity;
|
|
||||||
use crate::service;
|
|
||||||
|
|
||||||
fn now() -> chrono::DateTime<FixedOffset> {
|
|
||||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Query;
|
|
||||||
|
|
||||||
impl Query {
|
|
||||||
pub async fn find_doc_info_by_id(
|
|
||||||
db: &DbConn,
|
|
||||||
id: i64
|
|
||||||
) -> Result<Option<doc_info::Model>, DbErr> {
|
|
||||||
Entity::find_by_id(id).one(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_doc_infos(db: &DbConn) -> Result<Vec<doc_info::Model>, DbErr> {
|
|
||||||
Entity::find().all(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_doc_infos_by_uid(
|
|
||||||
db: &DbConn,
|
|
||||||
uid: i64
|
|
||||||
) -> Result<Vec<doc_info::Model>, DbErr> {
|
|
||||||
Entity::find().filter(doc_info::Column::Uid.eq(uid)).all(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_doc_infos_by_name(
|
|
||||||
db: &DbConn,
|
|
||||||
uid: i64,
|
|
||||||
name: &String,
|
|
||||||
parent_id: Option<i64>
|
|
||||||
) -> Result<Vec<doc_info::Model>, DbErr> {
|
|
||||||
let mut dids = Vec::<i64>::new();
|
|
||||||
if let Some(pid) = parent_id {
|
|
||||||
for d2d in doc2_doc::Entity
|
|
||||||
::find()
|
|
||||||
.filter(doc2_doc::Column::ParentId.eq(pid))
|
|
||||||
.all(db).await? {
|
|
||||||
dids.push(d2d.did);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let doc = Entity::find()
|
|
||||||
.filter(doc_info::Column::DocName.eq(name.clone()))
|
|
||||||
.filter(doc_info::Column::Uid.eq(uid))
|
|
||||||
.all(db).await?;
|
|
||||||
if doc.len() == 0 {
|
|
||||||
return Ok(vec![]);
|
|
||||||
}
|
|
||||||
assert!(doc.len() > 0);
|
|
||||||
let d2d = doc2_doc::Entity
|
|
||||||
::find()
|
|
||||||
.filter(doc2_doc::Column::Did.eq(doc[0].did))
|
|
||||||
.all(db).await?;
|
|
||||||
assert!(d2d.len() <= 1, "Did: {}->{}", doc[0].did, d2d.len());
|
|
||||||
if d2d.len() > 0 {
|
|
||||||
for d2d_ in doc2_doc::Entity
|
|
||||||
::find()
|
|
||||||
.filter(doc2_doc::Column::ParentId.eq(d2d[0].parent_id))
|
|
||||||
.all(db).await? {
|
|
||||||
dids.push(d2d_.did);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Entity::find()
|
|
||||||
.filter(doc_info::Column::DocName.eq(name.clone()))
|
|
||||||
.filter(doc_info::Column::Uid.eq(uid))
|
|
||||||
.filter(doc_info::Column::Did.is_in(dids))
|
|
||||||
.filter(doc_info::Column::IsDeleted.eq(false))
|
|
||||||
.all(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn all_descendent_ids(db: &DbConn, doc_ids: &Vec<i64>) -> Result<Vec<i64>, DbErr> {
|
|
||||||
let mut dids = doc_ids.clone();
|
|
||||||
let mut i: usize = 0;
|
|
||||||
loop {
|
|
||||||
if dids.len() == i {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
for d in doc2_doc::Entity
|
|
||||||
::find()
|
|
||||||
.filter(doc2_doc::Column::ParentId.eq(dids[i]))
|
|
||||||
.all(db).await? {
|
|
||||||
dids.push(d.did);
|
|
||||||
}
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
Ok(dids)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_doc_infos_by_params(
|
|
||||||
db: &DbConn,
|
|
||||||
params: ListParams
|
|
||||||
) -> Result<Vec<doc_info::Model>, DbErr> {
|
|
||||||
// Setup paginator
|
|
||||||
let mut sql: String =
|
|
||||||
"
|
|
||||||
select
|
|
||||||
a.did,
|
|
||||||
a.uid,
|
|
||||||
a.doc_name,
|
|
||||||
a.location,
|
|
||||||
a.size,
|
|
||||||
a.type,
|
|
||||||
a.created_at,
|
|
||||||
a.updated_at,
|
|
||||||
a.is_deleted
|
|
||||||
from
|
|
||||||
doc_info as a
|
|
||||||
".to_owned();
|
|
||||||
|
|
||||||
let mut cond: String = format!(" a.uid={} and a.is_deleted=False ", params.uid);
|
|
||||||
|
|
||||||
if let Some(kb_id) = params.filter.kb_id {
|
|
||||||
sql.push_str(
|
|
||||||
&format!(" inner join kb2_doc on kb2_doc.did = a.did and kb2_doc.kb_id={}", kb_id)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if let Some(folder_id) = params.filter.folder_id {
|
|
||||||
sql.push_str(
|
|
||||||
&format!(" inner join doc2_doc on a.did = doc2_doc.did and doc2_doc.parent_id={}", folder_id)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
// Fetch paginated posts
|
|
||||||
if let Some(tag_id) = params.filter.tag_id {
|
|
||||||
let tag = service::tag_info::Query
|
|
||||||
::find_tag_info_by_id(tag_id, &db).await
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
if tag.folder_id > 0 {
|
|
||||||
sql.push_str(
|
|
||||||
&format!(
|
|
||||||
" inner join doc2_doc on a.did = doc2_doc.did and doc2_doc.parent_id={}",
|
|
||||||
tag.folder_id
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if tag.regx.len() > 0 {
|
|
||||||
cond.push_str(&format!(" and (type='{}' or doc_name ~ '{}') ", tag.tag_name, tag.regx));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(keywords) = params.filter.keywords {
|
|
||||||
cond.push_str(&format!(" and doc_name like '%{}%'", keywords));
|
|
||||||
}
|
|
||||||
if cond.len() > 0 {
|
|
||||||
sql.push_str(&" where ");
|
|
||||||
sql.push_str(&cond);
|
|
||||||
}
|
|
||||||
let mut orderby = params.sortby.clone();
|
|
||||||
if orderby.len() == 0 {
|
|
||||||
orderby = "updated_at desc".to_owned();
|
|
||||||
}
|
|
||||||
sql.push_str(&format!(" order by {}", orderby));
|
|
||||||
let mut page_size: u32 = 30;
|
|
||||||
if let Some(pg_sz) = params.per_page {
|
|
||||||
page_size = pg_sz;
|
|
||||||
}
|
|
||||||
let mut page: u32 = 0;
|
|
||||||
if let Some(pg) = params.page {
|
|
||||||
page = pg;
|
|
||||||
}
|
|
||||||
sql.push_str(&format!(" limit {} offset {} ;", page_size, page * page_size));
|
|
||||||
|
|
||||||
print!("{}", sql);
|
|
||||||
Entity::find()
|
|
||||||
.from_raw_sql(Statement::from_sql_and_values(DbBackend::Postgres, sql, vec![]))
|
|
||||||
.all(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_doc_infos_in_page(
|
|
||||||
db: &DbConn,
|
|
||||||
page: u64,
|
|
||||||
posts_per_page: u64
|
|
||||||
) -> Result<(Vec<doc_info::Model>, u64), DbErr> {
|
|
||||||
// Setup paginator
|
|
||||||
let paginator = Entity::find()
|
|
||||||
.order_by_asc(doc_info::Column::Did)
|
|
||||||
.paginate(db, posts_per_page);
|
|
||||||
let num_pages = paginator.num_pages().await?;
|
|
||||||
|
|
||||||
// Fetch paginated posts
|
|
||||||
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Mutation;
|
|
||||||
|
|
||||||
impl Mutation {
|
|
||||||
pub async fn mv_doc_info(db: &DbConn, dest_did: i64, dids: &[i64]) -> Result<(), DbErr> {
|
|
||||||
for did in dids {
|
|
||||||
let d = doc2_doc::Entity
|
|
||||||
::find()
|
|
||||||
.filter(doc2_doc::Column::Did.eq(did.to_owned()))
|
|
||||||
.all(db).await?;
|
|
||||||
|
|
||||||
let _ = (doc2_doc::ActiveModel {
|
|
||||||
id: Set(d[0].id),
|
|
||||||
did: Set(did.to_owned()),
|
|
||||||
parent_id: Set(dest_did),
|
|
||||||
}).update(db).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn place_doc(
|
|
||||||
db: &DbConn,
|
|
||||||
dest_did: i64,
|
|
||||||
did: i64
|
|
||||||
) -> Result<doc2_doc::ActiveModel, DbErr> {
|
|
||||||
(doc2_doc::ActiveModel {
|
|
||||||
id: Default::default(),
|
|
||||||
parent_id: Set(dest_did),
|
|
||||||
did: Set(did),
|
|
||||||
}).save(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn create_doc_info(
|
|
||||||
db: &DbConn,
|
|
||||||
form_data: doc_info::Model
|
|
||||||
) -> Result<doc_info::ActiveModel, DbErr> {
|
|
||||||
(doc_info::ActiveModel {
|
|
||||||
did: Default::default(),
|
|
||||||
uid: Set(form_data.uid.to_owned()),
|
|
||||||
doc_name: Set(form_data.doc_name.to_owned()),
|
|
||||||
size: Set(form_data.size.to_owned()),
|
|
||||||
r#type: Set(form_data.r#type.to_owned()),
|
|
||||||
location: Set(form_data.location.to_owned()),
|
|
||||||
thumbnail_base64: Default::default(),
|
|
||||||
created_at: Set(form_data.created_at.to_owned()),
|
|
||||||
updated_at: Set(form_data.updated_at.to_owned()),
|
|
||||||
is_deleted: Default::default(),
|
|
||||||
}).save(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn update_doc_info_by_id(
|
|
||||||
db: &DbConn,
|
|
||||||
id: i64,
|
|
||||||
form_data: doc_info::Model
|
|
||||||
) -> Result<doc_info::Model, DbErr> {
|
|
||||||
let doc_info: doc_info::ActiveModel = Entity::find_by_id(id)
|
|
||||||
.one(db).await?
|
|
||||||
.ok_or(DbErr::Custom("Cannot find.".to_owned()))
|
|
||||||
.map(Into::into)?;
|
|
||||||
|
|
||||||
(doc_info::ActiveModel {
|
|
||||||
did: doc_info.did,
|
|
||||||
uid: Set(form_data.uid.to_owned()),
|
|
||||||
doc_name: Set(form_data.doc_name.to_owned()),
|
|
||||||
size: Set(form_data.size.to_owned()),
|
|
||||||
r#type: Set(form_data.r#type.to_owned()),
|
|
||||||
location: Set(form_data.location.to_owned()),
|
|
||||||
thumbnail_base64: doc_info.thumbnail_base64,
|
|
||||||
created_at: doc_info.created_at,
|
|
||||||
updated_at: Set(now()),
|
|
||||||
is_deleted: Default::default(),
|
|
||||||
}).update(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_doc_info(db: &DbConn, doc_ids: &Vec<i64>) -> Result<UpdateResult, DbErr> {
|
|
||||||
let mut dids = doc_ids.clone();
|
|
||||||
let mut i: usize = 0;
|
|
||||||
loop {
|
|
||||||
if dids.len() == i {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let mut doc: doc_info::ActiveModel = Entity::find_by_id(dids[i])
|
|
||||||
.one(db).await?
|
|
||||||
.ok_or(DbErr::Custom(format!("Can't find doc:{}", dids[i])))
|
|
||||||
.map(Into::into)?;
|
|
||||||
doc.updated_at = Set(now());
|
|
||||||
doc.is_deleted = Set(true);
|
|
||||||
let _ = doc.update(db).await?;
|
|
||||||
|
|
||||||
for d in doc2_doc::Entity
|
|
||||||
::find()
|
|
||||||
.filter(doc2_doc::Column::ParentId.eq(dids[i]))
|
|
||||||
.all(db).await? {
|
|
||||||
dids.push(d.did);
|
|
||||||
}
|
|
||||||
let _ = doc2_doc::Entity
|
|
||||||
::delete_many()
|
|
||||||
.filter(doc2_doc::Column::ParentId.eq(dids[i]))
|
|
||||||
.exec(db).await?;
|
|
||||||
let _ = doc2_doc::Entity
|
|
||||||
::delete_many()
|
|
||||||
.filter(doc2_doc::Column::Did.eq(dids[i]))
|
|
||||||
.exec(db).await?;
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
crate::service::kb_info::Mutation::remove_docs(&db, dids, None).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn rename(db: &DbConn, doc_id: i64, name: &String) -> Result<doc_info::Model, DbErr> {
|
|
||||||
let mut doc: doc_info::ActiveModel = Entity::find_by_id(doc_id)
|
|
||||||
.one(db).await?
|
|
||||||
.ok_or(DbErr::Custom(format!("Can't find doc:{}", doc_id)))
|
|
||||||
.map(Into::into)?;
|
|
||||||
doc.updated_at = Set(now());
|
|
||||||
doc.doc_name = Set(name.clone());
|
|
||||||
doc.update(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_all_doc_infos(db: &DbConn) -> Result<DeleteResult, DbErr> {
|
|
||||||
Entity::delete_many().exec(db).await
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,168 +0,0 @@
|
|||||||
use chrono::{ Local, FixedOffset, Utc };
|
|
||||||
use migration::Expr;
|
|
||||||
use sea_orm::{
|
|
||||||
ActiveModelTrait,
|
|
||||||
ColumnTrait,
|
|
||||||
DbConn,
|
|
||||||
DbErr,
|
|
||||||
DeleteResult,
|
|
||||||
EntityTrait,
|
|
||||||
PaginatorTrait,
|
|
||||||
QueryFilter,
|
|
||||||
QueryOrder,
|
|
||||||
UpdateResult,
|
|
||||||
};
|
|
||||||
use sea_orm::ActiveValue::Set;
|
|
||||||
use crate::entity::kb_info;
|
|
||||||
use crate::entity::kb2_doc;
|
|
||||||
use crate::entity::kb_info::Entity;
|
|
||||||
|
|
||||||
fn now() -> chrono::DateTime<FixedOffset> {
|
|
||||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
|
||||||
}
|
|
||||||
pub struct Query;
|
|
||||||
|
|
||||||
impl Query {
|
|
||||||
pub async fn find_kb_info_by_id(db: &DbConn, id: i64) -> Result<Option<kb_info::Model>, DbErr> {
|
|
||||||
Entity::find_by_id(id).one(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_kb_infos(db: &DbConn) -> Result<Vec<kb_info::Model>, DbErr> {
|
|
||||||
Entity::find().all(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_kb_infos_by_uid(db: &DbConn, uid: i64) -> Result<Vec<kb_info::Model>, DbErr> {
|
|
||||||
Entity::find().filter(kb_info::Column::Uid.eq(uid)).all(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_kb_infos_by_name(
|
|
||||||
db: &DbConn,
|
|
||||||
name: String
|
|
||||||
) -> Result<Vec<kb_info::Model>, DbErr> {
|
|
||||||
Entity::find().filter(kb_info::Column::KbName.eq(name)).all(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_kb_by_docs(
|
|
||||||
db: &DbConn,
|
|
||||||
doc_ids: Vec<i64>
|
|
||||||
) -> Result<Vec<kb_info::Model>, DbErr> {
|
|
||||||
let mut kbids = Vec::<i64>::new();
|
|
||||||
for k in kb2_doc::Entity
|
|
||||||
::find()
|
|
||||||
.filter(kb2_doc::Column::Did.is_in(doc_ids))
|
|
||||||
.all(db).await? {
|
|
||||||
kbids.push(k.kb_id);
|
|
||||||
}
|
|
||||||
Entity::find().filter(kb_info::Column::KbId.is_in(kbids)).all(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_kb_infos_in_page(
|
|
||||||
db: &DbConn,
|
|
||||||
page: u64,
|
|
||||||
posts_per_page: u64
|
|
||||||
) -> Result<(Vec<kb_info::Model>, u64), DbErr> {
|
|
||||||
// Setup paginator
|
|
||||||
let paginator = Entity::find()
|
|
||||||
.order_by_asc(kb_info::Column::KbId)
|
|
||||||
.paginate(db, posts_per_page);
|
|
||||||
let num_pages = paginator.num_pages().await?;
|
|
||||||
|
|
||||||
// Fetch paginated posts
|
|
||||||
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Mutation;
|
|
||||||
|
|
||||||
impl Mutation {
|
|
||||||
pub async fn create_kb_info(
|
|
||||||
db: &DbConn,
|
|
||||||
form_data: kb_info::Model
|
|
||||||
) -> Result<kb_info::ActiveModel, DbErr> {
|
|
||||||
(kb_info::ActiveModel {
|
|
||||||
kb_id: Default::default(),
|
|
||||||
uid: Set(form_data.uid.to_owned()),
|
|
||||||
kb_name: Set(form_data.kb_name.to_owned()),
|
|
||||||
icon: Set(form_data.icon.to_owned()),
|
|
||||||
created_at: Set(now()),
|
|
||||||
updated_at: Set(now()),
|
|
||||||
is_deleted: Default::default(),
|
|
||||||
}).save(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn add_docs(db: &DbConn, kb_id: i64, doc_ids: Vec<i64>) -> Result<(), DbErr> {
|
|
||||||
for did in doc_ids {
|
|
||||||
let res = kb2_doc::Entity
|
|
||||||
::find()
|
|
||||||
.filter(kb2_doc::Column::KbId.eq(kb_id))
|
|
||||||
.filter(kb2_doc::Column::Did.eq(did))
|
|
||||||
.all(db).await?;
|
|
||||||
if res.len() > 0 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let _ = (kb2_doc::ActiveModel {
|
|
||||||
id: Default::default(),
|
|
||||||
kb_id: Set(kb_id),
|
|
||||||
did: Set(did),
|
|
||||||
kb_progress: Set(0.0),
|
|
||||||
kb_progress_msg: Set("".to_owned()),
|
|
||||||
updated_at: Set(now()),
|
|
||||||
is_deleted: Default::default(),
|
|
||||||
}).save(db).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn remove_docs(
|
|
||||||
db: &DbConn,
|
|
||||||
doc_ids: Vec<i64>,
|
|
||||||
kb_id: Option<i64>
|
|
||||||
) -> Result<UpdateResult, DbErr> {
|
|
||||||
let update = kb2_doc::Entity
|
|
||||||
::update_many()
|
|
||||||
.col_expr(kb2_doc::Column::IsDeleted, Expr::value(true))
|
|
||||||
.col_expr(kb2_doc::Column::KbProgress, Expr::value(0))
|
|
||||||
.col_expr(kb2_doc::Column::KbProgressMsg, Expr::value(""))
|
|
||||||
.filter(kb2_doc::Column::Did.is_in(doc_ids));
|
|
||||||
if let Some(kbid) = kb_id {
|
|
||||||
update.filter(kb2_doc::Column::KbId.eq(kbid)).exec(db).await
|
|
||||||
} else {
|
|
||||||
update.exec(db).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn update_kb_info_by_id(
|
|
||||||
db: &DbConn,
|
|
||||||
id: i64,
|
|
||||||
form_data: kb_info::Model
|
|
||||||
) -> Result<kb_info::Model, DbErr> {
|
|
||||||
let kb_info: kb_info::ActiveModel = Entity::find_by_id(id)
|
|
||||||
.one(db).await?
|
|
||||||
.ok_or(DbErr::Custom("Cannot find.".to_owned()))
|
|
||||||
.map(Into::into)?;
|
|
||||||
|
|
||||||
(kb_info::ActiveModel {
|
|
||||||
kb_id: kb_info.kb_id,
|
|
||||||
uid: kb_info.uid,
|
|
||||||
kb_name: Set(form_data.kb_name.to_owned()),
|
|
||||||
icon: Set(form_data.icon.to_owned()),
|
|
||||||
created_at: kb_info.created_at,
|
|
||||||
updated_at: Set(now()),
|
|
||||||
is_deleted: Default::default(),
|
|
||||||
}).update(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_kb_info(db: &DbConn, kb_id: i64) -> Result<DeleteResult, DbErr> {
|
|
||||||
let kb: kb_info::ActiveModel = Entity::find_by_id(kb_id)
|
|
||||||
.one(db).await?
|
|
||||||
.ok_or(DbErr::Custom("Cannot find.".to_owned()))
|
|
||||||
.map(Into::into)?;
|
|
||||||
|
|
||||||
kb.delete(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_all_kb_infos(db: &DbConn) -> Result<DeleteResult, DbErr> {
|
|
||||||
Entity::delete_many().exec(db).await
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,5 +0,0 @@
|
|||||||
pub(crate) mod dialog_info;
|
|
||||||
pub(crate) mod tag_info;
|
|
||||||
pub(crate) mod kb_info;
|
|
||||||
pub(crate) mod doc_info;
|
|
||||||
pub(crate) mod user_info;
|
|
@ -1,108 +0,0 @@
|
|||||||
use chrono::{ FixedOffset, Utc };
|
|
||||||
use sea_orm::{
|
|
||||||
ActiveModelTrait,
|
|
||||||
DbConn,
|
|
||||||
DbErr,
|
|
||||||
DeleteResult,
|
|
||||||
EntityTrait,
|
|
||||||
PaginatorTrait,
|
|
||||||
QueryOrder,
|
|
||||||
ColumnTrait,
|
|
||||||
QueryFilter,
|
|
||||||
};
|
|
||||||
use sea_orm::ActiveValue::{ Set, NotSet };
|
|
||||||
use crate::entity::tag_info;
|
|
||||||
use crate::entity::tag_info::Entity;
|
|
||||||
|
|
||||||
fn now() -> chrono::DateTime<FixedOffset> {
|
|
||||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
|
||||||
}
|
|
||||||
pub struct Query;
|
|
||||||
|
|
||||||
impl Query {
|
|
||||||
pub async fn find_tag_info_by_id(
|
|
||||||
id: i64,
|
|
||||||
db: &DbConn
|
|
||||||
) -> Result<Option<tag_info::Model>, DbErr> {
|
|
||||||
Entity::find_by_id(id).one(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_tags_by_uid(uid: i64, db: &DbConn) -> Result<Vec<tag_info::Model>, DbErr> {
|
|
||||||
Entity::find().filter(tag_info::Column::Uid.eq(uid)).all(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_tag_infos_in_page(
|
|
||||||
db: &DbConn,
|
|
||||||
page: u64,
|
|
||||||
posts_per_page: u64
|
|
||||||
) -> Result<(Vec<tag_info::Model>, u64), DbErr> {
|
|
||||||
// Setup paginator
|
|
||||||
let paginator = Entity::find()
|
|
||||||
.order_by_asc(tag_info::Column::Tid)
|
|
||||||
.paginate(db, posts_per_page);
|
|
||||||
let num_pages = paginator.num_pages().await?;
|
|
||||||
|
|
||||||
// Fetch paginated posts
|
|
||||||
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Mutation;
|
|
||||||
|
|
||||||
impl Mutation {
|
|
||||||
pub async fn create_tag(
|
|
||||||
db: &DbConn,
|
|
||||||
form_data: tag_info::Model
|
|
||||||
) -> Result<tag_info::ActiveModel, DbErr> {
|
|
||||||
(tag_info::ActiveModel {
|
|
||||||
tid: Default::default(),
|
|
||||||
uid: Set(form_data.uid.to_owned()),
|
|
||||||
tag_name: Set(form_data.tag_name.to_owned()),
|
|
||||||
regx: Set(form_data.regx.to_owned()),
|
|
||||||
color: Set(form_data.color.to_owned()),
|
|
||||||
icon: Set(form_data.icon.to_owned()),
|
|
||||||
folder_id: match form_data.folder_id {
|
|
||||||
0 => NotSet,
|
|
||||||
_ => Set(form_data.folder_id.to_owned()),
|
|
||||||
},
|
|
||||||
created_at: Set(now()),
|
|
||||||
updated_at: Set(now()),
|
|
||||||
}).save(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn update_tag_by_id(
|
|
||||||
db: &DbConn,
|
|
||||||
id: i64,
|
|
||||||
form_data: tag_info::Model
|
|
||||||
) -> Result<tag_info::Model, DbErr> {
|
|
||||||
let tag: tag_info::ActiveModel = Entity::find_by_id(id)
|
|
||||||
.one(db).await?
|
|
||||||
.ok_or(DbErr::Custom("Cannot find tag.".to_owned()))
|
|
||||||
.map(Into::into)?;
|
|
||||||
|
|
||||||
(tag_info::ActiveModel {
|
|
||||||
tid: tag.tid,
|
|
||||||
uid: tag.uid,
|
|
||||||
tag_name: Set(form_data.tag_name.to_owned()),
|
|
||||||
regx: Set(form_data.regx.to_owned()),
|
|
||||||
color: Set(form_data.color.to_owned()),
|
|
||||||
icon: Set(form_data.icon.to_owned()),
|
|
||||||
folder_id: Set(form_data.folder_id.to_owned()),
|
|
||||||
created_at: Default::default(),
|
|
||||||
updated_at: Set(now()),
|
|
||||||
}).update(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_tag(db: &DbConn, tid: i64) -> Result<DeleteResult, DbErr> {
|
|
||||||
let tag: tag_info::ActiveModel = Entity::find_by_id(tid)
|
|
||||||
.one(db).await?
|
|
||||||
.ok_or(DbErr::Custom("Cannot find tag.".to_owned()))
|
|
||||||
.map(Into::into)?;
|
|
||||||
|
|
||||||
tag.delete(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_all_tags(db: &DbConn) -> Result<DeleteResult, DbErr> {
|
|
||||||
Entity::delete_many().exec(db).await
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,131 +0,0 @@
|
|||||||
use chrono::{ FixedOffset, Utc };
|
|
||||||
use migration::Expr;
|
|
||||||
use sea_orm::{
|
|
||||||
ActiveModelTrait,
|
|
||||||
ColumnTrait,
|
|
||||||
DbConn,
|
|
||||||
DbErr,
|
|
||||||
DeleteResult,
|
|
||||||
EntityTrait,
|
|
||||||
PaginatorTrait,
|
|
||||||
QueryFilter,
|
|
||||||
QueryOrder,
|
|
||||||
UpdateResult,
|
|
||||||
};
|
|
||||||
use sea_orm::ActiveValue::Set;
|
|
||||||
use crate::entity::user_info;
|
|
||||||
use crate::entity::user_info::Entity;
|
|
||||||
|
|
||||||
fn now() -> chrono::DateTime<FixedOffset> {
|
|
||||||
Utc::now().with_timezone(&FixedOffset::east_opt(3600 * 8).unwrap())
|
|
||||||
}
|
|
||||||
pub struct Query;
|
|
||||||
|
|
||||||
impl Query {
|
|
||||||
pub async fn find_user_info_by_id(
|
|
||||||
db: &DbConn,
|
|
||||||
id: i64
|
|
||||||
) -> Result<Option<user_info::Model>, DbErr> {
|
|
||||||
Entity::find_by_id(id).one(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn login(
|
|
||||||
db: &DbConn,
|
|
||||||
email: &str,
|
|
||||||
password: &str
|
|
||||||
) -> Result<Option<user_info::Model>, DbErr> {
|
|
||||||
Entity::find()
|
|
||||||
.filter(user_info::Column::Email.eq(email))
|
|
||||||
.filter(user_info::Column::Password.eq(password))
|
|
||||||
.one(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_user_infos(
|
|
||||||
db: &DbConn,
|
|
||||||
email: &String
|
|
||||||
) -> Result<Option<user_info::Model>, DbErr> {
|
|
||||||
Entity::find().filter(user_info::Column::Email.eq(email)).one(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn find_user_infos_in_page(
|
|
||||||
db: &DbConn,
|
|
||||||
page: u64,
|
|
||||||
posts_per_page: u64
|
|
||||||
) -> Result<(Vec<user_info::Model>, u64), DbErr> {
|
|
||||||
// Setup paginator
|
|
||||||
let paginator = Entity::find()
|
|
||||||
.order_by_asc(user_info::Column::Uid)
|
|
||||||
.paginate(db, posts_per_page);
|
|
||||||
let num_pages = paginator.num_pages().await?;
|
|
||||||
|
|
||||||
// Fetch paginated posts
|
|
||||||
paginator.fetch_page(page - 1).await.map(|p| (p, num_pages))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Mutation;
|
|
||||||
|
|
||||||
impl Mutation {
|
|
||||||
pub async fn create_user(
|
|
||||||
db: &DbConn,
|
|
||||||
form_data: &user_info::Model
|
|
||||||
) -> Result<user_info::ActiveModel, DbErr> {
|
|
||||||
(user_info::ActiveModel {
|
|
||||||
uid: Default::default(),
|
|
||||||
email: Set(form_data.email.to_owned()),
|
|
||||||
nickname: Set(form_data.nickname.to_owned()),
|
|
||||||
avatar_base64: Set(form_data.avatar_base64.to_owned()),
|
|
||||||
color_scheme: Set(form_data.color_scheme.to_owned()),
|
|
||||||
list_style: Set(form_data.list_style.to_owned()),
|
|
||||||
language: Set(form_data.language.to_owned()),
|
|
||||||
password: Set(form_data.password.to_owned()),
|
|
||||||
last_login_at: Set(now()),
|
|
||||||
created_at: Set(now()),
|
|
||||||
updated_at: Set(now()),
|
|
||||||
}).save(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn update_user_by_id(
|
|
||||||
db: &DbConn,
|
|
||||||
form_data: &user_info::Model
|
|
||||||
) -> Result<user_info::Model, DbErr> {
|
|
||||||
let usr: user_info::ActiveModel = Entity::find_by_id(form_data.uid)
|
|
||||||
.one(db).await?
|
|
||||||
.ok_or(DbErr::Custom("Cannot find user.".to_owned()))
|
|
||||||
.map(Into::into)?;
|
|
||||||
|
|
||||||
(user_info::ActiveModel {
|
|
||||||
uid: Set(form_data.uid),
|
|
||||||
email: Set(form_data.email.to_owned()),
|
|
||||||
nickname: Set(form_data.nickname.to_owned()),
|
|
||||||
avatar_base64: Set(form_data.avatar_base64.to_owned()),
|
|
||||||
color_scheme: Set(form_data.color_scheme.to_owned()),
|
|
||||||
list_style: Set(form_data.list_style.to_owned()),
|
|
||||||
language: Set(form_data.language.to_owned()),
|
|
||||||
password: Set(form_data.password.to_owned()),
|
|
||||||
updated_at: Set(now()),
|
|
||||||
last_login_at: usr.last_login_at,
|
|
||||||
created_at: usr.created_at,
|
|
||||||
}).update(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn update_login_status(uid: i64, db: &DbConn) -> Result<UpdateResult, DbErr> {
|
|
||||||
Entity::update_many()
|
|
||||||
.col_expr(user_info::Column::LastLoginAt, Expr::value(now()))
|
|
||||||
.filter(user_info::Column::Uid.eq(uid))
|
|
||||||
.exec(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_user(db: &DbConn, tid: i64) -> Result<DeleteResult, DbErr> {
|
|
||||||
let tag: user_info::ActiveModel = Entity::find_by_id(tid)
|
|
||||||
.one(db).await?
|
|
||||||
.ok_or(DbErr::Custom("Cannot find tag.".to_owned()))
|
|
||||||
.map(Into::into)?;
|
|
||||||
|
|
||||||
tag.delete(db).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_all(db: &DbConn) -> Result<DeleteResult, DbErr> {
|
|
||||||
Entity::delete_many().exec(db).await
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,97 +0,0 @@
|
|||||||
use std::io::{Cursor, Write};
|
|
||||||
use std::time::{Duration, Instant};
|
|
||||||
use actix_rt::time::interval;
|
|
||||||
use actix_web::{HttpRequest, HttpResponse, rt, web};
|
|
||||||
use actix_web::web::Buf;
|
|
||||||
use actix_ws::Message;
|
|
||||||
use futures_util::{future, StreamExt};
|
|
||||||
use futures_util::future::Either;
|
|
||||||
use uuid::Uuid;
|
|
||||||
use crate::api::doc_info::_upload_file;
|
|
||||||
use crate::AppState;
|
|
||||||
use crate::errors::AppError;
|
|
||||||
|
|
||||||
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
|
|
||||||
|
|
||||||
/// How long before lack of client response causes a timeout.
|
|
||||||
const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
|
|
||||||
|
|
||||||
pub async fn upload_file_ws(req: HttpRequest, stream: web::Payload, data: web::Data<AppState>) -> Result<HttpResponse, AppError> {
|
|
||||||
let (res, session, msg_stream) = actix_ws::handle(&req, stream)?;
|
|
||||||
|
|
||||||
// spawn websocket handler (and don't await it) so that the response is returned immediately
|
|
||||||
rt::spawn(upload_file_handler(data, session, msg_stream));
|
|
||||||
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn upload_file_handler(
|
|
||||||
data: web::Data<AppState>,
|
|
||||||
mut session: actix_ws::Session,
|
|
||||||
mut msg_stream: actix_ws::MessageStream,
|
|
||||||
) {
|
|
||||||
let mut bytes = Cursor::new(vec![]);
|
|
||||||
let mut last_heartbeat = Instant::now();
|
|
||||||
let mut interval = interval(HEARTBEAT_INTERVAL);
|
|
||||||
|
|
||||||
let reason = loop {
|
|
||||||
let tick = interval.tick();
|
|
||||||
tokio::pin!(tick);
|
|
||||||
|
|
||||||
match future::select(msg_stream.next(), tick).await {
|
|
||||||
// received message from WebSocket client
|
|
||||||
Either::Left((Some(Ok(msg)), _)) => {
|
|
||||||
match msg {
|
|
||||||
Message::Text(text) => {
|
|
||||||
session.text(text).await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
Message::Binary(bin) => {
|
|
||||||
let mut pos = 0; // notice the name of the file that will be written
|
|
||||||
while pos < bin.len() {
|
|
||||||
let bytes_written = bytes.write(&bin[pos..]).unwrap();
|
|
||||||
pos += bytes_written
|
|
||||||
};
|
|
||||||
session.binary(bin).await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
Message::Close(reason) => {
|
|
||||||
break reason;
|
|
||||||
}
|
|
||||||
|
|
||||||
Message::Ping(bytes) => {
|
|
||||||
last_heartbeat = Instant::now();
|
|
||||||
let _ = session.pong(&bytes).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
Message::Pong(_) => {
|
|
||||||
last_heartbeat = Instant::now();
|
|
||||||
}
|
|
||||||
|
|
||||||
Message::Continuation(_) | Message::Nop => {}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
Either::Left((Some(Err(_)), _)) => {
|
|
||||||
break None;
|
|
||||||
}
|
|
||||||
Either::Left((None, _)) => break None,
|
|
||||||
Either::Right((_inst, _)) => {
|
|
||||||
if Instant::now().duration_since(last_heartbeat) > CLIENT_TIMEOUT {
|
|
||||||
break None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let _ = session.ping(b"").await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let _ = session.close(reason).await;
|
|
||||||
|
|
||||||
if !bytes.has_remaining() {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let uid = bytes.get_i64();
|
|
||||||
let did = bytes.get_i64();
|
|
||||||
|
|
||||||
_upload_file(uid, did, &Uuid::new_v4().to_string(), &bytes.into_inner(), &data).await.unwrap();
|
|
||||||
}
|
|
@ -1 +0,0 @@
|
|||||||
pub mod doc_info;
|
|
0
web_server/__init__.py
Normal file
0
web_server/__init__.py
Normal file
147
web_server/apps/__init__.py
Normal file
147
web_server/apps/__init__.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from importlib.util import module_from_spec, spec_from_file_location
|
||||||
|
from pathlib import Path
|
||||||
|
from flask import Blueprint, Flask, request
|
||||||
|
from werkzeug.wrappers.request import Request
|
||||||
|
from flask_cors import CORS
|
||||||
|
|
||||||
|
from web_server.db import StatusEnum
|
||||||
|
from web_server.db.services import UserService
|
||||||
|
from web_server.utils import CustomJSONEncoder
|
||||||
|
|
||||||
|
from flask_session import Session
|
||||||
|
from flask_login import LoginManager
|
||||||
|
from web_server.settings import RetCode, SECRET_KEY, stat_logger
|
||||||
|
from web_server.hook import HookManager
|
||||||
|
from web_server.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters
|
||||||
|
from web_server.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger
|
||||||
|
from web_server.utils.api_utils import get_json_result, server_error_response
|
||||||
|
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||||
|
|
||||||
|
__all__ = ['app']
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger('flask.app')
|
||||||
|
for h in access_logger.handlers:
|
||||||
|
logger.addHandler(h)
|
||||||
|
|
||||||
|
Request.json = property(lambda self: self.get_json(force=True, silent=True))
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
CORS(app, supports_credentials=True,max_age = 2592000)
|
||||||
|
app.url_map.strict_slashes = False
|
||||||
|
app.json_encoder = CustomJSONEncoder
|
||||||
|
app.errorhandler(Exception)(server_error_response)
|
||||||
|
|
||||||
|
|
||||||
|
## convince for dev and debug
|
||||||
|
#app.config["LOGIN_DISABLED"] = True
|
||||||
|
app.config["SESSION_PERMANENT"] = False
|
||||||
|
app.config["SESSION_TYPE"] = "filesystem"
|
||||||
|
app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024
|
||||||
|
|
||||||
|
Session(app)
|
||||||
|
login_manager = LoginManager()
|
||||||
|
login_manager.init_app(app)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def search_pages_path(pages_dir):
|
||||||
|
return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
|
||||||
|
|
||||||
|
|
||||||
|
def register_page(page_path):
|
||||||
|
page_name = page_path.stem.rstrip('_app')
|
||||||
|
module_name = '.'.join(page_path.parts[page_path.parts.index('web_server'):-1] + (page_name, ))
|
||||||
|
|
||||||
|
spec = spec_from_file_location(module_name, page_path)
|
||||||
|
page = module_from_spec(spec)
|
||||||
|
page.app = app
|
||||||
|
page.manager = Blueprint(page_name, module_name)
|
||||||
|
sys.modules[module_name] = page
|
||||||
|
spec.loader.exec_module(page)
|
||||||
|
|
||||||
|
page_name = getattr(page, 'page_name', page_name)
|
||||||
|
url_prefix = f'/{API_VERSION}/{page_name}'
|
||||||
|
|
||||||
|
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
||||||
|
return url_prefix
|
||||||
|
|
||||||
|
|
||||||
|
pages_dir = [
|
||||||
|
Path(__file__).parent,
|
||||||
|
Path(__file__).parent.parent / 'web_server' / 'apps',
|
||||||
|
]
|
||||||
|
|
||||||
|
client_urls_prefix = [
|
||||||
|
register_page(path)
|
||||||
|
for dir in pages_dir
|
||||||
|
for path in search_pages_path(dir)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def client_authentication_before_request():
|
||||||
|
result = HookManager.client_authentication(ClientAuthenticationParameters(
|
||||||
|
request.full_path, request.headers,
|
||||||
|
request.form, request.data, request.json,
|
||||||
|
))
|
||||||
|
|
||||||
|
if result.code != RetCode.SUCCESS:
|
||||||
|
return get_json_result(result.code, result.message)
|
||||||
|
|
||||||
|
|
||||||
|
def site_authentication_before_request():
|
||||||
|
for url_prefix in client_urls_prefix:
|
||||||
|
if request.path.startswith(url_prefix):
|
||||||
|
return
|
||||||
|
|
||||||
|
result = HookManager.site_authentication(AuthenticationParameters(
|
||||||
|
request.headers.get('site_signature'),
|
||||||
|
request.json,
|
||||||
|
))
|
||||||
|
|
||||||
|
if result.code != RetCode.SUCCESS:
|
||||||
|
return get_json_result(result.code, result.message)
|
||||||
|
|
||||||
|
|
||||||
|
@app.before_request
|
||||||
|
def authentication_before_request():
|
||||||
|
if CLIENT_AUTHENTICATION:
|
||||||
|
return client_authentication_before_request()
|
||||||
|
|
||||||
|
if SITE_AUTHENTICATION:
|
||||||
|
return site_authentication_before_request()
|
||||||
|
|
||||||
|
@login_manager.request_loader
|
||||||
|
def load_user(web_request):
|
||||||
|
jwt = Serializer(secret_key=SECRET_KEY)
|
||||||
|
authorization = web_request.headers.get("Authorization")
|
||||||
|
if authorization:
|
||||||
|
try:
|
||||||
|
access_token = str(jwt.loads(authorization))
|
||||||
|
user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
|
||||||
|
if user:
|
||||||
|
return user[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
stat_logger.exception(e)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
235
web_server/apps/document_app.py
Normal file
235
web_server/apps/document_app.py
Normal file
@ -0,0 +1,235 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
from elasticsearch_dsl import Q
|
||||||
|
from flask import request
|
||||||
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
|
from rag.nlp import search
|
||||||
|
from rag.utils import ELASTICSEARCH
|
||||||
|
from web_server.db.services import duplicate_name
|
||||||
|
from web_server.db.services.kb_service import KnowledgebaseService
|
||||||
|
from web_server.db.services.user_service import TenantService
|
||||||
|
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
|
from web_server.utils import get_uuid, get_format_time
|
||||||
|
from web_server.db import StatusEnum, FileType
|
||||||
|
from web_server.db.services.document_service import DocumentService
|
||||||
|
from web_server.settings import RetCode
|
||||||
|
from web_server.utils.api_utils import get_json_result
|
||||||
|
from rag.utils.minio_conn import MINIO
|
||||||
|
from web_server.utils.file_utils import filename_type
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/upload', methods=['POST'])
|
||||||
|
@login_required
|
||||||
|
@validate_request("kb_id")
|
||||||
|
def upload():
|
||||||
|
kb_id = request.form.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
|
||||||
|
if 'file' not in request.files:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
|
||||||
|
file = request.files['file']
|
||||||
|
if file.filename == '':
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
try:
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Can't find this knowledgebase!")
|
||||||
|
|
||||||
|
filename = duplicate_name(
|
||||||
|
DocumentService.query,
|
||||||
|
name=file.filename,
|
||||||
|
kb_id=kb.id)
|
||||||
|
location = filename
|
||||||
|
while MINIO.obj_exist(kb_id, location):
|
||||||
|
location += "_"
|
||||||
|
blob = request.files['file'].read()
|
||||||
|
MINIO.put(kb_id, filename, blob)
|
||||||
|
doc = DocumentService.insert({
|
||||||
|
"id": get_uuid(),
|
||||||
|
"kb_id": kb.id,
|
||||||
|
"parser_id": kb.parser_id,
|
||||||
|
"created_by": current_user.id,
|
||||||
|
"type": filename_type(filename),
|
||||||
|
"name": filename,
|
||||||
|
"location": location,
|
||||||
|
"size": len(blob)
|
||||||
|
})
|
||||||
|
return get_json_result(data=doc.to_json())
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/create', methods=['POST'])
|
||||||
|
@login_required
|
||||||
|
@validate_request("name", "kb_id")
|
||||||
|
def create():
|
||||||
|
req = request.json
|
||||||
|
kb_id = req["kb_id"]
|
||||||
|
if not kb_id:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
try:
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Can't find this knowledgebase!")
|
||||||
|
|
||||||
|
if DocumentService.query(name=req["name"], kb_id=kb_id):
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Duplicated document name in the same knowledgebase.")
|
||||||
|
|
||||||
|
doc = DocumentService.insert({
|
||||||
|
"id": get_uuid(),
|
||||||
|
"kb_id": kb.id,
|
||||||
|
"parser_id": kb.parser_id,
|
||||||
|
"created_by": current_user.id,
|
||||||
|
"type": FileType.VIRTUAL,
|
||||||
|
"name": req["name"],
|
||||||
|
"location": "",
|
||||||
|
"size": 0
|
||||||
|
})
|
||||||
|
return get_json_result(data=doc.to_json())
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/list', methods=['GET'])
|
||||||
|
@login_required
|
||||||
|
def list():
|
||||||
|
kb_id = request.args.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
|
||||||
|
keywords = request.args.get("keywords", "")
|
||||||
|
|
||||||
|
page_number = request.args.get("page", 1)
|
||||||
|
items_per_page = request.args.get("page_size", 15)
|
||||||
|
orderby = request.args.get("orderby", "create_time")
|
||||||
|
desc = request.args.get("desc", True)
|
||||||
|
try:
|
||||||
|
docs = DocumentService.get_by_kb_id(
|
||||||
|
kb_id, page_number, items_per_page, orderby, desc, keywords)
|
||||||
|
return get_json_result(data=docs)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/change_status', methods=['POST'])
|
||||||
|
@login_required
|
||||||
|
@validate_request("doc_id", "status")
|
||||||
|
def change_status():
|
||||||
|
req = request.json
|
||||||
|
if str(req["status"]) not in ["0", "1"]:
|
||||||
|
get_json_result(
|
||||||
|
data=False,
|
||||||
|
retmsg='"Status" must be either 0 or 1!',
|
||||||
|
retcode=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
try:
|
||||||
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Document not found!")
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Can't find this knowledgebase!")
|
||||||
|
|
||||||
|
if not DocumentService.update_by_id(
|
||||||
|
req["doc_id"], {"status": str(req["status"])}):
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Database error (Document update)!")
|
||||||
|
|
||||||
|
if str(req["status"]) == "0":
|
||||||
|
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
|
||||||
|
scripts="""
|
||||||
|
if(ctx._source.kb_id.contains('%s'))
|
||||||
|
ctx._source.kb_id.remove(
|
||||||
|
ctx._source.kb_id.indexOf('%s')
|
||||||
|
);
|
||||||
|
""" % (doc.kb_id, doc.kb_id),
|
||||||
|
idxnm=search.index_name(
|
||||||
|
kb.tenant_id)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]),
|
||||||
|
scripts="""
|
||||||
|
if(!ctx._source.kb_id.contains('%s'))
|
||||||
|
ctx._source.kb_id.add('%s');
|
||||||
|
""" % (doc.kb_id, doc.kb_id),
|
||||||
|
idxnm=search.index_name(
|
||||||
|
kb.tenant_id)
|
||||||
|
)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/rm', methods=['POST'])
|
||||||
|
@login_required
|
||||||
|
@validate_request("doc_id")
|
||||||
|
def rm():
|
||||||
|
req = request.json
|
||||||
|
try:
|
||||||
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Document not found!")
|
||||||
|
if not DocumentService.delete_by_id(req["doc_id"]):
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Database error (Document removal)!")
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
||||||
|
MINIO.rm(kb.id, doc.location)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/rename', methods=['POST'])
|
||||||
|
@login_required
|
||||||
|
@validate_request("doc_id", "name", "old_name")
|
||||||
|
def rename():
|
||||||
|
req = request.json
|
||||||
|
if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||||
|
req["old_name"].lower()).suffix:
|
||||||
|
get_json_result(
|
||||||
|
data=False,
|
||||||
|
retmsg="The extension of file can't be changed",
|
||||||
|
retcode=RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
|
try:
|
||||||
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Document not found!")
|
||||||
|
if DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Duplicated document name in the same knowledgebase.")
|
||||||
|
|
||||||
|
if not DocumentService.update_by_id(
|
||||||
|
req["doc_id"], {"name": req["name"]}):
|
||||||
|
return get_data_error_result(
|
||||||
|
retmsg="Database error (Document rename)!")
|
||||||
|
|
||||||
|
return get_json_result(data=True)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
102
web_server/apps/kb_app.py
Normal file
102
web_server/apps/kb_app.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from flask import request
|
||||||
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
|
from web_server.db.services import duplicate_name
|
||||||
|
from web_server.db.services.user_service import TenantService, UserTenantService
|
||||||
|
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
|
from web_server.utils import get_uuid, get_format_time
|
||||||
|
from web_server.db import StatusEnum, UserTenantRole
|
||||||
|
from web_server.db.services.kb_service import KnowledgebaseService
|
||||||
|
from web_server.db.db_models import Knowledgebase
|
||||||
|
from web_server.settings import stat_logger, RetCode
|
||||||
|
from web_server.utils.api_utils import get_json_result
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/create', methods=['post'])
|
||||||
|
@login_required
|
||||||
|
@validate_request("name", "description", "permission", "embd_id", "parser_id")
|
||||||
|
def create():
|
||||||
|
req = request.json
|
||||||
|
req["name"] = req["name"].strip()
|
||||||
|
req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)
|
||||||
|
try:
|
||||||
|
req["id"] = get_uuid()
|
||||||
|
req["tenant_id"] = current_user.id
|
||||||
|
req["created_by"] = current_user.id
|
||||||
|
if not KnowledgebaseService.save(**req): return get_data_error_result()
|
||||||
|
return get_json_result(data={"kb_id": req["id"]})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/update', methods=['post'])
|
||||||
|
@login_required
|
||||||
|
@validate_request("kb_id", "name", "description", "permission", "embd_id", "parser_id")
|
||||||
|
def update():
|
||||||
|
req = request.json
|
||||||
|
req["name"] = req["name"].strip()
|
||||||
|
try:
|
||||||
|
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
|
||||||
|
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||||
|
if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
||||||
|
|
||||||
|
if req["name"].lower() != kb.name.lower() \
|
||||||
|
and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1:
|
||||||
|
return get_data_error_result(retmsg="Duplicated knowledgebase name.")
|
||||||
|
|
||||||
|
del req["kb_id"]
|
||||||
|
if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result()
|
||||||
|
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||||
|
if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!")
|
||||||
|
|
||||||
|
return get_json_result(data=kb.to_json())
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/list', methods=['GET'])
|
||||||
|
@login_required
|
||||||
|
def list():
|
||||||
|
page_number = request.args.get("page", 1)
|
||||||
|
items_per_page = request.args.get("page_size", 15)
|
||||||
|
orderby = request.args.get("orderby", "create_time")
|
||||||
|
desc = request.args.get("desc", True)
|
||||||
|
try:
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||||
|
kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc)
|
||||||
|
return get_json_result(data=kbs)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/rm', methods=['post'])
|
||||||
|
@login_required
|
||||||
|
@validate_request("kb_id")
|
||||||
|
def rm():
|
||||||
|
req = request.json
|
||||||
|
try:
|
||||||
|
if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]):
|
||||||
|
return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.IN_VALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!")
|
||||||
|
return get_json_result(data=True)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
226
web_server/apps/user_app.py
Normal file
226
web_server/apps/user_app.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from flask import request, session, redirect, url_for
|
||||||
|
from werkzeug.security import generate_password_hash, check_password_hash
|
||||||
|
from flask_login import login_required, current_user, login_user, logout_user
|
||||||
|
from web_server.utils.api_utils import server_error_response, validate_request
|
||||||
|
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
|
||||||
|
from web_server.db import UserTenantRole
|
||||||
|
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
|
||||||
|
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
|
||||||
|
from web_server.settings import stat_logger
|
||||||
|
from web_server.utils.api_utils import get_json_result, cors_reponse
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/login', methods=['POST', 'GET'])
|
||||||
|
def login():
|
||||||
|
userinfo = None
|
||||||
|
login_channel = "password"
|
||||||
|
if session.get("access_token"):
|
||||||
|
login_channel = session["access_token_from"]
|
||||||
|
if session["access_token_from"] == "github":
|
||||||
|
userinfo = user_info_from_github(session["access_token"])
|
||||||
|
elif not request.json:
|
||||||
|
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||||
|
retmsg='Unautherized!')
|
||||||
|
|
||||||
|
email = request.json.get('email') if not userinfo else userinfo["email"]
|
||||||
|
users = UserService.query(email=email)
|
||||||
|
if not users:
|
||||||
|
if request.json is not None:
|
||||||
|
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
||||||
|
avatar = ""
|
||||||
|
try:
|
||||||
|
avatar = download_img(userinfo["avatar_url"])
|
||||||
|
except Exception as e:
|
||||||
|
stat_logger.exception(e)
|
||||||
|
try:
|
||||||
|
users = user_register({
|
||||||
|
"access_token": session["access_token"],
|
||||||
|
"email": userinfo["email"],
|
||||||
|
"avatar": avatar,
|
||||||
|
"nickname": userinfo["login"],
|
||||||
|
"login_channel": login_channel,
|
||||||
|
"last_login_time": get_format_time(),
|
||||||
|
"is_superuser": False,
|
||||||
|
})
|
||||||
|
if not users: raise Exception('Register user failure.')
|
||||||
|
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||||
|
user = users[0]
|
||||||
|
login_user(user)
|
||||||
|
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
||||||
|
except Exception as e:
|
||||||
|
stat_logger.exception(e)
|
||||||
|
return server_error_response(e)
|
||||||
|
elif not request.json:
|
||||||
|
login_user(users[0])
|
||||||
|
return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
|
||||||
|
|
||||||
|
password = request.json.get('password')
|
||||||
|
try:
|
||||||
|
password = decrypt(password)
|
||||||
|
except:
|
||||||
|
return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password')
|
||||||
|
|
||||||
|
user = UserService.query_user(email, password)
|
||||||
|
if user:
|
||||||
|
response_data = user.to_json()
|
||||||
|
user.access_token = get_uuid()
|
||||||
|
login_user(user)
|
||||||
|
user.save()
|
||||||
|
msg = "Welcome back!"
|
||||||
|
return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg)
|
||||||
|
else:
|
||||||
|
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!')
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/github_callback', methods=['GET'])
|
||||||
|
def github_callback():
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
||||||
|
"client_id": GITHUB_OAUTH.get("client_id"),
|
||||||
|
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
||||||
|
"code": request.args.get('code')
|
||||||
|
},headers={"Accept": "application/json"})
|
||||||
|
res = res.json()
|
||||||
|
if "error" in res:
|
||||||
|
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
||||||
|
retmsg=res["error_description"])
|
||||||
|
|
||||||
|
if "user:email" not in res["scope"].split(","):
|
||||||
|
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
|
||||||
|
|
||||||
|
session["access_token"] = res["access_token"]
|
||||||
|
session["access_token_from"] = "github"
|
||||||
|
return redirect(url_for("user.login"), code=307)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
stat_logger.exception(e)
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
def user_info_from_github(access_token):
|
||||||
|
import requests
|
||||||
|
headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"}
|
||||||
|
res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers)
|
||||||
|
user_info = res.json()
|
||||||
|
email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json()
|
||||||
|
user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"]
|
||||||
|
return user_info
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/logout", methods=['GET'])
|
||||||
|
@login_required
|
||||||
|
def log_out():
|
||||||
|
current_user.access_token = ""
|
||||||
|
current_user.save()
|
||||||
|
logout_user()
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/setting", methods=["POST"])
|
||||||
|
@login_required
|
||||||
|
def setting_user():
|
||||||
|
update_dict = {}
|
||||||
|
request_data = request.json
|
||||||
|
if request_data.get("password"):
|
||||||
|
new_password = request_data.get("new_password")
|
||||||
|
if not check_password_hash(current_user.password, decrypt(request_data["password"])):
|
||||||
|
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!')
|
||||||
|
|
||||||
|
if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password))
|
||||||
|
|
||||||
|
for k in request_data.keys():
|
||||||
|
if k in ["password", "new_password"]:continue
|
||||||
|
update_dict[k] = request_data[k]
|
||||||
|
|
||||||
|
try:
|
||||||
|
UserService.update_by_id(current_user.id, update_dict)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
except Exception as e:
|
||||||
|
stat_logger.exception(e)
|
||||||
|
return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/info", methods=["GET"])
|
||||||
|
@login_required
|
||||||
|
def user_info():
|
||||||
|
return get_json_result(data=current_user.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def user_register(user):
|
||||||
|
user_id = get_uuid()
|
||||||
|
user["id"] = user_id
|
||||||
|
tenant = {
|
||||||
|
"id": user_id,
|
||||||
|
"name": user["nickname"] + "‘s Kingdom",
|
||||||
|
"llm_id": CHAT_MDL,
|
||||||
|
"embd_id": EMBEDDING_MDL,
|
||||||
|
"asr_id": ASR_MDL,
|
||||||
|
"parser_ids": PARSERS,
|
||||||
|
"img2txt_id": IMAGE2TEXT_MDL
|
||||||
|
}
|
||||||
|
usr_tenant = {
|
||||||
|
"tenant_id": user_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"invited_by": user_id,
|
||||||
|
"role": UserTenantRole.OWNER
|
||||||
|
}
|
||||||
|
|
||||||
|
if not UserService.save(**user):return
|
||||||
|
TenantService.save(**tenant)
|
||||||
|
UserTenantService.save(**usr_tenant)
|
||||||
|
return UserService.query(email=user["email"])
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/register", methods=["POST"])
|
||||||
|
@validate_request("nickname", "email", "password")
|
||||||
|
def user_add():
|
||||||
|
req = request.json
|
||||||
|
if UserService.query(email=req["email"]):
|
||||||
|
return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
user_dict = {
|
||||||
|
"access_token": get_uuid(),
|
||||||
|
"email": req["email"],
|
||||||
|
"nickname": req["nickname"],
|
||||||
|
"password": decrypt(req["password"]),
|
||||||
|
"login_channel": "password",
|
||||||
|
"last_login_time": get_format_time(),
|
||||||
|
"is_superuser": False,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
users = user_register(user_dict)
|
||||||
|
if not users: raise Exception('Register user failure.')
|
||||||
|
if len(users) > 1: raise Exception('Same E-mail exist!')
|
||||||
|
user = users[0]
|
||||||
|
login_user(user)
|
||||||
|
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
|
||||||
|
except Exception as e:
|
||||||
|
stat_logger.exception(e)
|
||||||
|
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/tenant_info", methods=["GET"])
|
||||||
|
@login_required
|
||||||
|
def tenant_info():
|
||||||
|
try:
|
||||||
|
tenants = TenantService.get_by_user_id(current_user.id)
|
||||||
|
return get_json_result(data=tenants)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
54
web_server/db/__init__.py
Normal file
54
web_server/db/__init__.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from enum import Enum
|
||||||
|
from enum import IntEnum
|
||||||
|
from strenum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class StatusEnum(Enum):
|
||||||
|
VALID = "1"
|
||||||
|
IN_VALID = "0"
|
||||||
|
|
||||||
|
|
||||||
|
class UserTenantRole(StrEnum):
|
||||||
|
OWNER = 'owner'
|
||||||
|
ADMIN = 'admin'
|
||||||
|
NORMAL = 'normal'
|
||||||
|
|
||||||
|
|
||||||
|
class TenantPermission(StrEnum):
|
||||||
|
ME = 'me'
|
||||||
|
TEAM = 'team'
|
||||||
|
|
||||||
|
|
||||||
|
class SerializedType(IntEnum):
|
||||||
|
PICKLE = 1
|
||||||
|
JSON = 2
|
||||||
|
|
||||||
|
|
||||||
|
class FileType(StrEnum):
|
||||||
|
PDF = 'pdf'
|
||||||
|
DOC = 'doc'
|
||||||
|
VISUAL = 'visual'
|
||||||
|
AURAL = 'aural'
|
||||||
|
VIRTUAL = 'virtual'
|
||||||
|
|
||||||
|
|
||||||
|
class LLMType(StrEnum):
|
||||||
|
CHAT = 'chat'
|
||||||
|
EMBEDDING = 'embedding'
|
||||||
|
SPEECH2TEXT = 'speech2text'
|
||||||
|
IMAGE2TEXT = 'image2text'
|
616
web_server/db/db_models.py
Normal file
616
web_server/db/db_models.py
Normal file
@ -0,0 +1,616 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import typing
|
||||||
|
import operator
|
||||||
|
from functools import wraps
|
||||||
|
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||||
|
from flask_login import UserMixin
|
||||||
|
|
||||||
|
from peewee import (
|
||||||
|
BigAutoField, BigIntegerField, BooleanField, CharField,
|
||||||
|
CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField,
|
||||||
|
Field, Model, Metadata
|
||||||
|
)
|
||||||
|
from playhouse.pool import PooledMySQLDatabase
|
||||||
|
|
||||||
|
from web_server.db import SerializedType
|
||||||
|
from web_server.settings import DATABASE, stat_logger, SECRET_KEY
|
||||||
|
from web_server.utils.log_utils import getLogger
|
||||||
|
from web_server import utils
|
||||||
|
|
||||||
|
LOGGER = getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def singleton(cls, *args, **kw):
|
||||||
|
instances = {}
|
||||||
|
|
||||||
|
def _singleton():
|
||||||
|
key = str(cls) + str(os.getpid())
|
||||||
|
if key not in instances:
|
||||||
|
instances[key] = cls(*args, **kw)
|
||||||
|
return instances[key]
|
||||||
|
|
||||||
|
return _singleton
|
||||||
|
|
||||||
|
|
||||||
|
CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
|
||||||
|
AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
|
||||||
|
|
||||||
|
|
||||||
|
class LongTextField(TextField):
|
||||||
|
field_type = 'LONGTEXT'
|
||||||
|
|
||||||
|
|
||||||
|
class JSONField(LongTextField):
|
||||||
|
default_value = {}
|
||||||
|
|
||||||
|
def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
|
||||||
|
self._object_hook = object_hook
|
||||||
|
self._object_pairs_hook = object_pairs_hook
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def db_value(self, value):
|
||||||
|
if value is None:
|
||||||
|
value = self.default_value
|
||||||
|
return utils.json_dumps(value)
|
||||||
|
|
||||||
|
def python_value(self, value):
|
||||||
|
if not value:
|
||||||
|
return self.default_value
|
||||||
|
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||||
|
|
||||||
|
|
||||||
|
class ListField(JSONField):
|
||||||
|
default_value = []
|
||||||
|
|
||||||
|
|
||||||
|
class SerializedField(LongTextField):
|
||||||
|
def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
|
||||||
|
self._serialized_type = serialized_type
|
||||||
|
self._object_hook = object_hook
|
||||||
|
self._object_pairs_hook = object_pairs_hook
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def db_value(self, value):
|
||||||
|
if self._serialized_type == SerializedType.PICKLE:
|
||||||
|
return utils.serialize_b64(value, to_str=True)
|
||||||
|
elif self._serialized_type == SerializedType.JSON:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return utils.json_dumps(value, with_type=True)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||||
|
|
||||||
|
def python_value(self, value):
|
||||||
|
if self._serialized_type == SerializedType.PICKLE:
|
||||||
|
return utils.deserialize_b64(value)
|
||||||
|
elif self._serialized_type == SerializedType.JSON:
|
||||||
|
if value is None:
|
||||||
|
return {}
|
||||||
|
return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"the serialized type {self._serialized_type} is not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def is_continuous_field(cls: typing.Type) -> bool:
|
||||||
|
if cls in CONTINUOUS_FIELD_TYPE:
|
||||||
|
return True
|
||||||
|
for p in cls.__bases__:
|
||||||
|
if p in CONTINUOUS_FIELD_TYPE:
|
||||||
|
return True
|
||||||
|
elif p != Field and p != object:
|
||||||
|
if is_continuous_field(p):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def auto_date_timestamp_field():
|
||||||
|
return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
|
||||||
|
|
||||||
|
|
||||||
|
def auto_date_timestamp_db_field():
|
||||||
|
return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
|
||||||
|
|
||||||
|
|
||||||
|
def remove_field_name_prefix(field_name):
|
||||||
|
return field_name[2:] if field_name.startswith('f_') else field_name
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(Model):
|
||||||
|
create_time = BigIntegerField(null=True)
|
||||||
|
create_date = DateTimeField(null=True)
|
||||||
|
update_time = BigIntegerField(null=True)
|
||||||
|
update_date = DateTimeField(null=True)
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
# This function is obsolete
|
||||||
|
return self.to_dict()
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return self.__dict__['__data__']
|
||||||
|
|
||||||
|
def to_human_model_dict(self, only_primary_with: list = None):
|
||||||
|
model_dict = self.__dict__['__data__']
|
||||||
|
|
||||||
|
if not only_primary_with:
|
||||||
|
return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
|
||||||
|
|
||||||
|
human_model_dict = {}
|
||||||
|
for k in self._meta.primary_key.field_names:
|
||||||
|
human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
|
||||||
|
for k in only_primary_with:
|
||||||
|
human_model_dict[k] = model_dict[f'f_{k}']
|
||||||
|
return human_model_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def meta(self) -> Metadata:
|
||||||
|
return self._meta
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_primary_keys_name(cls):
|
||||||
|
return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
|
||||||
|
cls._meta.primary_key.name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getter_by(cls, attr):
|
||||||
|
return operator.attrgetter(attr)(cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def query(cls, reverse=None, order_by=None, **kwargs):
|
||||||
|
filters = []
|
||||||
|
for f_n, f_v in kwargs.items():
|
||||||
|
attr_name = '%s' % f_n
|
||||||
|
if not hasattr(cls, attr_name) or f_v is None:
|
||||||
|
continue
|
||||||
|
if type(f_v) in {list, set}:
|
||||||
|
f_v = list(f_v)
|
||||||
|
if is_continuous_field(type(getattr(cls, attr_name))):
|
||||||
|
if len(f_v) == 2:
|
||||||
|
for i, v in enumerate(f_v):
|
||||||
|
if isinstance(v, str) and f_n in auto_date_timestamp_field():
|
||||||
|
# time type: %Y-%m-%d %H:%M:%S
|
||||||
|
f_v[i] = utils.date_string_to_timestamp(v)
|
||||||
|
lt_value = f_v[0]
|
||||||
|
gt_value = f_v[1]
|
||||||
|
if lt_value is not None and gt_value is not None:
|
||||||
|
filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
|
||||||
|
elif lt_value is not None:
|
||||||
|
filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
|
||||||
|
elif gt_value is not None:
|
||||||
|
filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
|
||||||
|
else:
|
||||||
|
filters.append(operator.attrgetter(attr_name)(cls) << f_v)
|
||||||
|
else:
|
||||||
|
filters.append(operator.attrgetter(attr_name)(cls) == f_v)
|
||||||
|
if filters:
|
||||||
|
query_records = cls.select().where(*filters)
|
||||||
|
if reverse is not None:
|
||||||
|
if not order_by or not hasattr(cls, f"{order_by}"):
|
||||||
|
order_by = "create_time"
|
||||||
|
if reverse is True:
|
||||||
|
query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc())
|
||||||
|
elif reverse is False:
|
||||||
|
query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc())
|
||||||
|
return [query_record for query_record in query_records]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def insert(cls, __data=None, **insert):
|
||||||
|
if isinstance(__data, dict) and __data:
|
||||||
|
__data[cls._meta.combined["create_time"]] = utils.current_timestamp()
|
||||||
|
if insert:
|
||||||
|
insert["create_time"] = utils.current_timestamp()
|
||||||
|
|
||||||
|
return super().insert(__data, **insert)
|
||||||
|
|
||||||
|
# update and insert will call this method
|
||||||
|
@classmethod
|
||||||
|
def _normalize_data(cls, data, kwargs):
|
||||||
|
normalized = super()._normalize_data(data, kwargs)
|
||||||
|
if not normalized:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
normalized[cls._meta.combined["update_time"]] = utils.current_timestamp()
|
||||||
|
|
||||||
|
for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
|
||||||
|
if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
|
||||||
|
cls._meta.combined[f"{f_n}_time"] in normalized and \
|
||||||
|
normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
|
||||||
|
normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(
|
||||||
|
normalized[cls._meta.combined[f"{f_n}_time"]])
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
class JsonSerializedField(SerializedField):
|
||||||
|
def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs):
|
||||||
|
super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
|
||||||
|
object_pairs_hook=object_pairs_hook, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class BaseDataBase:
|
||||||
|
def __init__(self):
|
||||||
|
database_config = DATABASE.copy()
|
||||||
|
db_name = database_config.pop("name")
|
||||||
|
self.database_connection = PooledMySQLDatabase(db_name, **database_config)
|
||||||
|
stat_logger.info('init mysql database on cluster mode successfully')
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseLock:
|
||||||
|
def __init__(self, lock_name, timeout=10, db=None):
|
||||||
|
self.lock_name = lock_name
|
||||||
|
self.timeout = int(timeout)
|
||||||
|
self.db = db if db else DB
|
||||||
|
|
||||||
|
def lock(self):
|
||||||
|
# SQL parameters only support %s format placeholders
|
||||||
|
cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
|
||||||
|
ret = cursor.fetchone()
|
||||||
|
if ret[0] == 0:
|
||||||
|
raise Exception(f'acquire mysql lock {self.lock_name} timeout')
|
||||||
|
elif ret[0] == 1:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
raise Exception(f'failed to acquire lock {self.lock_name}')
|
||||||
|
|
||||||
|
def unlock(self):
|
||||||
|
cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,))
|
||||||
|
ret = cursor.fetchone()
|
||||||
|
if ret[0] == 0:
|
||||||
|
raise Exception(f'mysql lock {self.lock_name} was not established by this thread')
|
||||||
|
elif ret[0] == 1:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
raise Exception(f'mysql lock {self.lock_name} does not exist')
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if isinstance(self.db, PooledMySQLDatabase):
|
||||||
|
self.lock()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if isinstance(self.db, PooledMySQLDatabase):
|
||||||
|
self.unlock()
|
||||||
|
|
||||||
|
def __call__(self, func):
|
||||||
|
@wraps(func)
|
||||||
|
def magic(*args, **kwargs):
|
||||||
|
with self:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return magic
|
||||||
|
|
||||||
|
|
||||||
|
DB = BaseDataBase().database_connection
|
||||||
|
DB.lock = DatabaseLock
|
||||||
|
|
||||||
|
|
||||||
|
def close_connection():
|
||||||
|
try:
|
||||||
|
if DB:
|
||||||
|
DB.close()
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.exception(e)
|
||||||
|
|
||||||
|
|
||||||
|
class DataBaseModel(BaseModel):
|
||||||
|
class Meta:
|
||||||
|
database = DB
|
||||||
|
|
||||||
|
|
||||||
|
@DB.connection_context()
|
||||||
|
def init_database_tables():
|
||||||
|
members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
|
||||||
|
table_objs = []
|
||||||
|
create_failed_list = []
|
||||||
|
for name, obj in members:
|
||||||
|
if obj != DataBaseModel and issubclass(obj, DataBaseModel):
|
||||||
|
table_objs.append(obj)
|
||||||
|
LOGGER.info(f"start create table {obj.__name__}")
|
||||||
|
try:
|
||||||
|
obj.create_table()
|
||||||
|
LOGGER.info(f"create table success: {obj.__name__}")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.exception(e)
|
||||||
|
create_failed_list.append(obj.__name__)
|
||||||
|
if create_failed_list:
|
||||||
|
LOGGER.info(f"create tables failed: {create_failed_list}")
|
||||||
|
raise Exception(f"create tables failed: {create_failed_list}")
|
||||||
|
|
||||||
|
|
||||||
|
def fill_db_model_object(model_object, human_model_dict):
|
||||||
|
for k, v in human_model_dict.items():
|
||||||
|
attr_name = '%s' % k
|
||||||
|
if hasattr(model_object.__class__, attr_name):
|
||||||
|
setattr(model_object, attr_name, v)
|
||||||
|
return model_object
|
||||||
|
|
||||||
|
|
||||||
|
class User(DataBaseModel, UserMixin):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
access_token = CharField(max_length=255, null=True)
|
||||||
|
nickname = CharField(max_length=100, null=False, help_text="nicky name")
|
||||||
|
password = CharField(max_length=255, null=True, help_text="password")
|
||||||
|
email = CharField(max_length=255, null=False, help_text="email", index=True)
|
||||||
|
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||||
|
language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese")
|
||||||
|
color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Dark")
|
||||||
|
last_login_time = DateTimeField(null=True)
|
||||||
|
is_authenticated = CharField(max_length=1, null=False, default="1")
|
||||||
|
is_active = CharField(max_length=1, null=False, default="1")
|
||||||
|
is_anonymous = CharField(max_length=1, null=False, default="0")
|
||||||
|
login_channel = CharField(null=True, help_text="from which user login")
|
||||||
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||||
|
is_superuser = BooleanField(null=True, help_text="is root", default=False)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.email
|
||||||
|
|
||||||
|
def get_id(self):
|
||||||
|
jwt = Serializer(secret_key=SECRET_KEY)
|
||||||
|
return jwt.dumps(str(self.access_token))
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "user"
|
||||||
|
|
||||||
|
|
||||||
|
class Tenant(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
name = CharField(max_length=100, null=True, help_text="Tenant name")
|
||||||
|
public_key = CharField(max_length=255, null=True)
|
||||||
|
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
|
||||||
|
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
|
||||||
|
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
|
||||||
|
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
|
||||||
|
parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID")
|
||||||
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "tenant"
|
||||||
|
|
||||||
|
|
||||||
|
class UserTenant(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
user_id = CharField(max_length=32, null=False)
|
||||||
|
tenant_id = CharField(max_length=32, null=False)
|
||||||
|
role = CharField(max_length=32, null=False, help_text="UserTenantRole")
|
||||||
|
invited_by = CharField(max_length=32, null=False)
|
||||||
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "user_tenant"
|
||||||
|
|
||||||
|
|
||||||
|
class InvitationCode(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
code = CharField(max_length=32, null=False)
|
||||||
|
visit_time = DateTimeField(null=True)
|
||||||
|
user_id = CharField(max_length=32, null=True)
|
||||||
|
tenant_id = CharField(max_length=32, null=True)
|
||||||
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "invitation_code"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMFactories(DataBaseModel):
|
||||||
|
name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
|
||||||
|
logo = TextField(null=True, help_text="llm logo base64")
|
||||||
|
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||||
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "llm_factories"
|
||||||
|
|
||||||
|
|
||||||
|
class LLM(DataBaseModel):
|
||||||
|
# defautlt LLMs for every users
|
||||||
|
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
|
||||||
|
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
||||||
|
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
||||||
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.llm_name
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "llm"
|
||||||
|
|
||||||
|
|
||||||
|
class TenantLLM(DataBaseModel):
|
||||||
|
tenant_id = CharField(max_length=32, null=False)
|
||||||
|
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
|
||||||
|
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
||||||
|
llm_name = CharField(max_length=128, null=False, help_text="LLM name")
|
||||||
|
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
||||||
|
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.llm_name
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "tenant_llm"
|
||||||
|
primary_key = CompositeKey('tenant_id', 'llm_factory')
|
||||||
|
|
||||||
|
|
||||||
|
class Knowledgebase(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||||
|
tenant_id = CharField(max_length=32, null=False)
|
||||||
|
name = CharField(max_length=128, null=False, help_text="KB name", index=True)
|
||||||
|
description = TextField(null=True, help_text="KB description")
|
||||||
|
permission = CharField(max_length=16, null=False, help_text="me|team")
|
||||||
|
created_by = CharField(max_length=32, null=False)
|
||||||
|
doc_num = IntegerField(default=0)
|
||||||
|
embd_id = CharField(max_length=32, null=False, help_text="default embedding model ID")
|
||||||
|
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
||||||
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "knowledgebase"
|
||||||
|
|
||||||
|
|
||||||
|
class Document(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||||
|
kb_id = CharField(max_length=256, null=False, index=True)
|
||||||
|
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
||||||
|
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
|
||||||
|
type = CharField(max_length=32, null=False, help_text="file extension")
|
||||||
|
created_by = CharField(max_length=32, null=False, help_text="who created it")
|
||||||
|
name = CharField(max_length=255, null=True, help_text="file name", index=True)
|
||||||
|
location = CharField(max_length=255, null=True, help_text="where dose it store")
|
||||||
|
size = IntegerField(default=0)
|
||||||
|
token_num = IntegerField(default=0)
|
||||||
|
chunk_num = IntegerField(default=0)
|
||||||
|
progress = FloatField(default=0)
|
||||||
|
progress_msg = CharField(max_length=255, null=True, help_text="process message", default="")
|
||||||
|
process_begin_at = DateTimeField(null=True)
|
||||||
|
process_duation = FloatField(default=0)
|
||||||
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "document"
|
||||||
|
|
||||||
|
|
||||||
|
class Dialog(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
tenant_id = CharField(max_length=32, null=False)
|
||||||
|
name = CharField(max_length=255, null=True, help_text="dialog application name")
|
||||||
|
description = TextField(null=True, help_text="Dialog description")
|
||||||
|
icon = CharField(max_length=16, null=False, help_text="dialog icon")
|
||||||
|
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
|
||||||
|
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
|
||||||
|
llm_setting_type = CharField(max_length=8, null=False, help_text="Creative|Precise|Evenly|Custom",
|
||||||
|
default="Creative")
|
||||||
|
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
|
||||||
|
"presence_penalty": 0.4, "max_tokens": 215})
|
||||||
|
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")
|
||||||
|
prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
|
||||||
|
"parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"})
|
||||||
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "dialog"
|
||||||
|
|
||||||
|
|
||||||
|
class DialogKb(DataBaseModel):
|
||||||
|
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||||
|
kb_id = CharField(max_length=32, null=False)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "dialog_kb"
|
||||||
|
primary_key = CompositeKey('dialog_id', 'kb_id')
|
||||||
|
|
||||||
|
|
||||||
|
class Conversation(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||||
|
name = CharField(max_length=255, null=True, help_text="converastion name")
|
||||||
|
message = JSONField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "conversation"
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
class Job(DataBaseModel):
|
||||||
|
# multi-party common configuration
|
||||||
|
f_user_id = CharField(max_length=25, null=True)
|
||||||
|
f_job_id = CharField(max_length=25, index=True)
|
||||||
|
f_name = CharField(max_length=500, null=True, default='')
|
||||||
|
f_description = TextField(null=True, default='')
|
||||||
|
f_tag = CharField(max_length=50, null=True, default='')
|
||||||
|
f_dsl = JSONField()
|
||||||
|
f_runtime_conf = JSONField()
|
||||||
|
f_runtime_conf_on_party = JSONField()
|
||||||
|
f_train_runtime_conf = JSONField(null=True)
|
||||||
|
f_roles = JSONField()
|
||||||
|
f_initiator_role = CharField(max_length=50)
|
||||||
|
f_initiator_party_id = CharField(max_length=50)
|
||||||
|
f_status = CharField(max_length=50)
|
||||||
|
f_status_code = IntegerField(null=True)
|
||||||
|
f_user = JSONField()
|
||||||
|
# this party configuration
|
||||||
|
f_role = CharField(max_length=50, index=True)
|
||||||
|
f_party_id = CharField(max_length=10, index=True)
|
||||||
|
f_is_initiator = BooleanField(null=True, default=False)
|
||||||
|
f_progress = IntegerField(null=True, default=0)
|
||||||
|
f_ready_signal = BooleanField(default=False)
|
||||||
|
f_ready_time = BigIntegerField(null=True)
|
||||||
|
f_cancel_signal = BooleanField(default=False)
|
||||||
|
f_cancel_time = BigIntegerField(null=True)
|
||||||
|
f_rerun_signal = BooleanField(default=False)
|
||||||
|
f_end_scheduling_updates = IntegerField(null=True, default=0)
|
||||||
|
|
||||||
|
f_engine_name = CharField(max_length=50, null=True)
|
||||||
|
f_engine_type = CharField(max_length=10, null=True)
|
||||||
|
f_cores = IntegerField(default=0)
|
||||||
|
f_memory = IntegerField(default=0) # MB
|
||||||
|
f_remaining_cores = IntegerField(default=0)
|
||||||
|
f_remaining_memory = IntegerField(default=0) # MB
|
||||||
|
f_resource_in_use = BooleanField(default=False)
|
||||||
|
f_apply_resource_time = BigIntegerField(null=True)
|
||||||
|
f_return_resource_time = BigIntegerField(null=True)
|
||||||
|
|
||||||
|
f_inheritance_info = JSONField(null=True)
|
||||||
|
f_inheritance_status = CharField(max_length=50, null=True)
|
||||||
|
|
||||||
|
f_start_time = BigIntegerField(null=True)
|
||||||
|
f_start_date = DateTimeField(null=True)
|
||||||
|
f_end_time = BigIntegerField(null=True)
|
||||||
|
f_end_date = DateTimeField(null=True)
|
||||||
|
f_elapsed = BigIntegerField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "t_job"
|
||||||
|
primary_key = CompositeKey('f_job_id', 'f_role', 'f_party_id')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineComponentMeta(DataBaseModel):
|
||||||
|
f_model_id = CharField(max_length=100, index=True)
|
||||||
|
f_model_version = CharField(max_length=100, index=True)
|
||||||
|
f_role = CharField(max_length=50, index=True)
|
||||||
|
f_party_id = CharField(max_length=10, index=True)
|
||||||
|
f_component_name = CharField(max_length=100, index=True)
|
||||||
|
f_component_module_name = CharField(max_length=100)
|
||||||
|
f_model_alias = CharField(max_length=100, index=True)
|
||||||
|
f_model_proto_index = JSONField(null=True)
|
||||||
|
f_run_parameters = JSONField(null=True)
|
||||||
|
f_archive_sha256 = CharField(max_length=100, null=True)
|
||||||
|
f_archive_from_ip = CharField(max_length=100, null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = 't_pipeline_component_meta'
|
||||||
|
indexes = (
|
||||||
|
(('f_model_id', 'f_model_version', 'f_role', 'f_party_id', 'f_component_name'), True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
157
web_server/db/db_services.py
Normal file
157
web_server/db/db_services.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2021 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import abc
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from functools import wraps
|
||||||
|
from shortuuid import ShortUUID
|
||||||
|
|
||||||
|
from web_server.versions import get_fate_version
|
||||||
|
|
||||||
|
from web_server.errors.error_services import *
|
||||||
|
from web_server.settings import (
|
||||||
|
GRPC_PORT, HOST, HTTP_PORT,
|
||||||
|
RANDOM_INSTANCE_ID, stat_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}'
|
||||||
|
server_instance = (
|
||||||
|
f'{HOST}:{GRPC_PORT}',
|
||||||
|
json.dumps({
|
||||||
|
'instance_id': instance_id,
|
||||||
|
'timestamp': round(time.time() * 1000),
|
||||||
|
'version': get_fate_version() or '',
|
||||||
|
'host': HOST,
|
||||||
|
'grpc_port': GRPC_PORT,
|
||||||
|
'http_port': HTTP_PORT,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_service_supported(method):
|
||||||
|
"""Decorator to check if `service_name` is supported.
|
||||||
|
The attribute `supported_services` MUST be defined in class.
|
||||||
|
The first and second arguments of `method` MUST be `self` and `service_name`.
|
||||||
|
|
||||||
|
:param Callable method: The class method.
|
||||||
|
:return: The inner wrapper function.
|
||||||
|
:rtype: Callable
|
||||||
|
"""
|
||||||
|
@wraps(method)
|
||||||
|
def magic(self, service_name, *args, **kwargs):
|
||||||
|
if service_name not in self.supported_services:
|
||||||
|
raise ServiceNotSupported(service_name=service_name)
|
||||||
|
return method(self, service_name, *args, **kwargs)
|
||||||
|
return magic
|
||||||
|
|
||||||
|
|
||||||
|
class ServicesDB(abc.ABC):
|
||||||
|
"""Database for storage service urls.
|
||||||
|
Abstract base class for the real backends.
|
||||||
|
|
||||||
|
"""
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def supported_services(self):
|
||||||
|
"""The names of supported services.
|
||||||
|
The returned list SHOULD contain `fateflow` (model download) and `servings` (FATE-Serving).
|
||||||
|
|
||||||
|
:return: The service names.
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _get_serving(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_serving(self):
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self._get_serving()
|
||||||
|
except ServicesError as e:
|
||||||
|
stat_logger.exception(e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _insert(self, service_name, service_url, value=''):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@check_service_supported
|
||||||
|
def insert(self, service_name, service_url, value=''):
|
||||||
|
"""Insert a service url to database.
|
||||||
|
|
||||||
|
:param str service_name: The service name.
|
||||||
|
:param str service_url: The service url.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._insert(service_name, service_url, value)
|
||||||
|
except ServicesError as e:
|
||||||
|
stat_logger.exception(e)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _delete(self, service_name, service_url):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@check_service_supported
|
||||||
|
def delete(self, service_name, service_url):
|
||||||
|
"""Delete a service url from database.
|
||||||
|
|
||||||
|
:param str service_name: The service name.
|
||||||
|
:param str service_url: The service url.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._delete(service_name, service_url)
|
||||||
|
except ServicesError as e:
|
||||||
|
stat_logger.exception(e)
|
||||||
|
|
||||||
|
def register_flow(self):
|
||||||
|
"""Call `self.insert` for insert the flow server address to databae.
|
||||||
|
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
self.insert('flow-server', *server_instance)
|
||||||
|
|
||||||
|
def unregister_flow(self):
|
||||||
|
"""Call `self.delete` for delete the flow server address from databae.
|
||||||
|
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
self.delete('flow-server', server_instance[0])
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _get_urls(self, service_name, with_values=False):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@check_service_supported
|
||||||
|
def get_urls(self, service_name, with_values=False):
|
||||||
|
"""Query service urls from database. The urls may belong to other nodes.
|
||||||
|
Currently, only `fateflow` (model download) urls and `servings` (FATE-Serving) urls are supported.
|
||||||
|
`fateflow` is a url containing scheme, host, port and path,
|
||||||
|
while `servings` only contains host and port.
|
||||||
|
|
||||||
|
:param str service_name: The service name.
|
||||||
|
:return: The service urls.
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self._get_urls(service_name, with_values)
|
||||||
|
except ServicesError as e:
|
||||||
|
stat_logger.exception(e)
|
||||||
|
return []
|
131
web_server/db/db_utils.py
Normal file
131
web_server/db/db_utils.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import operator
|
||||||
|
from functools import reduce
|
||||||
|
from typing import Dict, Type, Union
|
||||||
|
|
||||||
|
from web_server.utils import current_timestamp, timestamp_to_date
|
||||||
|
|
||||||
|
from web_server.db.db_models import DB, DataBaseModel
|
||||||
|
from web_server.db.runtime_config import RuntimeConfig
|
||||||
|
from web_server.utils.log_utils import getLogger
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
LOGGER = getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
@DB.connection_context()
|
||||||
|
def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
|
||||||
|
DB.create_tables([model])
|
||||||
|
|
||||||
|
current_time = current_timestamp()
|
||||||
|
current_date = timestamp_to_date(current_time)
|
||||||
|
|
||||||
|
for data in data_source:
|
||||||
|
if 'f_create_time' not in data:
|
||||||
|
data['f_create_time'] = current_time
|
||||||
|
data['f_create_date'] = timestamp_to_date(data['f_create_time'])
|
||||||
|
data['f_update_time'] = current_time
|
||||||
|
data['f_update_date'] = current_date
|
||||||
|
|
||||||
|
preserve = tuple(data_source[0].keys() - {'f_create_time', 'f_create_date'})
|
||||||
|
|
||||||
|
batch_size = 50 if RuntimeConfig.USE_LOCAL_DATABASE else 1000
|
||||||
|
|
||||||
|
for i in range(0, len(data_source), batch_size):
|
||||||
|
with DB.atomic():
|
||||||
|
query = model.insert_many(data_source[i:i + batch_size])
|
||||||
|
if replace_on_conflict:
|
||||||
|
query = query.on_conflict(preserve=preserve)
|
||||||
|
query.execute()
|
||||||
|
|
||||||
|
|
||||||
|
def get_dynamic_db_model(base, job_id):
|
||||||
|
return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id)))
|
||||||
|
|
||||||
|
|
||||||
|
def get_dynamic_tracking_table_index(job_id):
|
||||||
|
return job_id[:8]
|
||||||
|
|
||||||
|
|
||||||
|
def fill_db_model_object(model_object, human_model_dict):
|
||||||
|
for k, v in human_model_dict.items():
|
||||||
|
attr_name = 'f_%s' % k
|
||||||
|
if hasattr(model_object.__class__, attr_name):
|
||||||
|
setattr(model_object, attr_name, v)
|
||||||
|
return model_object
|
||||||
|
|
||||||
|
|
||||||
|
# https://docs.peewee-orm.com/en/latest/peewee/query_operators.html
|
||||||
|
supported_operators = {
|
||||||
|
'==': operator.eq,
|
||||||
|
'<': operator.lt,
|
||||||
|
'<=': operator.le,
|
||||||
|
'>': operator.gt,
|
||||||
|
'>=': operator.ge,
|
||||||
|
'!=': operator.ne,
|
||||||
|
'<<': operator.lshift,
|
||||||
|
'>>': operator.rshift,
|
||||||
|
'%': operator.mod,
|
||||||
|
'**': operator.pow,
|
||||||
|
'^': operator.xor,
|
||||||
|
'~': operator.inv,
|
||||||
|
}
|
||||||
|
|
||||||
|
def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
|
||||||
|
expression = []
|
||||||
|
|
||||||
|
for field, value in query.items():
|
||||||
|
if not isinstance(value, (list, tuple)):
|
||||||
|
value = ('==', value)
|
||||||
|
op, *val = value
|
||||||
|
|
||||||
|
field = getattr(model, f'f_{field}')
|
||||||
|
value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val)
|
||||||
|
expression.append(value)
|
||||||
|
|
||||||
|
return reduce(operator.iand, expression)
|
||||||
|
|
||||||
|
|
||||||
|
def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0,
|
||||||
|
query: dict = None, order_by: Union[str, list, tuple] = None):
|
||||||
|
data = model.select()
|
||||||
|
if query:
|
||||||
|
data = data.where(query_dict2expression(model, query))
|
||||||
|
count = data.count()
|
||||||
|
|
||||||
|
if not order_by:
|
||||||
|
order_by = 'create_time'
|
||||||
|
if not isinstance(order_by, (list, tuple)):
|
||||||
|
order_by = (order_by, 'asc')
|
||||||
|
order_by, order = order_by
|
||||||
|
order_by = getattr(model, f'f_{order_by}')
|
||||||
|
order_by = getattr(order_by, order)()
|
||||||
|
data = data.order_by(order_by)
|
||||||
|
|
||||||
|
if limit > 0:
|
||||||
|
data = data.limit(limit)
|
||||||
|
if offset > 0:
|
||||||
|
data = data.offset(offset)
|
||||||
|
|
||||||
|
return list(data), count
|
||||||
|
|
||||||
|
|
||||||
|
class StatusEnum(Enum):
|
||||||
|
# 样本可用状态
|
||||||
|
VALID = "1"
|
||||||
|
IN_VALID = "0"
|
141
web_server/db/init_data.py
Normal file
141
web_server/db/init_data.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from web_server.db import LLMType
|
||||||
|
from web_server.db.db_models import init_database_tables as init_web_db
|
||||||
|
from web_server.db.services import UserService
|
||||||
|
from web_server.db.services.llm_service import LLMFactoriesService, LLMService
|
||||||
|
|
||||||
|
|
||||||
|
def init_superuser():
|
||||||
|
user_info = {
|
||||||
|
"id": uuid.uuid1().hex,
|
||||||
|
"password": "admin",
|
||||||
|
"nickname": "admin",
|
||||||
|
"is_superuser": True,
|
||||||
|
"email": "kai.hu@infiniflow.org",
|
||||||
|
"creator": "system",
|
||||||
|
"status": "1",
|
||||||
|
}
|
||||||
|
UserService.save(**user_info)
|
||||||
|
|
||||||
|
|
||||||
|
def init_llm_factory():
|
||||||
|
factory_infos = [{
|
||||||
|
"name": "OpenAI",
|
||||||
|
"logo": "",
|
||||||
|
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||||
|
"status": "1",
|
||||||
|
},{
|
||||||
|
"name": "通义千问",
|
||||||
|
"logo": "",
|
||||||
|
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||||
|
"status": "1",
|
||||||
|
},{
|
||||||
|
"name": "智普AI",
|
||||||
|
"logo": "",
|
||||||
|
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||||
|
"status": "1",
|
||||||
|
},{
|
||||||
|
"name": "文心一言",
|
||||||
|
"logo": "",
|
||||||
|
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||||
|
"status": "1",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
llm_infos = [{
|
||||||
|
"fid": factory_infos[0]["name"],
|
||||||
|
"llm_name": "gpt-3.5-turbo",
|
||||||
|
"tags": "LLM,CHAT,4K",
|
||||||
|
"model_type": LLMType.CHAT.value
|
||||||
|
},{
|
||||||
|
"fid": factory_infos[0]["name"],
|
||||||
|
"llm_name": "gpt-3.5-turbo-16k-0613",
|
||||||
|
"tags": "LLM,CHAT,16k",
|
||||||
|
"model_type": LLMType.CHAT.value
|
||||||
|
},{
|
||||||
|
"fid": factory_infos[0]["name"],
|
||||||
|
"llm_name": "text-embedding-ada-002",
|
||||||
|
"tags": "TEXT EMBEDDING,8K",
|
||||||
|
"model_type": LLMType.EMBEDDING.value
|
||||||
|
},{
|
||||||
|
"fid": factory_infos[0]["name"],
|
||||||
|
"llm_name": "whisper-1",
|
||||||
|
"tags": "SPEECH2TEXT",
|
||||||
|
"model_type": LLMType.SPEECH2TEXT.value
|
||||||
|
},{
|
||||||
|
"fid": factory_infos[0]["name"],
|
||||||
|
"llm_name": "gpt-4",
|
||||||
|
"tags": "LLM,CHAT,8K",
|
||||||
|
"model_type": LLMType.CHAT.value
|
||||||
|
},{
|
||||||
|
"fid": factory_infos[0]["name"],
|
||||||
|
"llm_name": "gpt-4-32k",
|
||||||
|
"tags": "LLM,CHAT,32K",
|
||||||
|
"model_type": LLMType.CHAT.value
|
||||||
|
},{
|
||||||
|
"fid": factory_infos[0]["name"],
|
||||||
|
"llm_name": "gpt-4-vision-preview",
|
||||||
|
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||||
|
"model_type": LLMType.IMAGE2TEXT.value
|
||||||
|
},{
|
||||||
|
"fid": factory_infos[1]["name"],
|
||||||
|
"llm_name": "qwen-turbo",
|
||||||
|
"tags": "LLM,CHAT,8K",
|
||||||
|
"model_type": LLMType.CHAT.value
|
||||||
|
},{
|
||||||
|
"fid": factory_infos[1]["name"],
|
||||||
|
"llm_name": "qwen-plus",
|
||||||
|
"tags": "LLM,CHAT,32K",
|
||||||
|
"model_type": LLMType.CHAT.value
|
||||||
|
},{
|
||||||
|
"fid": factory_infos[1]["name"],
|
||||||
|
"llm_name": "text-embedding-v2",
|
||||||
|
"tags": "TEXT EMBEDDING,2K",
|
||||||
|
"model_type": LLMType.EMBEDDING.value
|
||||||
|
},{
|
||||||
|
"fid": factory_infos[1]["name"],
|
||||||
|
"llm_name": "paraformer-realtime-8k-v1",
|
||||||
|
"tags": "SPEECH2TEXT",
|
||||||
|
"model_type": LLMType.SPEECH2TEXT.value
|
||||||
|
},{
|
||||||
|
"fid": factory_infos[1]["name"],
|
||||||
|
"llm_name": "qwen_vl_chat_v1",
|
||||||
|
"tags": "LLM,CHAT,IMAGE2TEXT",
|
||||||
|
"model_type": LLMType.IMAGE2TEXT.value
|
||||||
|
},
|
||||||
|
]
|
||||||
|
for info in factory_infos:
|
||||||
|
LLMFactoriesService.save(**info)
|
||||||
|
for info in llm_infos:
|
||||||
|
LLMService.save(**info)
|
||||||
|
|
||||||
|
|
||||||
|
def init_web_data():
|
||||||
|
start_time = time.time()
|
||||||
|
if not UserService.get_all().count():
|
||||||
|
init_superuser()
|
||||||
|
|
||||||
|
if not LLMService.get_all().count():init_llm_factory()
|
||||||
|
|
||||||
|
print("init web data success:{}".format(time.time() - start_time))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
init_web_db()
|
||||||
|
init_web_data()
|
21
web_server/db/operatioins.py
Normal file
21
web_server/db/operatioins.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import operator
|
||||||
|
import time
|
||||||
|
import typing
|
||||||
|
from web_server.utils.log_utils import sql_logger
|
||||||
|
import peewee
|
27
web_server/db/reload_config_base.py
Normal file
27
web_server/db/reload_config_base.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
class ReloadConfigBase:
|
||||||
|
@classmethod
|
||||||
|
def get_all(cls):
|
||||||
|
configs = {}
|
||||||
|
for k, v in cls.__dict__.items():
|
||||||
|
if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"):
|
||||||
|
configs[k] = v
|
||||||
|
return configs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, config_name):
|
||||||
|
return getattr(cls, config_name) if hasattr(cls, config_name) else None
|
54
web_server/db/runtime_config.py
Normal file
54
web_server/db/runtime_config.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from web_server.versions import get_versions
|
||||||
|
from .reload_config_base import ReloadConfigBase
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeConfig(ReloadConfigBase):
|
||||||
|
DEBUG = None
|
||||||
|
WORK_MODE = None
|
||||||
|
HTTP_PORT = None
|
||||||
|
JOB_SERVER_HOST = None
|
||||||
|
JOB_SERVER_VIP = None
|
||||||
|
ENV = dict()
|
||||||
|
SERVICE_DB = None
|
||||||
|
LOAD_CONFIG_MANAGER = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_config(cls, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if hasattr(cls, k):
|
||||||
|
setattr(cls, k, v)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_env(cls):
|
||||||
|
cls.ENV.update(get_versions())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_config_manager(cls):
|
||||||
|
cls.LOAD_CONFIG_MANAGER = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_env(cls, key):
|
||||||
|
return cls.ENV.get(key, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_all_env(cls):
|
||||||
|
return cls.ENV
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_service_db(cls, service_db):
|
||||||
|
cls.SERVICE_DB = service_db
|
164
web_server/db/service_registry.py
Normal file
164
web_server/db/service_registry.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import socket
|
||||||
|
from pathlib import Path
|
||||||
|
from web_server import utils
|
||||||
|
from .db_models import DB, ServiceRegistryInfo, ServerRegistryInfo
|
||||||
|
from .reload_config_base import ReloadConfigBase
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceRegistry(ReloadConfigBase):
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def load_service(cls, **kwargs) -> [ServiceRegistryInfo]:
|
||||||
|
service_registry_list = ServiceRegistryInfo.query(**kwargs)
|
||||||
|
return [service for service in service_registry_list]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def save_service_info(cls, server_name, service_name, uri, method="POST", server_info=None, params=None, data=None, headers=None, protocol="http"):
|
||||||
|
if not server_info:
|
||||||
|
server_list = ServerRegistry.query_server_info_from_db(server_name=server_name)
|
||||||
|
if not server_list:
|
||||||
|
raise Exception(f"no found server {server_name}")
|
||||||
|
server_info = server_list[0]
|
||||||
|
url = f"{server_info.f_protocol}://{server_info.f_host}:{server_info.f_port}{uri}"
|
||||||
|
else:
|
||||||
|
url = f"{server_info.get('protocol', protocol)}://{server_info.get('host')}:{server_info.get('port')}{uri}"
|
||||||
|
service_info = {
|
||||||
|
"f_server_name": server_name,
|
||||||
|
"f_service_name": service_name,
|
||||||
|
"f_url": url,
|
||||||
|
"f_method": method,
|
||||||
|
"f_params": params if params else {},
|
||||||
|
"f_data": data if data else {},
|
||||||
|
"f_headers": headers if headers else {}
|
||||||
|
}
|
||||||
|
entity_model, status = ServiceRegistryInfo.get_or_create(
|
||||||
|
f_server_name=server_name,
|
||||||
|
f_service_name=service_name,
|
||||||
|
defaults=service_info)
|
||||||
|
if status is False:
|
||||||
|
for key in service_info:
|
||||||
|
setattr(entity_model, key, service_info[key])
|
||||||
|
entity_model.save(force_insert=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ServerRegistry(ReloadConfigBase):
|
||||||
|
FATEBOARD = None
|
||||||
|
FATE_ON_STANDALONE = None
|
||||||
|
FATE_ON_EGGROLL = None
|
||||||
|
FATE_ON_SPARK = None
|
||||||
|
MODEL_STORE_ADDRESS = None
|
||||||
|
SERVINGS = None
|
||||||
|
FATEMANAGER = None
|
||||||
|
STUDIO = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls):
|
||||||
|
cls.load_server_info_from_conf()
|
||||||
|
cls.load_server_info_from_db()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_server_info_from_conf(cls):
|
||||||
|
path = Path(utils.file_utils.get_project_base_directory()) / 'conf' / utils.SERVICE_CONF
|
||||||
|
conf = utils.file_utils.load_yaml_conf(path)
|
||||||
|
if not isinstance(conf, dict):
|
||||||
|
raise ValueError('invalid config file')
|
||||||
|
|
||||||
|
local_path = path.with_name(f'local.{utils.SERVICE_CONF}')
|
||||||
|
if local_path.exists():
|
||||||
|
local_conf = utils.file_utils.load_yaml_conf(local_path)
|
||||||
|
if not isinstance(local_conf, dict):
|
||||||
|
raise ValueError('invalid local config file')
|
||||||
|
conf.update(local_conf)
|
||||||
|
for k, v in conf.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
setattr(cls, k.upper(), v)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, server_name, server_info):
|
||||||
|
cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol=server_info.get("protocol", "http"))
|
||||||
|
setattr(cls, server_name, server_info)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save(cls, service_config):
|
||||||
|
update_server = {}
|
||||||
|
for server_name, server_info in service_config.items():
|
||||||
|
cls.parameter_check(server_info)
|
||||||
|
api_info = server_info.pop("api", {})
|
||||||
|
for service_name, info in api_info.items():
|
||||||
|
ServiceRegistry.save_service_info(server_name, service_name, uri=info.get('uri'), method=info.get('method', 'POST'), server_info=server_info)
|
||||||
|
cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol="http")
|
||||||
|
setattr(cls, server_name.upper(), server_info)
|
||||||
|
return update_server
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parameter_check(cls, service_info):
|
||||||
|
if "host" in service_info and "port" in service_info:
|
||||||
|
cls.connection_test(service_info.get("host"), service_info.get("port"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def connection_test(cls, ip, port):
|
||||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
result = s.connect_ex((ip, port))
|
||||||
|
if result != 0:
|
||||||
|
raise ConnectionRefusedError(f"connection refused: host {ip}, port {port}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def query(cls, service_name, default=None):
|
||||||
|
service_info = getattr(cls, service_name, default)
|
||||||
|
if not service_info:
|
||||||
|
service_info = utils.get_base_config(service_name, default)
|
||||||
|
return service_info
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def query_server_info_from_db(cls, server_name=None) -> [ServerRegistryInfo]:
|
||||||
|
if server_name:
|
||||||
|
server_list = ServerRegistryInfo.select().where(ServerRegistryInfo.f_server_name==server_name.upper())
|
||||||
|
else:
|
||||||
|
server_list = ServerRegistryInfo.select()
|
||||||
|
return [server for server in server_list]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def load_server_info_from_db(cls):
|
||||||
|
for server in cls.query_server_info_from_db():
|
||||||
|
server_info = {
|
||||||
|
"host": server.f_host,
|
||||||
|
"port": server.f_port,
|
||||||
|
"protocol": server.f_protocol
|
||||||
|
}
|
||||||
|
setattr(cls, server.f_server_name.upper(), server_info)
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def save_server_info_to_db(cls, server_name, host, port, protocol="http"):
|
||||||
|
server_info = {
|
||||||
|
"f_server_name": server_name,
|
||||||
|
"f_host": host,
|
||||||
|
"f_port": port,
|
||||||
|
"f_protocol": protocol
|
||||||
|
}
|
||||||
|
entity_model, status = ServerRegistryInfo.get_or_create(
|
||||||
|
f_server_name=server_name,
|
||||||
|
defaults=server_info)
|
||||||
|
if status is False:
|
||||||
|
for key in server_info:
|
||||||
|
setattr(entity_model, key, server_info[key])
|
||||||
|
entity_model.save(force_insert=False)
|
38
web_server/db/services/__init__.py
Normal file
38
web_server/db/services/__init__.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
from .user_service import UserService
|
||||||
|
|
||||||
|
|
||||||
|
def duplicate_name(query_func, **kwargs):
|
||||||
|
fnm = kwargs["name"]
|
||||||
|
objs = query_func(**kwargs)
|
||||||
|
if not objs: return fnm
|
||||||
|
ext = pathlib.Path(fnm).suffix #.jpg
|
||||||
|
nm = re.sub(r"%s$"%ext, "", fnm)
|
||||||
|
r = re.search(r"\([0-9]+\)$", nm)
|
||||||
|
c = 0
|
||||||
|
if r:
|
||||||
|
c = int(r.group(1))
|
||||||
|
nm = re.sub(r"\([0-9]+\)$", "", nm)
|
||||||
|
c += 1
|
||||||
|
nm = f"{nm}({c})"
|
||||||
|
if ext: nm += f"{ext}"
|
||||||
|
|
||||||
|
kwargs["name"] = nm
|
||||||
|
return duplicate_name(query_func, **kwargs)
|
||||||
|
|
153
web_server/db/services/common_service.py
Normal file
153
web_server/db/services/common_service.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import peewee
|
||||||
|
|
||||||
|
from web_server.db.db_models import DB
|
||||||
|
from web_server.utils import datetime_format
|
||||||
|
|
||||||
|
|
||||||
|
class CommonService:
|
||||||
|
model = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
|
||||||
|
return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_all(cls, cols=None, reverse=None, order_by=None):
|
||||||
|
if cols:
|
||||||
|
query_records = cls.model.select(*cols)
|
||||||
|
else:
|
||||||
|
query_records = cls.model.select()
|
||||||
|
if reverse is not None:
|
||||||
|
if not order_by or not hasattr(cls, order_by):
|
||||||
|
order_by = "create_time"
|
||||||
|
if reverse is True:
|
||||||
|
query_records = query_records.order_by(cls.model.getter_by(order_by).desc())
|
||||||
|
elif reverse is False:
|
||||||
|
query_records = query_records.order_by(cls.model.getter_by(order_by).asc())
|
||||||
|
return query_records
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get(cls, **kwargs):
|
||||||
|
return cls.model.get(**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_or_none(cls, **kwargs):
|
||||||
|
try:
|
||||||
|
return cls.model.get(**kwargs)
|
||||||
|
except peewee.DoesNotExist:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def save(cls, **kwargs):
|
||||||
|
#if "id" not in kwargs:
|
||||||
|
# kwargs["id"] = get_uuid()
|
||||||
|
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||||
|
return sample_obj
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def insert_many(cls, data_list, batch_size=100):
|
||||||
|
with DB.atomic():
|
||||||
|
for i in range(0, len(data_list), batch_size):
|
||||||
|
cls.model.insert_many(data_list[i:i + batch_size]).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def update_many_by_id(cls, data_list):
|
||||||
|
cur = datetime_format(datetime.now())
|
||||||
|
with DB.atomic():
|
||||||
|
for data in data_list:
|
||||||
|
data["update_time"] = cur
|
||||||
|
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def update_by_id(cls, pid, data):
|
||||||
|
data["update_time"] = datetime_format(datetime.now())
|
||||||
|
num = cls.model.update(data).where(cls.model.id == pid).execute()
|
||||||
|
return num
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_by_id(cls, pid):
|
||||||
|
try:
|
||||||
|
obj = cls.model.query(id=pid)[0]
|
||||||
|
return True, obj
|
||||||
|
except Exception as e:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_by_ids(cls, pids, cols=None):
|
||||||
|
if cols:
|
||||||
|
objs = cls.model.select(*cols)
|
||||||
|
else:
|
||||||
|
objs = cls.model.select()
|
||||||
|
return objs.where(cls.model.id.in_(pids))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def delete_by_id(cls, pid):
|
||||||
|
return cls.model.delete().where(cls.model.id == pid).execute()
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def filter_delete(cls, filters):
|
||||||
|
with DB.atomic():
|
||||||
|
num = cls.model.delete().where(*filters).execute()
|
||||||
|
return num
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def filter_update(cls, filters, update_data):
|
||||||
|
with DB.atomic():
|
||||||
|
cls.model.update(update_data).where(*filters).execute()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cut_list(tar_list, n):
|
||||||
|
length = len(tar_list)
|
||||||
|
arr = range(length)
|
||||||
|
result = [tuple(tar_list[x:(x + n)]) for x in arr[::n]]
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None):
|
||||||
|
in_filters_tuple_list = cls.cut_list(in_filters_list, 20)
|
||||||
|
if not filters:
|
||||||
|
filters = []
|
||||||
|
res_list = []
|
||||||
|
if cols:
|
||||||
|
for i in in_filters_tuple_list:
|
||||||
|
query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters)
|
||||||
|
if query_records:
|
||||||
|
res_list.extend([query_record for query_record in query_records])
|
||||||
|
else:
|
||||||
|
for i in in_filters_tuple_list:
|
||||||
|
query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters)
|
||||||
|
if query_records:
|
||||||
|
res_list.extend([query_record for query_record in query_records])
|
||||||
|
return res_list
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user