mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 20:59:00 +08:00
add support for cohere (#1849)
### What problem does this PR solve? _Briefly describe what this PR aims to solve. Include background context that will help reviewers understand the purpose of the PR._ ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
parent
60428c4ad2
commit
e34817c2a9
@ -2216,6 +2216,116 @@
|
|||||||
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
|
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
|
||||||
"status": "1",
|
"status": "1",
|
||||||
"llm": []
|
"llm": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "cohere",
|
||||||
|
"logo": "",
|
||||||
|
"tags": "LLM,TEXT EMBEDDING, TEXT RE-RANK",
|
||||||
|
"status": "1",
|
||||||
|
"llm": [
|
||||||
|
{
|
||||||
|
"llm_name": "command-r-plus",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"model_type": "chat"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "command-r",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"model_type": "chat"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "command",
|
||||||
|
"tags": "LLM,CHAT,4k",
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"model_type": "chat"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "command-nightly",
|
||||||
|
"tags": "LLM,CHAT,128k",
|
||||||
|
"max_tokens": 131072,
|
||||||
|
"model_type": "chat"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "command-light",
|
||||||
|
"tags": "LLM,CHAT,4k",
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"model_type": "chat"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "command-light-nightly",
|
||||||
|
"tags": "LLM,CHAT,4k",
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"model_type": "chat"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "embed-english-v3.0",
|
||||||
|
"tags": "TEXT EMBEDDING",
|
||||||
|
"max_tokens": 512,
|
||||||
|
"model_type": "embedding"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "embed-english-light-v3.0",
|
||||||
|
"tags": "TEXT EMBEDDING",
|
||||||
|
"max_tokens": 512,
|
||||||
|
"model_type": "embedding"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "embed-multilingual-v3.0",
|
||||||
|
"tags": "TEXT EMBEDDING",
|
||||||
|
"max_tokens": 512,
|
||||||
|
"model_type": "embedding"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "embed-multilingual-light-v3.0",
|
||||||
|
"tags": "TEXT EMBEDDING",
|
||||||
|
"max_tokens": 512,
|
||||||
|
"model_type": "embedding"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "embed-english-v2.0",
|
||||||
|
"tags": "TEXT EMBEDDING",
|
||||||
|
"max_tokens": 512,
|
||||||
|
"model_type": "embedding"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "embed-english-light-v2.0",
|
||||||
|
"tags": "TEXT EMBEDDING",
|
||||||
|
"max_tokens": 512,
|
||||||
|
"model_type": "embedding"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "embed-multilingual-v2.0",
|
||||||
|
"tags": "TEXT EMBEDDING",
|
||||||
|
"max_tokens": 256,
|
||||||
|
"model_type": "embedding"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "rerank-english-v3.0",
|
||||||
|
"tags": "RE-RANK,4k",
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"model_type": "rerank"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "rerank-multilingual-v3.0",
|
||||||
|
"tags": "RE-RANK,4k",
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"model_type": "rerank"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "rerank-english-v2.0",
|
||||||
|
"tags": "RE-RANK,512",
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"model_type": "rerank"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "rerank-multilingual-v2.0",
|
||||||
|
"tags": "RE-RANK,512",
|
||||||
|
"max_tokens": 512,
|
||||||
|
"model_type": "rerank"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
@ -37,7 +37,8 @@ EmbeddingModel = {
|
|||||||
"Gemini": GeminiEmbed,
|
"Gemini": GeminiEmbed,
|
||||||
"NVIDIA": NvidiaEmbed,
|
"NVIDIA": NvidiaEmbed,
|
||||||
"LM-Studio": LmStudioEmbed,
|
"LM-Studio": LmStudioEmbed,
|
||||||
"OpenAI-API-Compatible": OpenAI_APIEmbed
|
"OpenAI-API-Compatible": OpenAI_APIEmbed,
|
||||||
|
"cohere": CoHereEmbed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +82,8 @@ ChatModel = {
|
|||||||
"StepFun": StepFunChat,
|
"StepFun": StepFunChat,
|
||||||
"NVIDIA": NvidiaChat,
|
"NVIDIA": NvidiaChat,
|
||||||
"LM-Studio": LmStudioChat,
|
"LM-Studio": LmStudioChat,
|
||||||
"OpenAI-API-Compatible": OpenAI_APIChat
|
"OpenAI-API-Compatible": OpenAI_APIChat,
|
||||||
|
"cohere": CoHereChat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -92,7 +94,8 @@ RerankModel = {
|
|||||||
"Xinference": XInferenceRerank,
|
"Xinference": XInferenceRerank,
|
||||||
"NVIDIA": NvidiaRerank,
|
"NVIDIA": NvidiaRerank,
|
||||||
"LM-Studio": LmStudioRerank,
|
"LM-Studio": LmStudioRerank,
|
||||||
"OpenAI-API-Compatible": OpenAI_APIRerank
|
"OpenAI-API-Compatible": OpenAI_APIRerank,
|
||||||
|
"cohere": CoHereRerank
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -900,3 +900,84 @@ class OpenAI_APIChat(Base):
|
|||||||
base_url = os.path.join(base_url, "v1")
|
base_url = os.path.join(base_url, "v1")
|
||||||
model_name = model_name.split("___")[0]
|
model_name = model_name.split("___")[0]
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
|
class CoHereChat(Base):
|
||||||
|
def __init__(self, key, model_name, base_url=""):
|
||||||
|
from cohere import Client
|
||||||
|
|
||||||
|
self.client = Client(api_key=key)
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def chat(self, system, history, gen_conf):
|
||||||
|
if system:
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
if "top_p" in gen_conf:
|
||||||
|
gen_conf["p"] = gen_conf.pop("top_p")
|
||||||
|
if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf:
|
||||||
|
gen_conf.pop("presence_penalty")
|
||||||
|
for item in history:
|
||||||
|
if "role" in item and item["role"] == "user":
|
||||||
|
item["role"] = "USER"
|
||||||
|
if "role" in item and item["role"] == "assistant":
|
||||||
|
item["role"] = "CHATBOT"
|
||||||
|
if "content" in item:
|
||||||
|
item["message"] = item.pop("content")
|
||||||
|
mes = history.pop()["message"]
|
||||||
|
ans = ""
|
||||||
|
try:
|
||||||
|
response = self.client.chat(
|
||||||
|
model=self.model_name, chat_history=history, message=mes, **gen_conf
|
||||||
|
)
|
||||||
|
ans = response.text
|
||||||
|
if response.finish_reason == "MAX_TOKENS":
|
||||||
|
ans += (
|
||||||
|
"...\nFor the content length reason, it stopped, continue?"
|
||||||
|
if is_english([ans])
|
||||||
|
else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
ans,
|
||||||
|
response.meta.tokens.input_tokens + response.meta.tokens.output_tokens,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return ans + "\n**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
|
if system:
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
|
if "top_p" in gen_conf:
|
||||||
|
gen_conf["p"] = gen_conf.pop("top_p")
|
||||||
|
if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf:
|
||||||
|
gen_conf.pop("presence_penalty")
|
||||||
|
for item in history:
|
||||||
|
if "role" in item and item["role"] == "user":
|
||||||
|
item["role"] = "USER"
|
||||||
|
if "role" in item and item["role"] == "assistant":
|
||||||
|
item["role"] = "CHATBOT"
|
||||||
|
if "content" in item:
|
||||||
|
item["message"] = item.pop("content")
|
||||||
|
mes = history.pop()["message"]
|
||||||
|
ans = ""
|
||||||
|
total_tokens = 0
|
||||||
|
try:
|
||||||
|
response = self.client.chat_stream(
|
||||||
|
model=self.model_name, chat_history=history, message=mes, **gen_conf
|
||||||
|
)
|
||||||
|
for resp in response:
|
||||||
|
if resp.event_type == "text-generation":
|
||||||
|
ans += resp.text
|
||||||
|
total_tokens += num_tokens_from_string(resp.text)
|
||||||
|
elif resp.event_type == "stream-end":
|
||||||
|
if resp.finish_reason == "MAX_TOKENS":
|
||||||
|
ans += (
|
||||||
|
"...\nFor the content length reason, it stopped, continue?"
|
||||||
|
if is_english([ans])
|
||||||
|
else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
|
)
|
||||||
|
yield ans
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
|
yield total_tokens
|
||||||
|
@ -523,3 +523,33 @@ class OpenAI_APIEmbed(OpenAIEmbed):
|
|||||||
base_url = os.path.join(base_url, "v1")
|
base_url = os.path.join(base_url, "v1")
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name.split("___")[0]
|
self.model_name = model_name.split("___")[0]
|
||||||
|
|
||||||
|
|
||||||
|
class CoHereEmbed(Base):
|
||||||
|
def __init__(self, key, model_name, base_url=None):
|
||||||
|
from cohere import Client
|
||||||
|
|
||||||
|
self.client = Client(api_key=key)
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def encode(self, texts: list, batch_size=32):
|
||||||
|
res = self.client.embed(
|
||||||
|
texts=texts,
|
||||||
|
model=self.model_name,
|
||||||
|
input_type="search_query",
|
||||||
|
embedding_types=["float"],
|
||||||
|
)
|
||||||
|
return np.array([d for d in res.embeddings.float]), int(
|
||||||
|
res.meta.billed_units.input_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_queries(self, text):
|
||||||
|
res = self.client.embed(
|
||||||
|
texts=[text],
|
||||||
|
model=self.model_name,
|
||||||
|
input_type="search_query",
|
||||||
|
embedding_types=["float"],
|
||||||
|
)
|
||||||
|
return np.array([d for d in res.embeddings.float]), int(
|
||||||
|
res.meta.billed_units.input_tokens
|
||||||
|
)
|
||||||
|
@ -203,7 +203,9 @@ class NvidiaRerank(Base):
|
|||||||
"top_n": len(texts),
|
"top_n": len(texts),
|
||||||
}
|
}
|
||||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||||
return (np.array([d["logit"] for d in res["rankings"]]), token_count)
|
rank = np.array([d["logit"] for d in res["rankings"]])
|
||||||
|
indexs = [d["index"] for d in res["rankings"]]
|
||||||
|
return rank[indexs], token_count
|
||||||
|
|
||||||
|
|
||||||
class LmStudioRerank(Base):
|
class LmStudioRerank(Base):
|
||||||
@ -220,3 +222,26 @@ class OpenAI_APIRerank(Base):
|
|||||||
|
|
||||||
def similarity(self, query: str, texts: list):
|
def similarity(self, query: str, texts: list):
|
||||||
raise NotImplementedError("The api has not been implement")
|
raise NotImplementedError("The api has not been implement")
|
||||||
|
|
||||||
|
|
||||||
|
class CoHereRerank(Base):
|
||||||
|
def __init__(self, key, model_name, base_url=None):
|
||||||
|
from cohere import Client
|
||||||
|
|
||||||
|
self.client = Client(api_key=key)
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def similarity(self, query: str, texts: list):
|
||||||
|
token_count = num_tokens_from_string(query) + sum(
|
||||||
|
[num_tokens_from_string(t) for t in texts]
|
||||||
|
)
|
||||||
|
res = self.client.rerank(
|
||||||
|
model=self.model_name,
|
||||||
|
query=query,
|
||||||
|
documents=texts,
|
||||||
|
top_n=len(texts),
|
||||||
|
return_documents=False,
|
||||||
|
)
|
||||||
|
rank = np.array([d.relevance_score for d in res.results])
|
||||||
|
indexs = [d.index for d in res.results]
|
||||||
|
return rank[indexs], token_count
|
||||||
|
@ -7,6 +7,7 @@ botocore==1.34.140
|
|||||||
cachetools==5.3.3
|
cachetools==5.3.3
|
||||||
chardet==5.2.0
|
chardet==5.2.0
|
||||||
cn2an==0.5.22
|
cn2an==0.5.22
|
||||||
|
cohere==5.6.2
|
||||||
dashscope==1.14.1
|
dashscope==1.14.1
|
||||||
datrie==0.8.2
|
datrie==0.8.2
|
||||||
demjson3==3.0.6
|
demjson3==3.0.6
|
||||||
|
@ -14,6 +14,7 @@ certifi==2024.7.4
|
|||||||
cffi==1.16.0
|
cffi==1.16.0
|
||||||
charset-normalizer==3.3.2
|
charset-normalizer==3.3.2
|
||||||
click==8.1.7
|
click==8.1.7
|
||||||
|
cohere==5.6.2
|
||||||
coloredlogs==15.0.1
|
coloredlogs==15.0.1
|
||||||
cryptography==42.0.5
|
cryptography==42.0.5
|
||||||
dashscope==1.14.1
|
dashscope==1.14.1
|
||||||
|
@ -14,6 +14,7 @@ certifi==2024.7.4
|
|||||||
cffi==1.16.0
|
cffi==1.16.0
|
||||||
charset-normalizer==3.3.2
|
charset-normalizer==3.3.2
|
||||||
click==8.1.7
|
click==8.1.7
|
||||||
|
cohere==5.6.2
|
||||||
coloredlogs==15.0.1
|
coloredlogs==15.0.1
|
||||||
cryptography==42.0.5
|
cryptography==42.0.5
|
||||||
dashscope==1.14.1
|
dashscope==1.14.1
|
||||||
|
1
web/src/assets/svg/llm/cohere.svg
Normal file
1
web/src/assets/svg/llm/cohere.svg
Normal file
@ -0,0 +1 @@
|
|||||||
|
<svg xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" xml:space="preserve" style="enable-background:new 0 0 75 75" viewBox="0 0 75 75" width="75" height="75" ><path d="M24.3 44.7c2 0 6-.1 11.6-2.4 6.5-2.7 19.3-7.5 28.6-12.5 6.5-3.5 9.3-8.1 9.3-14.3C73.8 7 66.9 0 58.3 0h-36C10 0 0 10 0 22.3s9.4 22.4 24.3 22.4z" style="fill-rule:evenodd;clip-rule:evenodd;fill:#39594d"/><path d="M30.4 60c0-6 3.6-11.5 9.2-13.8l11.3-4.7C62.4 36.8 75 45.2 75 57.6 75 67.2 67.2 75 57.6 75H45.3c-8.2 0-14.9-6.7-14.9-15z" style="fill-rule:evenodd;clip-rule:evenodd;fill:#d18ee2"/><path d="M12.9 47.6C5.8 47.6 0 53.4 0 60.5v1.7C0 69.2 5.8 75 12.9 75c7.1 0 12.9-5.8 12.9-12.9v-1.7c-.1-7-5.8-12.8-12.9-12.8z" style="fill:#ff7759"/></svg>
|
After Width: | Height: | Size: 742 B |
@ -22,7 +22,8 @@ export const IconMap = {
|
|||||||
StepFun: 'stepfun',
|
StepFun: 'stepfun',
|
||||||
NVIDIA:'nvidia',
|
NVIDIA:'nvidia',
|
||||||
'LM-Studio':'lm-studio',
|
'LM-Studio':'lm-studio',
|
||||||
'OpenAI-API-Compatible':'openai-api'
|
'OpenAI-API-Compatible':'openai-api',
|
||||||
|
'cohere':'cohere'
|
||||||
};
|
};
|
||||||
|
|
||||||
export const BedrockRegionList = [
|
export const BedrockRegionList = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user