mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 10:39:00 +08:00
add support for Replicate (#1980)
### What problem does this PR solve? #1853 add support for Replicate ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
parent
be5a67895e
commit
79426fc41f
@ -149,7 +149,7 @@ def add_llm():
|
|||||||
msg = ""
|
msg = ""
|
||||||
if llm["model_type"] == LLMType.EMBEDDING.value:
|
if llm["model_type"] == LLMType.EMBEDDING.value:
|
||||||
mdl = EmbeddingModel[factory](
|
mdl = EmbeddingModel[factory](
|
||||||
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
|
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
|
||||||
model_name=llm["llm_name"],
|
model_name=llm["llm_name"],
|
||||||
base_url=llm["api_base"])
|
base_url=llm["api_base"])
|
||||||
try:
|
try:
|
||||||
@ -160,7 +160,7 @@ def add_llm():
|
|||||||
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
|
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
|
||||||
elif llm["model_type"] == LLMType.CHAT.value:
|
elif llm["model_type"] == LLMType.CHAT.value:
|
||||||
mdl = ChatModel[factory](
|
mdl = ChatModel[factory](
|
||||||
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
|
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
|
||||||
model_name=llm["llm_name"],
|
model_name=llm["llm_name"],
|
||||||
base_url=llm["api_base"]
|
base_url=llm["api_base"]
|
||||||
)
|
)
|
||||||
|
@ -3113,6 +3113,13 @@
|
|||||||
"model_type": "image2text"
|
"model_type": "image2text"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Replicate",
|
||||||
|
"logo": "",
|
||||||
|
"tags": "LLM,TEXT EMBEDDING",
|
||||||
|
"status": "1",
|
||||||
|
"llm": []
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,8 @@ EmbeddingModel = {
|
|||||||
"TogetherAI": TogetherAIEmbed,
|
"TogetherAI": TogetherAIEmbed,
|
||||||
"PerfXCloud": PerfXCloudEmbed,
|
"PerfXCloud": PerfXCloudEmbed,
|
||||||
"Upstage": UpstageEmbed,
|
"Upstage": UpstageEmbed,
|
||||||
"SILICONFLOW": SILICONFLOWEmbed
|
"SILICONFLOW": SILICONFLOWEmbed,
|
||||||
|
"Replicate": ReplicateEmbed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -96,7 +97,8 @@ ChatModel = {
|
|||||||
"Upstage":UpstageChat,
|
"Upstage":UpstageChat,
|
||||||
"novita.ai": NovitaAIChat,
|
"novita.ai": NovitaAIChat,
|
||||||
"SILICONFLOW": SILICONFLOWChat,
|
"SILICONFLOW": SILICONFLOWChat,
|
||||||
"01.AI": YiChat
|
"01.AI": YiChat,
|
||||||
|
"Replicate": ReplicateChat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1037,3 +1037,54 @@ class YiChat(Base):
|
|||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.01.ai/v1"
|
base_url = "https://api.01.ai/v1"
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicateChat(Base):
|
||||||
|
def __init__(self, key, model_name, base_url=None):
|
||||||
|
from replicate.client import Client
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
|
self.client = Client(api_token=key)
|
||||||
|
self.system = ""
|
||||||
|
|
||||||
|
def chat(self, system, history, gen_conf):
|
||||||
|
if "max_tokens" in gen_conf:
|
||||||
|
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
|
||||||
|
if system:
|
||||||
|
self.system = system
|
||||||
|
prompt = "\n".join(
|
||||||
|
[item["role"] + ":" + item["content"] for item in history[-5:]]
|
||||||
|
)
|
||||||
|
ans = ""
|
||||||
|
try:
|
||||||
|
response = self.client.run(
|
||||||
|
self.model_name,
|
||||||
|
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
|
||||||
|
)
|
||||||
|
ans = "".join(response)
|
||||||
|
return ans, num_tokens_from_string(ans)
|
||||||
|
except Exception as e:
|
||||||
|
return ans + "\n**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
|
if "max_tokens" in gen_conf:
|
||||||
|
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
|
||||||
|
if system:
|
||||||
|
self.system = system
|
||||||
|
prompt = "\n".join(
|
||||||
|
[item["role"] + ":" + item["content"] for item in history[-5:]]
|
||||||
|
)
|
||||||
|
ans = ""
|
||||||
|
try:
|
||||||
|
response = self.client.run(
|
||||||
|
self.model_name,
|
||||||
|
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
|
||||||
|
)
|
||||||
|
for resp in response:
|
||||||
|
ans += resp
|
||||||
|
yield ans
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
|
yield num_tokens_from_string(ans)
|
||||||
|
@ -581,3 +581,21 @@ class SILICONFLOWEmbed(OpenAIEmbed):
|
|||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.siliconflow.cn/v1"
|
base_url = "https://api.siliconflow.cn/v1"
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicateEmbed(Base):
|
||||||
|
def __init__(self, key, model_name, base_url=None):
|
||||||
|
from replicate.client import Client
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
|
self.client = Client(api_token=key)
|
||||||
|
|
||||||
|
def encode(self, texts: list, batch_size=32):
|
||||||
|
from json import dumps
|
||||||
|
|
||||||
|
res = self.client.run(self.model_name, input={"texts": dumps(texts)})
|
||||||
|
return np.array(res), sum([num_tokens_from_string(text) for text in texts])
|
||||||
|
|
||||||
|
def encode_queries(self, text):
|
||||||
|
res = self.client.embed(self.model_name, input={"texts": [text]})
|
||||||
|
return np.array(res), num_tokens_from_string(text)
|
||||||
|
@ -65,6 +65,7 @@ python_pptx==0.6.23
|
|||||||
readability_lxml==0.8.1
|
readability_lxml==0.8.1
|
||||||
redis==5.0.3
|
redis==5.0.3
|
||||||
Requests==2.32.2
|
Requests==2.32.2
|
||||||
|
replicate==0.31.0
|
||||||
roman_numbers==1.0.2
|
roman_numbers==1.0.2
|
||||||
ruamel.base==1.0.0
|
ruamel.base==1.0.0
|
||||||
scholarly==1.7.11
|
scholarly==1.7.11
|
||||||
|
@ -102,6 +102,7 @@ python-pptx==0.6.23
|
|||||||
PyYAML==6.0.1
|
PyYAML==6.0.1
|
||||||
redis==5.0.3
|
redis==5.0.3
|
||||||
regex==2023.12.25
|
regex==2023.12.25
|
||||||
|
replicate==0.31.0
|
||||||
requests==2.32.2
|
requests==2.32.2
|
||||||
ruamel.yaml==0.18.6
|
ruamel.yaml==0.18.6
|
||||||
ruamel.yaml.clib==0.2.8
|
ruamel.yaml.clib==0.2.8
|
||||||
|
1
web/src/assets/svg/llm/replicate.svg
Normal file
1
web/src/assets/svg/llm/replicate.svg
Normal file
@ -0,0 +1 @@
|
|||||||
|
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1723795491606" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="4258" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M0 0m256 0l512 0q256 0 256 256l0 512q0 256-256 256l-512 0q-256 0-256-256l0-512q0-256 256-256Z" fill="#000000" p-id="4259"></path><path d="M853.162667 170.666667v76.373333H338.602667v602.112H256V170.666667h597.162667zM853.333333 315.52v76.373333h-358.528V849.066667H412.117333V315.52H853.333333z m-0.170666 221.098667v-76.8h-285.226667v389.290666h82.645333v-312.490666h202.581334z" fill="#FFFFFF" p-id="4260"></path></svg>
|
After Width: | Height: | Size: 754 B |
@ -17,4 +17,4 @@ export const UserSettingIconMap = {
|
|||||||
|
|
||||||
export * from '@/constants/setting';
|
export * from '@/constants/setting';
|
||||||
|
|
||||||
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI'];
|
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI','Replicate'];
|
||||||
|
@ -30,7 +30,8 @@ export const IconMap = {
|
|||||||
Upstage: 'upstage',
|
Upstage: 'upstage',
|
||||||
'novita.ai': 'novita-ai',
|
'novita.ai': 'novita-ai',
|
||||||
SILICONFLOW: 'siliconflow',
|
SILICONFLOW: 'siliconflow',
|
||||||
"01.AI": 'yi'
|
"01.AI": 'yi',
|
||||||
|
"Replicate": 'replicate'
|
||||||
};
|
};
|
||||||
|
|
||||||
export const BedrockRegionList = [
|
export const BedrockRegionList = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user