mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-22 06:00:00 +08:00

### What problem does this PR solve? Add license statement. ### Type of change - [x] Refactoring Signed-off-by: Jin Hai <haijin.chn@gmail.com>
272 lines
7.5 KiB
Python
272 lines
7.5 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
import numpy as np
|
|
|
|
DEFAULT_MATCH_VECTOR_TOPN = 10
|
|
DEFAULT_MATCH_SPARSE_TOPN = 10
|
|
VEC = list | np.ndarray
|
|
|
|
|
|
@dataclass
|
|
class SparseVector:
|
|
indices: list[int]
|
|
values: list[float] | list[int] | None = None
|
|
|
|
def __post_init__(self):
|
|
assert (self.values is None) or (len(self.indices) == len(self.values))
|
|
|
|
def to_dict_old(self):
|
|
d = {"indices": self.indices}
|
|
if self.values is not None:
|
|
d["values"] = self.values
|
|
return d
|
|
|
|
def to_dict(self):
|
|
if self.values is None:
|
|
raise ValueError("SparseVector.values is None")
|
|
result = {}
|
|
for i, v in zip(self.indices, self.values):
|
|
result[str(i)] = v
|
|
return result
|
|
|
|
@staticmethod
|
|
def from_dict(d):
|
|
return SparseVector(d["indices"], d.get("values"))
|
|
|
|
def __str__(self):
|
|
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
|
|
|
|
def __repr__(self):
|
|
return str(self)
|
|
|
|
|
|
class MatchTextExpr(ABC):
|
|
def __init__(
|
|
self,
|
|
fields: list[str],
|
|
matching_text: str,
|
|
topn: int,
|
|
extra_options: dict = dict(),
|
|
):
|
|
self.fields = fields
|
|
self.matching_text = matching_text
|
|
self.topn = topn
|
|
self.extra_options = extra_options
|
|
|
|
|
|
class MatchDenseExpr(ABC):
|
|
def __init__(
|
|
self,
|
|
vector_column_name: str,
|
|
embedding_data: VEC,
|
|
embedding_data_type: str,
|
|
distance_type: str,
|
|
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
|
|
extra_options: dict = dict(),
|
|
):
|
|
self.vector_column_name = vector_column_name
|
|
self.embedding_data = embedding_data
|
|
self.embedding_data_type = embedding_data_type
|
|
self.distance_type = distance_type
|
|
self.topn = topn
|
|
self.extra_options = extra_options
|
|
|
|
|
|
class MatchSparseExpr(ABC):
|
|
def __init__(
|
|
self,
|
|
vector_column_name: str,
|
|
sparse_data: SparseVector | dict,
|
|
distance_type: str,
|
|
topn: int,
|
|
opt_params: dict | None = None,
|
|
):
|
|
self.vector_column_name = vector_column_name
|
|
self.sparse_data = sparse_data
|
|
self.distance_type = distance_type
|
|
self.topn = topn
|
|
self.opt_params = opt_params
|
|
|
|
|
|
class MatchTensorExpr(ABC):
|
|
def __init__(
|
|
self,
|
|
column_name: str,
|
|
query_data: VEC,
|
|
query_data_type: str,
|
|
topn: int,
|
|
extra_option: dict | None = None,
|
|
):
|
|
self.column_name = column_name
|
|
self.query_data = query_data
|
|
self.query_data_type = query_data_type
|
|
self.topn = topn
|
|
self.extra_option = extra_option
|
|
|
|
|
|
class FusionExpr(ABC):
|
|
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
|
|
self.method = method
|
|
self.topn = topn
|
|
self.fusion_params = fusion_params
|
|
|
|
|
|
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
|
|
|
|
class OrderByExpr(ABC):
|
|
def __init__(self):
|
|
self.fields = list()
|
|
def asc(self, field: str):
|
|
self.fields.append((field, 0))
|
|
return self
|
|
def desc(self, field: str):
|
|
self.fields.append((field, 1))
|
|
return self
|
|
def fields(self):
|
|
return self.fields
|
|
|
|
class DocStoreConnection(ABC):
|
|
"""
|
|
Database operations
|
|
"""
|
|
|
|
@abstractmethod
|
|
def dbType(self) -> str:
|
|
"""
|
|
Return the type of the database.
|
|
"""
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def health(self) -> dict:
|
|
"""
|
|
Return the health status of the database.
|
|
"""
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
"""
|
|
Table operations
|
|
"""
|
|
|
|
@abstractmethod
|
|
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
|
"""
|
|
Create an index with given name
|
|
"""
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
|
"""
|
|
Delete an index with given name
|
|
"""
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
|
"""
|
|
Check if an index with given name exists
|
|
"""
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
"""
|
|
CRUD operations
|
|
"""
|
|
|
|
@abstractmethod
|
|
def search(
|
|
self, selectFields: list[str],
|
|
highlightFields: list[str],
|
|
condition: dict,
|
|
matchExprs: list[MatchExpr],
|
|
orderBy: OrderByExpr,
|
|
offset: int,
|
|
limit: int,
|
|
indexNames: str|list[str],
|
|
knowledgebaseIds: list[str],
|
|
aggFields: list[str] = [],
|
|
rank_feature: dict | None = None
|
|
):
|
|
"""
|
|
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
|
|
"""
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
|
"""
|
|
Get single chunk with given id
|
|
"""
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
|
"""
|
|
Update or insert a bulk of rows
|
|
"""
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
|
"""
|
|
Update rows with given conjunctive equivalent filtering condition
|
|
"""
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
|
"""
|
|
Delete rows with given conjunctive equivalent filtering condition
|
|
"""
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
"""
|
|
Helper functions for search result
|
|
"""
|
|
|
|
@abstractmethod
|
|
def getTotal(self, res):
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def getChunkIds(self, res):
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def getHighlight(self, res, keywords: list[str], fieldnm: str):
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def getAggregation(self, res, fieldnm: str):
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
"""
|
|
SQL
|
|
"""
|
|
@abstractmethod
|
|
def sql(sql: str, fetch_size: int, format: str):
|
|
"""
|
|
Run the sql generated by text-to-sql
|
|
"""
|
|
raise NotImplementedError("Not implemented")
|