mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-14 02:45:54 +08:00
Support Ollama (#261)
### What problem does this PR solve? Issue link:#221 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
265a7a283a
commit
3708b97db9
13
README.md
13
README.md
@ -1,6 +1,6 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://demo.ragflow.io/">
|
<a href="https://demo.ragflow.io/">
|
||||||
<img src="web/src/assets/logo-with-text.png" width="350" alt="ragflow logo">
|
<img src="web/src/assets/logo-with-text.png" width="520" alt="ragflow logo">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@ -124,12 +124,12 @@
|
|||||||
|
|
||||||
* Running on all addresses (0.0.0.0)
|
* Running on all addresses (0.0.0.0)
|
||||||
* Running on http://127.0.0.1:9380
|
* Running on http://127.0.0.1:9380
|
||||||
* Running on http://172.22.0.5:9380
|
* Running on http://x.x.x.x:9380
|
||||||
INFO:werkzeug:Press CTRL+C to quit
|
INFO:werkzeug:Press CTRL+C to quit
|
||||||
```
|
```
|
||||||
|
|
||||||
5. In your web browser, enter the IP address of your server as prompted and log in to RAGFlow.
|
5. In your web browser, enter the IP address of your server and log in to RAGFlow.
|
||||||
> In the given scenario, you only need to enter `http://IP_of_RAGFlow ` (sans port number) as the default HTTP serving port `80` can be omitted when using the default configurations.
|
> In the given scenario, you only need to enter `http://IP_OF_YOUR_MACHINE` (sans port number) as the default HTTP serving port `80` can be omitted when using the default configurations.
|
||||||
6. In [service_conf.yaml](./docker/service_conf.yaml), select the desired LLM factory in `user_default_llm` and update the `API_KEY` field with the corresponding API key.
|
6. In [service_conf.yaml](./docker/service_conf.yaml), select the desired LLM factory in `user_default_llm` and update the `API_KEY` field with the corresponding API key.
|
||||||
|
|
||||||
> See [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md) for more information.
|
> See [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md) for more information.
|
||||||
@ -168,6 +168,11 @@ $ cd ragflow/docker
|
|||||||
$ docker compose up -d
|
$ docker compose up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 🆕 Latest Features
|
||||||
|
|
||||||
|
- Support [Ollam](./docs/ollama.md) for local LLM deployment.
|
||||||
|
- Support Chinese UI.
|
||||||
|
|
||||||
## 📜 Roadmap
|
## 📜 Roadmap
|
||||||
|
|
||||||
See the [RAGFlow Roadmap 2024](https://github.com/infiniflow/ragflow/issues/162)
|
See the [RAGFlow Roadmap 2024](https://github.com/infiniflow/ragflow/issues/162)
|
||||||
|
@ -124,12 +124,12 @@
|
|||||||
|
|
||||||
* Running on all addresses (0.0.0.0)
|
* Running on all addresses (0.0.0.0)
|
||||||
* Running on http://127.0.0.1:9380
|
* Running on http://127.0.0.1:9380
|
||||||
* Running on http://172.22.0.5:9380
|
* Running on http://x.x.x.x:9380
|
||||||
INFO:werkzeug:Press CTRL+C to quit
|
INFO:werkzeug:Press CTRL+C to quit
|
||||||
```
|
```
|
||||||
|
|
||||||
5. ウェブブラウザで、プロンプトに従ってサーバーの IP アドレスを入力し、RAGFlow にログインします。
|
5. ウェブブラウザで、プロンプトに従ってサーバーの IP アドレスを入力し、RAGFlow にログインします。
|
||||||
> デフォルトの設定を使用する場合、デフォルトの HTTP サービングポート `80` は省略できるので、与えられたシナリオでは、`http://172.22.0.5`(ポート番号は省略)だけを入力すればよい。
|
> デフォルトの設定を使用する場合、デフォルトの HTTP サービングポート `80` は省略できるので、与えられたシナリオでは、`http://IP_OF_YOUR_MACHINE`(ポート番号は省略)だけを入力すればよい。
|
||||||
6. [service_conf.yaml](./docker/service_conf.yaml) で、`user_default_llm` で希望の LLM ファクトリを選択し、`API_KEY` フィールドを対応する API キーで更新する。
|
6. [service_conf.yaml](./docker/service_conf.yaml) で、`user_default_llm` で希望の LLM ファクトリを選択し、`API_KEY` フィールドを対応する API キーで更新する。
|
||||||
|
|
||||||
> 詳しくは [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md) を参照してください。
|
> 詳しくは [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md) を参照してください。
|
||||||
@ -168,6 +168,11 @@ $ cd ragflow/docker
|
|||||||
$ docker compose up -d
|
$ docker compose up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 🆕 最新の新機能
|
||||||
|
|
||||||
|
- [Ollam](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。
|
||||||
|
- 中国語インターフェースをサポートします。
|
||||||
|
|
||||||
## 📜 ロードマップ
|
## 📜 ロードマップ
|
||||||
|
|
||||||
[RAGFlow ロードマップ 2024](https://github.com/infiniflow/ragflow/issues/162) を参照
|
[RAGFlow ロードマップ 2024](https://github.com/infiniflow/ragflow/issues/162) を参照
|
||||||
|
15
README_zh.md
15
README_zh.md
@ -124,12 +124,12 @@
|
|||||||
|
|
||||||
* Running on all addresses (0.0.0.0)
|
* Running on all addresses (0.0.0.0)
|
||||||
* Running on http://127.0.0.1:9380
|
* Running on http://127.0.0.1:9380
|
||||||
* Running on http://172.22.0.5:9380
|
* Running on http://x.x.x.x:9380
|
||||||
INFO:werkzeug:Press CTRL+C to quit
|
INFO:werkzeug:Press CTRL+C to quit
|
||||||
```
|
```
|
||||||
|
|
||||||
5. 根据刚才的界面提示在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。
|
5. 在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。
|
||||||
> 上面这个例子中,您只需输入 http://172.22.0.5 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80)。
|
> 上面这个例子中,您只需输入 http://IP_OF_YOUR_MACHINE 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80)。
|
||||||
6. 在 [service_conf.yaml](./docker/service_conf.yaml) 文件的 `user_default_llm` 栏配置 LLM factory,并在 `API_KEY` 栏填写和你选择的大模型相对应的 API key。
|
6. 在 [service_conf.yaml](./docker/service_conf.yaml) 文件的 `user_default_llm` 栏配置 LLM factory,并在 `API_KEY` 栏填写和你选择的大模型相对应的 API key。
|
||||||
|
|
||||||
> 详见 [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md)。
|
> 详见 [./docs/llm_api_key_setup.md](./docs/llm_api_key_setup.md)。
|
||||||
@ -168,9 +168,14 @@ $ cd ragflow/docker
|
|||||||
$ docker compose up -d
|
$ docker compose up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 🆕 最近新特性
|
||||||
|
|
||||||
|
- 支持用 [Ollam](./docs/ollama.md) 对大模型进行本地化部署。
|
||||||
|
- 支持中文界面。
|
||||||
|
|
||||||
## 📜 路线图
|
## 📜 路线图
|
||||||
|
|
||||||
详见 [RAGFlow Roadmap 2024](https://github.com/infiniflow/ragflow/issues/162)。
|
详见 [RAGFlow Roadmap 2024](https://github.com/infiniflow/ragflow/issues/162) 。
|
||||||
|
|
||||||
## 🏄 开源社区
|
## 🏄 开源社区
|
||||||
|
|
||||||
@ -179,7 +184,7 @@ $ docker compose up -d
|
|||||||
|
|
||||||
## 🙌 贡献指南
|
## 🙌 贡献指南
|
||||||
|
|
||||||
RAGFlow 只有通过开源协作才能蓬勃发展。秉持这一精神,我们欢迎来自社区的各种贡献。如果您有意参与其中,请查阅我们的[贡献者指南](https://github.com/infiniflow/ragflow/blob/main/docs/CONTRIBUTING.md)。
|
RAGFlow 只有通过开源协作才能蓬勃发展。秉持这一精神,我们欢迎来自社区的各种贡献。如果您有意参与其中,请查阅我们的[贡献者指南](https://github.com/infiniflow/ragflow/blob/main/docs/CONTRIBUTING.md) 。
|
||||||
|
|
||||||
## 👥 加入社区
|
## 👥 加入社区
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ def message_fit_in(msg, max_length=4000):
|
|||||||
if c < max_length:
|
if c < max_length:
|
||||||
return c, msg
|
return c, msg
|
||||||
|
|
||||||
msg_ = [m for m in msg[:-1] if m.role == "system"]
|
msg_ = [m for m in msg[:-1] if m["role"] == "system"]
|
||||||
msg_.append(msg[-1])
|
msg_.append(msg[-1])
|
||||||
msg = msg_
|
msg = msg_
|
||||||
c = count()
|
c = count()
|
||||||
|
@ -81,7 +81,7 @@ def upload():
|
|||||||
"parser_id": kb.parser_id,
|
"parser_id": kb.parser_id,
|
||||||
"parser_config": kb.parser_config,
|
"parser_config": kb.parser_config,
|
||||||
"created_by": current_user.id,
|
"created_by": current_user.id,
|
||||||
"type": filename_type(filename),
|
"type": filetype,
|
||||||
"name": filename,
|
"name": filename,
|
||||||
"location": location,
|
"location": location,
|
||||||
"size": len(blob),
|
"size": len(blob),
|
||||||
|
@ -91,6 +91,57 @@ def set_api_key():
|
|||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/add_llm', methods=['POST'])
|
||||||
|
@login_required
|
||||||
|
@validate_request("llm_factory", "llm_name", "model_type")
|
||||||
|
def add_llm():
|
||||||
|
req = request.json
|
||||||
|
llm = {
|
||||||
|
"tenant_id": current_user.id,
|
||||||
|
"llm_factory": req["llm_factory"],
|
||||||
|
"model_type": req["model_type"],
|
||||||
|
"llm_name": req["llm_name"],
|
||||||
|
"api_base": req.get("api_base", ""),
|
||||||
|
"api_key": "xxxxxxxxxxxxxxx"
|
||||||
|
}
|
||||||
|
|
||||||
|
factory = req["llm_factory"]
|
||||||
|
msg = ""
|
||||||
|
if llm["model_type"] == LLMType.EMBEDDING.value:
|
||||||
|
mdl = EmbeddingModel[factory](
|
||||||
|
key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
|
||||||
|
try:
|
||||||
|
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||||
|
if len(arr[0]) == 0 or tc == 0:
|
||||||
|
raise Exception("Fail")
|
||||||
|
except Exception as e:
|
||||||
|
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
|
||||||
|
elif llm["model_type"] == LLMType.CHAT.value:
|
||||||
|
mdl = ChatModel[factory](
|
||||||
|
key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
|
||||||
|
try:
|
||||||
|
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {
|
||||||
|
"temperature": 0.9})
|
||||||
|
if not tc:
|
||||||
|
raise Exception(m)
|
||||||
|
except Exception as e:
|
||||||
|
msg += f"\nFail to access model({llm['llm_name']})." + str(
|
||||||
|
e)
|
||||||
|
else:
|
||||||
|
# TODO: check other type of models
|
||||||
|
pass
|
||||||
|
|
||||||
|
if msg:
|
||||||
|
return get_data_error_result(retmsg=msg)
|
||||||
|
|
||||||
|
|
||||||
|
if not TenantLLMService.filter_update(
|
||||||
|
[TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
|
||||||
|
TenantLLMService.save(**llm)
|
||||||
|
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/my_llms', methods=['GET'])
|
@manager.route('/my_llms', methods=['GET'])
|
||||||
@login_required
|
@login_required
|
||||||
def my_llms():
|
def my_llms():
|
||||||
@ -125,6 +176,12 @@ def list():
|
|||||||
for m in llms:
|
for m in llms:
|
||||||
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
|
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
|
||||||
|
|
||||||
|
llm_set = set([m["llm_name"] for m in llms])
|
||||||
|
for o in objs:
|
||||||
|
if not o.api_key:continue
|
||||||
|
if o.llm_name in llm_set:continue
|
||||||
|
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
|
||||||
|
|
||||||
res = {}
|
res = {}
|
||||||
for m in llms:
|
for m in llms:
|
||||||
if model_type and m["model_type"] != model_type:
|
if model_type and m["model_type"] != model_type:
|
||||||
|
@ -181,6 +181,10 @@ def user_info():
|
|||||||
|
|
||||||
|
|
||||||
def rollback_user_registration(user_id):
|
def rollback_user_registration(user_id):
|
||||||
|
try:
|
||||||
|
UserService.delete_by_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
try:
|
try:
|
||||||
TenantService.delete_by_id(user_id)
|
TenantService.delete_by_id(user_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -18,7 +18,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from api.db import LLMType, UserTenantRole
|
from api.db import LLMType, UserTenantRole
|
||||||
from api.db.db_models import init_database_tables as init_web_db
|
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM
|
||||||
from api.db.services import UserService
|
from api.db.services import UserService
|
||||||
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
@ -100,16 +100,16 @@ factory_infos = [{
|
|||||||
"status": "1",
|
"status": "1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Local",
|
"name": "Ollama",
|
||||||
"logo": "",
|
"logo": "",
|
||||||
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
||||||
"status": "1",
|
"status": "1",
|
||||||
}, {
|
}, {
|
||||||
"name": "Moonshot",
|
"name": "Moonshot",
|
||||||
"logo": "",
|
"logo": "",
|
||||||
"tags": "LLM,TEXT EMBEDDING",
|
"tags": "LLM,TEXT EMBEDDING",
|
||||||
"status": "1",
|
"status": "1",
|
||||||
}
|
},
|
||||||
# {
|
# {
|
||||||
# "name": "文心一言",
|
# "name": "文心一言",
|
||||||
# "logo": "",
|
# "logo": "",
|
||||||
@ -230,20 +230,6 @@ def init_llm_factory():
|
|||||||
"max_tokens": 512,
|
"max_tokens": 512,
|
||||||
"model_type": LLMType.EMBEDDING.value
|
"model_type": LLMType.EMBEDDING.value
|
||||||
},
|
},
|
||||||
# ---------------------- 本地 ----------------------
|
|
||||||
{
|
|
||||||
"fid": factory_infos[3]["name"],
|
|
||||||
"llm_name": "qwen-14B-chat",
|
|
||||||
"tags": "LLM,CHAT,",
|
|
||||||
"max_tokens": 4096,
|
|
||||||
"model_type": LLMType.CHAT.value
|
|
||||||
}, {
|
|
||||||
"fid": factory_infos[3]["name"],
|
|
||||||
"llm_name": "flag-embedding",
|
|
||||||
"tags": "TEXT EMBEDDING,",
|
|
||||||
"max_tokens": 128 * 1000,
|
|
||||||
"model_type": LLMType.EMBEDDING.value
|
|
||||||
},
|
|
||||||
# ------------------------ Moonshot -----------------------
|
# ------------------------ Moonshot -----------------------
|
||||||
{
|
{
|
||||||
"fid": factory_infos[4]["name"],
|
"fid": factory_infos[4]["name"],
|
||||||
@ -282,6 +268,9 @@ def init_llm_factory():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
LLMFactoriesService.filter_delete([LLMFactories.name=="Local"])
|
||||||
|
LLMService.filter_delete([LLM.fid=="Local"])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
drop table llm;
|
drop table llm;
|
||||||
drop table llm_factories;
|
drop table llm_factories;
|
||||||
@ -295,8 +284,7 @@ def init_llm_factory():
|
|||||||
def init_web_data():
|
def init_web_data():
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if LLMFactoriesService.get_all().count() != len(factory_infos):
|
init_llm_factory()
|
||||||
init_llm_factory()
|
|
||||||
if not UserService.get_all().count():
|
if not UserService.get_all().count():
|
||||||
init_superuser()
|
init_superuser()
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ services:
|
|||||||
- 443:443
|
- 443:443
|
||||||
volumes:
|
volumes:
|
||||||
- ./service_conf.yaml:/ragflow/conf/service_conf.yaml
|
- ./service_conf.yaml:/ragflow/conf/service_conf.yaml
|
||||||
|
- ./entrypoint.sh:/ragflow/entrypoint.sh
|
||||||
- ./ragflow-logs:/ragflow/logs
|
- ./ragflow-logs:/ragflow/logs
|
||||||
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf
|
||||||
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
|
- ./nginx/proxy.conf:/etc/nginx/proxy.conf
|
||||||
|
40
docs/ollama.md
Normal file
40
docs/ollama.md
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# Ollama
|
||||||
|
|
||||||
|
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||||
|
<img src="https://github.com/infiniflow/ragflow/assets/12318111/2019e7ee-1e8a-412e-9349-11bbf702e549" width="130"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
One-click deployment of local LLMs, that is [Ollama](https://github.com/ollama/ollama).
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
- [Ollama on Linux](https://github.com/ollama/ollama/blob/main/docs/linux.md)
|
||||||
|
- [Ollama Windows Preview](https://github.com/ollama/ollama/blob/main/docs/windows.md)
|
||||||
|
- [Docker](https://hub.docker.com/r/ollama/ollama)
|
||||||
|
|
||||||
|
## Launch Ollama
|
||||||
|
|
||||||
|
Decide which LLM you want to deploy ([here's a list for supported LLM](https://ollama.com/library)), say, **mistral**:
|
||||||
|
```bash
|
||||||
|
$ ollama run mistral
|
||||||
|
```
|
||||||
|
Or,
|
||||||
|
```bash
|
||||||
|
$ docker exec -it ollama ollama run mistral
|
||||||
|
```
|
||||||
|
|
||||||
|
## Use Ollama in RAGFlow
|
||||||
|
|
||||||
|
- Go to 'Settings > Model Providers > Models to be added > Ollama'.
|
||||||
|
|
||||||
|
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||||
|
<img src="https://github.com/infiniflow/ragflow/assets/12318111/2019e7ee-1e8a-412e-9349-11bbf702e549" width="130"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
> Base URL: Enter the base URL where the Ollama service is accessible, like, http://<your-ollama-endpoint-domain>:11434
|
||||||
|
|
||||||
|
- Use Ollama Models.
|
||||||
|
|
||||||
|
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||||
|
<img src="https://github.com/infiniflow/ragflow/assets/12318111/2019e7ee-1e8a-412e-9349-11bbf702e549" width="130"/>
|
||||||
|
</div>
|
@ -19,7 +19,7 @@ from .cv_model import *
|
|||||||
|
|
||||||
|
|
||||||
EmbeddingModel = {
|
EmbeddingModel = {
|
||||||
"Local": HuEmbedding,
|
"Ollama": OllamaEmbed,
|
||||||
"OpenAI": OpenAIEmbed,
|
"OpenAI": OpenAIEmbed,
|
||||||
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
||||||
"ZHIPU-AI": ZhipuEmbed,
|
"ZHIPU-AI": ZhipuEmbed,
|
||||||
@ -29,7 +29,7 @@ EmbeddingModel = {
|
|||||||
|
|
||||||
CvModel = {
|
CvModel = {
|
||||||
"OpenAI": GptV4,
|
"OpenAI": GptV4,
|
||||||
"Local": LocalCV,
|
"Ollama": OllamaCV,
|
||||||
"Tongyi-Qianwen": QWenCV,
|
"Tongyi-Qianwen": QWenCV,
|
||||||
"ZHIPU-AI": Zhipu4V,
|
"ZHIPU-AI": Zhipu4V,
|
||||||
"Moonshot": LocalCV
|
"Moonshot": LocalCV
|
||||||
@ -40,7 +40,7 @@ ChatModel = {
|
|||||||
"OpenAI": GptTurbo,
|
"OpenAI": GptTurbo,
|
||||||
"ZHIPU-AI": ZhipuChat,
|
"ZHIPU-AI": ZhipuChat,
|
||||||
"Tongyi-Qianwen": QWenChat,
|
"Tongyi-Qianwen": QWenChat,
|
||||||
"Local": LocalLLM,
|
"Ollama": OllamaChat,
|
||||||
"Moonshot": MoonshotChat
|
"Moonshot": MoonshotChat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ from dashscope import Generation
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import openai
|
import openai
|
||||||
|
from ollama import Client
|
||||||
from rag.nlp import is_english
|
from rag.nlp import is_english
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
|
|
||||||
@ -129,6 +130,32 @@ class ZhipuChat(Base):
|
|||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaChat(Base):
|
||||||
|
def __init__(self, key, model_name, **kwargs):
|
||||||
|
self.client = Client(host=kwargs["base_url"])
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def chat(self, system, history, gen_conf):
|
||||||
|
if system:
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
try:
|
||||||
|
options = {"temperature": gen_conf.get("temperature", 0.1),
|
||||||
|
"num_predict": gen_conf.get("max_tokens", 128),
|
||||||
|
"top_k": gen_conf.get("top_p", 0.3),
|
||||||
|
"presence_penalty": gen_conf.get("presence_penalty", 0.4),
|
||||||
|
"frequency_penalty": gen_conf.get("frequency_penalty", 0.7),
|
||||||
|
}
|
||||||
|
response = self.client.chat(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=history,
|
||||||
|
options=options
|
||||||
|
)
|
||||||
|
ans = response["message"]["content"].strip()
|
||||||
|
return ans, response["eval_count"]
|
||||||
|
except Exception as e:
|
||||||
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
|
||||||
class LocalLLM(Base):
|
class LocalLLM(Base):
|
||||||
class RPCProxy:
|
class RPCProxy:
|
||||||
def __init__(self, host, port):
|
def __init__(self, host, port):
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
from zhipuai import ZhipuAI
|
from zhipuai import ZhipuAI
|
||||||
import io
|
import io
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from ollama import Client
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import os
|
import os
|
||||||
@ -140,6 +140,28 @@ class Zhipu4V(Base):
|
|||||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaCV(Base):
|
||||||
|
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||||
|
self.client = Client(host=kwargs["base_url"])
|
||||||
|
self.model_name = model_name
|
||||||
|
self.lang = lang
|
||||||
|
|
||||||
|
def describe(self, image, max_tokens=1024):
|
||||||
|
prompt = self.prompt("")
|
||||||
|
try:
|
||||||
|
options = {"num_predict": max_tokens}
|
||||||
|
response = self.client.generate(
|
||||||
|
model=self.model_name,
|
||||||
|
prompt=prompt[0]["content"][1]["text"],
|
||||||
|
images=[image],
|
||||||
|
options=options
|
||||||
|
)
|
||||||
|
ans = response["response"].strip()
|
||||||
|
return ans, 128
|
||||||
|
except Exception as e:
|
||||||
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
|
||||||
class LocalCV(Base):
|
class LocalCV(Base):
|
||||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||||
pass
|
pass
|
||||||
|
@ -16,13 +16,12 @@
|
|||||||
from zhipuai import ZhipuAI
|
from zhipuai import ZhipuAI
|
||||||
import os
|
import os
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from ollama import Client
|
||||||
import dashscope
|
import dashscope
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from FlagEmbedding import FlagModel
|
from FlagEmbedding import FlagModel
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from rag.utils import num_tokens_from_string
|
from rag.utils import num_tokens_from_string
|
||||||
@ -150,3 +149,24 @@ class ZhipuEmbed(Base):
|
|||||||
res = self.client.embeddings.create(input=text,
|
res = self.client.embeddings.create(input=text,
|
||||||
model=self.model_name)
|
model=self.model_name)
|
||||||
return np.array(res.data[0].embedding), res.usage.total_tokens
|
return np.array(res.data[0].embedding), res.usage.total_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbed(Base):
|
||||||
|
def __init__(self, key, model_name, **kwargs):
|
||||||
|
self.client = Client(host=kwargs["base_url"])
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def encode(self, texts: list, batch_size=32):
|
||||||
|
arr = []
|
||||||
|
tks_num = 0
|
||||||
|
for txt in texts:
|
||||||
|
res = self.client.embeddings(prompt=txt,
|
||||||
|
model=self.model_name)
|
||||||
|
arr.append(res["embedding"])
|
||||||
|
tks_num += 128
|
||||||
|
return np.array(arr), tks_num
|
||||||
|
|
||||||
|
def encode_queries(self, text):
|
||||||
|
res = self.client.embeddings(prompt=text,
|
||||||
|
model=self.model_name)
|
||||||
|
return np.array(res["embedding"]), 128
|
||||||
|
@ -23,7 +23,8 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import signal
|
||||||
|
from contextlib import contextmanager
|
||||||
from rag.settings import database_logger
|
from rag.settings import database_logger
|
||||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||||
|
|
||||||
@ -97,8 +98,21 @@ def collect(comm, mod, tm):
|
|||||||
cron_logger.info("TOTAL:{}, To:{}".format(len(tasks), mtm))
|
cron_logger.info("TOTAL:{}, To:{}".format(len(tasks), mtm))
|
||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def timeout(time):
|
||||||
|
# Register a function to raise a TimeoutError on the signal.
|
||||||
|
signal.signal(signal.SIGALRM, raise_timeout)
|
||||||
|
# Schedule the signal to be sent after ``time``.
|
||||||
|
signal.alarm(time)
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def raise_timeout(signum, frame):
|
||||||
|
raise TimeoutError
|
||||||
|
|
||||||
|
|
||||||
def build(row):
|
def build(row):
|
||||||
|
from timeit import default_timer as timer
|
||||||
if row["size"] > DOC_MAXIMUM_SIZE:
|
if row["size"] > DOC_MAXIMUM_SIZE:
|
||||||
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
||||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||||
@ -111,11 +125,14 @@ def build(row):
|
|||||||
row["to_page"])
|
row["to_page"])
|
||||||
chunker = FACTORY[row["parser_id"].lower()]
|
chunker = FACTORY[row["parser_id"].lower()]
|
||||||
try:
|
try:
|
||||||
cron_logger.info(
|
st = timer()
|
||||||
"Chunkking {}/{}".format(row["location"], row["name"]))
|
with timeout(30):
|
||||||
cks = chunker.chunk(row["name"], binary=MINIO.get(row["kb_id"], row["location"]), from_page=row["from_page"],
|
binary = MINIO.get(row["kb_id"], row["location"])
|
||||||
|
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
|
||||||
to_page=row["to_page"], lang=row["language"], callback=callback,
|
to_page=row["to_page"], lang=row["language"], callback=callback,
|
||||||
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
|
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
|
||||||
|
cron_logger.info(
|
||||||
|
"Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"]))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if re.search("(No such file|not found)", str(e)):
|
if re.search("(No such file|not found)", str(e)):
|
||||||
callback(-1, "Can not find file <%s>" % row["name"])
|
callback(-1, "Can not find file <%s>" % row["name"])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user