mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 07:09:01 +08:00
Refactor embedding batch_size (#3825)
### What problem does this PR solve? Refactor embedding batch_size. Close #3657 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring
This commit is contained in:
parent
934dbc2e2b
commit
92ab7ef659
@ -232,13 +232,13 @@ class LLMBundle(object):
|
|||||||
self.max_length = lm.max_tokens
|
self.max_length = lm.max_tokens
|
||||||
break
|
break
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=32):
|
def encode(self, texts: list):
|
||||||
emd, used_tokens = self.mdl.encode(texts, batch_size)
|
embeddings, used_tokens = self.mdl.encode(texts)
|
||||||
if not TenantLLMService.increase_usage(
|
if not TenantLLMService.increase_usage(
|
||||||
self.tenant_id, self.llm_type, used_tokens):
|
self.tenant_id, self.llm_type, used_tokens):
|
||||||
logging.error(
|
logging.error(
|
||||||
"LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
"LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||||
return emd, used_tokens
|
return embeddings, used_tokens
|
||||||
|
|
||||||
def encode_queries(self, query: str):
|
def encode_queries(self, query: str):
|
||||||
emd, used_tokens = self.mdl.encode_queries(query)
|
emd, used_tokens = self.mdl.encode_queries(query)
|
||||||
|
@ -63,16 +63,13 @@ class Benchmark:
|
|||||||
run[query][c["chunk_id"]] = c["similarity"]
|
run[query][c["chunk_id"]] = c["similarity"]
|
||||||
return run
|
return run
|
||||||
|
|
||||||
def embedding(self, docs, batch_size=16):
|
def embedding(self, docs):
|
||||||
vects = []
|
texts = [d["content_with_weight"] for d in docs]
|
||||||
cnts = [d["content_with_weight"] for d in docs]
|
embeddings, _ = self.embd_mdl.encode(texts)
|
||||||
for i in range(0, len(cnts), batch_size):
|
assert len(docs) == len(embeddings)
|
||||||
vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
|
|
||||||
vects.extend(vts.tolist())
|
|
||||||
assert len(docs) == len(vects)
|
|
||||||
vector_size = 0
|
vector_size = 0
|
||||||
for i, d in enumerate(docs):
|
for i, d in enumerate(docs):
|
||||||
v = vects[i]
|
v = embeddings[i]
|
||||||
vector_size = len(v)
|
vector_size = len(v)
|
||||||
d["q_%d_vec" % len(v)] = v
|
d["q_%d_vec" % len(v)] = v
|
||||||
return docs, vector_size
|
return docs, vector_size
|
||||||
|
@ -38,7 +38,7 @@ class Base(ABC):
|
|||||||
def __init__(self, key, model_name):
|
def __init__(self, key, model_name):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
raise NotImplementedError("Please implement encode method!")
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
def encode_queries(self, text: str):
|
def encode_queries(self, text: str):
|
||||||
@ -78,15 +78,16 @@ class DefaultEmbedding(Base):
|
|||||||
use_fp16=torch.cuda.is_available())
|
use_fp16=torch.cuda.is_available())
|
||||||
self._model = DefaultEmbedding._model
|
self._model = DefaultEmbedding._model
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
|
batch_size = 16
|
||||||
texts = [truncate(t, 2048) for t in texts]
|
texts = [truncate(t, 2048) for t in texts]
|
||||||
token_count = 0
|
token_count = 0
|
||||||
for t in texts:
|
for t in texts:
|
||||||
token_count += num_tokens_from_string(t)
|
token_count += num_tokens_from_string(t)
|
||||||
res = []
|
ress = []
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
res.extend(self._model.encode(texts[i:i + batch_size]).tolist())
|
ress.extend(self._model.encode(texts[i:i + batch_size]).tolist())
|
||||||
return np.array(res), token_count
|
return np.array(ress), token_count
|
||||||
|
|
||||||
def encode_queries(self, text: str):
|
def encode_queries(self, text: str):
|
||||||
token_count = num_tokens_from_string(text)
|
token_count = num_tokens_from_string(text)
|
||||||
@ -101,12 +102,18 @@ class OpenAIEmbed(Base):
|
|||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
|
# OpenAI requires batch size <=16
|
||||||
|
batch_size = 16
|
||||||
texts = [truncate(t, 8191) for t in texts]
|
texts = [truncate(t, 8191) for t in texts]
|
||||||
res = self.client.embeddings.create(input=texts,
|
ress = []
|
||||||
model=self.model_name)
|
total_tokens = 0
|
||||||
return np.array([d.embedding for d in res.data]
|
for i in range(0, len(texts), batch_size):
|
||||||
), res.usage.total_tokens
|
res = self.client.embeddings.create(input=texts[i:i + batch_size],
|
||||||
|
model=self.model_name)
|
||||||
|
ress.extend([d.embedding for d in res.data])
|
||||||
|
total_tokens += res.usage.total_tokens
|
||||||
|
return np.array(ress), total_tokens
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embeddings.create(input=[truncate(text, 8191)],
|
res = self.client.embeddings.create(input=[truncate(text, 8191)],
|
||||||
@ -123,12 +130,14 @@ class LocalAIEmbed(Base):
|
|||||||
self.client = OpenAI(api_key="empty", base_url=base_url)
|
self.client = OpenAI(api_key="empty", base_url=base_url)
|
||||||
self.model_name = model_name.split("___")[0]
|
self.model_name = model_name.split("___")[0]
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
res = self.client.embeddings.create(input=texts, model=self.model_name)
|
batch_size = 16
|
||||||
return (
|
ress = []
|
||||||
np.array([d.embedding for d in res.data]),
|
for i in range(0, len(texts), batch_size):
|
||||||
1024,
|
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
|
||||||
) # local embedding for LmStudio donot count tokens
|
ress.extend([d.embedding for d in res.data])
|
||||||
|
# local embedding for LmStudio donot count tokens
|
||||||
|
return np.array(ress), 1024
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
embds, cnt = self.encode([text])
|
embds, cnt = self.encode([text])
|
||||||
@ -155,12 +164,12 @@ class BaiChuanEmbed(OpenAIEmbed):
|
|||||||
|
|
||||||
class QWenEmbed(Base):
|
class QWenEmbed(Base):
|
||||||
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
||||||
dashscope.api_key = key
|
self.key = key
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=10):
|
def encode(self, texts: list):
|
||||||
import dashscope
|
import dashscope
|
||||||
batch_size = min(batch_size, 4)
|
batch_size = 4
|
||||||
try:
|
try:
|
||||||
res = []
|
res = []
|
||||||
token_count = 0
|
token_count = 0
|
||||||
@ -169,6 +178,7 @@ class QWenEmbed(Base):
|
|||||||
resp = dashscope.TextEmbedding.call(
|
resp = dashscope.TextEmbedding.call(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
input=texts[i:i + batch_size],
|
input=texts[i:i + batch_size],
|
||||||
|
api_key=self.key,
|
||||||
text_type="document"
|
text_type="document"
|
||||||
)
|
)
|
||||||
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
||||||
@ -186,6 +196,7 @@ class QWenEmbed(Base):
|
|||||||
resp = dashscope.TextEmbedding.call(
|
resp = dashscope.TextEmbedding.call(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
input=text[:2048],
|
input=text[:2048],
|
||||||
|
api_key=self.key,
|
||||||
text_type="query"
|
text_type="query"
|
||||||
)
|
)
|
||||||
return np.array(resp["output"]["embeddings"][0]
|
return np.array(resp["output"]["embeddings"][0]
|
||||||
@ -200,7 +211,7 @@ class ZhipuEmbed(Base):
|
|||||||
self.client = ZhipuAI(api_key=key)
|
self.client = ZhipuAI(api_key=key)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
arr = []
|
arr = []
|
||||||
tks_num = 0
|
tks_num = 0
|
||||||
for txt in texts:
|
for txt in texts:
|
||||||
@ -221,7 +232,7 @@ class OllamaEmbed(Base):
|
|||||||
self.client = Client(host=kwargs["base_url"])
|
self.client = Client(host=kwargs["base_url"])
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
arr = []
|
arr = []
|
||||||
tks_num = 0
|
tks_num = 0
|
||||||
for txt in texts:
|
for txt in texts:
|
||||||
@ -252,13 +263,13 @@ class FastEmbed(Base):
|
|||||||
from fastembed import TextEmbedding
|
from fastembed import TextEmbedding
|
||||||
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
# Using the internal tokenizer to encode the texts and get the total
|
# Using the internal tokenizer to encode the texts and get the total
|
||||||
# number of tokens
|
# number of tokens
|
||||||
encodings = self._model.model.tokenizer.encode_batch(texts)
|
encodings = self._model.model.tokenizer.encode_batch(texts)
|
||||||
total_tokens = sum(len(e) for e in encodings)
|
total_tokens = sum(len(e) for e in encodings)
|
||||||
|
|
||||||
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
|
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size=16)]
|
||||||
|
|
||||||
return np.array(embeddings), total_tokens
|
return np.array(embeddings), total_tokens
|
||||||
|
|
||||||
@ -278,11 +289,15 @@ class XinferenceEmbed(Base):
|
|||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
res = self.client.embeddings.create(input=texts,
|
batch_size = 16
|
||||||
model=self.model_name)
|
ress = []
|
||||||
return np.array([d.embedding for d in res.data]
|
total_tokens = 0
|
||||||
), res.usage.total_tokens
|
for i in range(0, len(texts), batch_size):
|
||||||
|
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
|
||||||
|
ress.extend([d.embedding for d in res.data])
|
||||||
|
total_tokens += res.usage.total_tokens
|
||||||
|
return np.array(ress), total_tokens
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embeddings.create(input=[text],
|
res = self.client.embeddings.create(input=[text],
|
||||||
@ -306,7 +321,8 @@ class YoudaoEmbed(Base):
|
|||||||
model_name_or_path=model_name.replace(
|
model_name_or_path=model_name.replace(
|
||||||
"maidalun1020", "InfiniFlow"))
|
"maidalun1020", "InfiniFlow"))
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=10):
|
def encode(self, texts: list):
|
||||||
|
batch_size = 10
|
||||||
res = []
|
res = []
|
||||||
token_count = 0
|
token_count = 0
|
||||||
for t in texts:
|
for t in texts:
|
||||||
@ -332,15 +348,21 @@ class JinaEmbed(Base):
|
|||||||
}
|
}
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=None):
|
def encode(self, texts: list):
|
||||||
texts = [truncate(t, 8196) for t in texts]
|
texts = [truncate(t, 8196) for t in texts]
|
||||||
data = {
|
batch_size = 16
|
||||||
"model": self.model_name,
|
ress = []
|
||||||
"input": texts,
|
token_count = 0
|
||||||
'encoding_type': 'float'
|
for i in range(0, len(texts), batch_size):
|
||||||
}
|
data = {
|
||||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
"model": self.model_name,
|
||||||
return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"]
|
"input": texts[i:i + batch_size],
|
||||||
|
'encoding_type': 'float'
|
||||||
|
}
|
||||||
|
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||||
|
ress.extend([d["embedding"] for d in res["data"]])
|
||||||
|
token_count += res["usage"]["total_tokens"]
|
||||||
|
return np.array(ress), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
embds, cnt = self.encode([text])
|
embds, cnt = self.encode([text])
|
||||||
@ -394,12 +416,17 @@ class MistralEmbed(Base):
|
|||||||
self.client = MistralClient(api_key=key)
|
self.client = MistralClient(api_key=key)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
texts = [truncate(t, 8196) for t in texts]
|
texts = [truncate(t, 8196) for t in texts]
|
||||||
res = self.client.embeddings(input=texts,
|
batch_size = 16
|
||||||
model=self.model_name)
|
ress = []
|
||||||
return np.array([d.embedding for d in res.data]
|
token_count = 0
|
||||||
), res.usage.total_tokens
|
for i in range(0, len(texts), batch_size):
|
||||||
|
res = self.client.embeddings(input=texts[i:i + batch_size],
|
||||||
|
model=self.model_name)
|
||||||
|
ress.extend([d.embedding for d in res.data])
|
||||||
|
token_count += res.usage.total_tokens
|
||||||
|
return np.array(ress), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embeddings(input=[truncate(text, 8196)],
|
res = self.client.embeddings(input=[truncate(text, 8196)],
|
||||||
@ -418,7 +445,7 @@ class BedrockEmbed(Base):
|
|||||||
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
|
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
|
||||||
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
texts = [truncate(t, 8196) for t in texts]
|
texts = [truncate(t, 8196) for t in texts]
|
||||||
embeddings = []
|
embeddings = []
|
||||||
token_count = 0
|
token_count = 0
|
||||||
@ -436,7 +463,6 @@ class BedrockEmbed(Base):
|
|||||||
return np.array(embeddings), token_count
|
return np.array(embeddings), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
token_count = num_tokens_from_string(text)
|
token_count = num_tokens_from_string(text)
|
||||||
if self.model_name.split('.')[0] == 'amazon':
|
if self.model_name.split('.')[0] == 'amazon':
|
||||||
@ -453,20 +479,26 @@ class BedrockEmbed(Base):
|
|||||||
class GeminiEmbed(Base):
|
class GeminiEmbed(Base):
|
||||||
def __init__(self, key, model_name='models/text-embedding-004',
|
def __init__(self, key, model_name='models/text-embedding-004',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
genai.configure(api_key=key)
|
self.key = key
|
||||||
self.model_name = 'models/' + model_name
|
self.model_name = 'models/' + model_name
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
texts = [truncate(t, 2048) for t in texts]
|
texts = [truncate(t, 2048) for t in texts]
|
||||||
token_count = sum(num_tokens_from_string(text) for text in texts)
|
token_count = sum(num_tokens_from_string(text) for text in texts)
|
||||||
result = genai.embed_content(
|
genai.configure(api_key=self.key)
|
||||||
model=self.model_name,
|
batch_size = 16
|
||||||
content=texts,
|
ress = []
|
||||||
task_type="retrieval_document",
|
for i in range(0, len(texts), batch_size):
|
||||||
title="Embedding of list of strings")
|
result = genai.embed_content(
|
||||||
return np.array(result['embedding']),token_count
|
model=self.model_name,
|
||||||
|
content=texts[i, i + batch_size],
|
||||||
|
task_type="retrieval_document",
|
||||||
|
title="Embedding of single string")
|
||||||
|
ress.extend(result['embedding'])
|
||||||
|
return np.array(ress),token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
|
genai.configure(api_key=self.key)
|
||||||
result = genai.embed_content(
|
result = genai.embed_content(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
content=truncate(text,2048),
|
content=truncate(text,2048),
|
||||||
@ -495,19 +527,22 @@ class NvidiaEmbed(Base):
|
|||||||
if model_name == "snowflake/arctic-embed-l":
|
if model_name == "snowflake/arctic-embed-l":
|
||||||
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
|
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=None):
|
def encode(self, texts: list):
|
||||||
payload = {
|
batch_size = 16
|
||||||
"input": texts,
|
ress = []
|
||||||
"input_type": "query",
|
token_count = 0
|
||||||
"model": self.model_name,
|
for i in range(0, len(texts), batch_size):
|
||||||
"encoding_format": "float",
|
payload = {
|
||||||
"truncate": "END",
|
"input": texts[i : i + batch_size],
|
||||||
}
|
"input_type": "query",
|
||||||
res = requests.post(self.base_url, headers=self.headers, json=payload).json()
|
"model": self.model_name,
|
||||||
return (
|
"encoding_format": "float",
|
||||||
np.array([d["embedding"] for d in res["data"]]),
|
"truncate": "END",
|
||||||
res["usage"]["total_tokens"],
|
}
|
||||||
)
|
res = requests.post(self.base_url, headers=self.headers, json=payload).json()
|
||||||
|
ress.extend([d["embedding"] for d in res["data"]])
|
||||||
|
token_count += res["usage"]["total_tokens"]
|
||||||
|
return np.array(ress), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
embds, cnt = self.encode([text])
|
embds, cnt = self.encode([text])
|
||||||
@ -541,16 +576,20 @@ class CoHereEmbed(Base):
|
|||||||
self.client = Client(api_key=key)
|
self.client = Client(api_key=key)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
res = self.client.embed(
|
batch_size = 16
|
||||||
texts=texts,
|
ress = []
|
||||||
model=self.model_name,
|
token_count = 0
|
||||||
input_type="search_query",
|
for i in range(0, len(texts), batch_size):
|
||||||
embedding_types=["float"],
|
res = self.client.embed(
|
||||||
)
|
texts=texts[i : i + batch_size],
|
||||||
return np.array([d for d in res.embeddings.float]), int(
|
model=self.model_name,
|
||||||
res.meta.billed_units.input_tokens
|
input_type="search_document",
|
||||||
)
|
embedding_types=["float"],
|
||||||
|
)
|
||||||
|
ress.extend([d for d in res.embeddings.float])
|
||||||
|
token_count += res.meta.billed_units.input_tokens
|
||||||
|
return np.array(ress), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embed(
|
res = self.client.embed(
|
||||||
@ -599,19 +638,23 @@ class SILICONFLOWEmbed(Base):
|
|||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
payload = {
|
batch_size = 16
|
||||||
"model": self.model_name,
|
ress = []
|
||||||
"input": texts,
|
token_count = 0
|
||||||
"encoding_format": "float",
|
for i in range(0, len(texts), batch_size):
|
||||||
}
|
texts_batch = texts[i : i + batch_size]
|
||||||
res = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
payload = {
|
||||||
if "data" not in res or not isinstance(res["data"], list) or len(res["data"])!= len(texts):
|
"model": self.model_name,
|
||||||
raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}")
|
"input": texts_batch,
|
||||||
return (
|
"encoding_format": "float",
|
||||||
np.array([d["embedding"] for d in res["data"]]),
|
}
|
||||||
res["usage"]["total_tokens"],
|
res = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
||||||
)
|
if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch):
|
||||||
|
raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}")
|
||||||
|
ress.extend([d["embedding"] for d in res["data"]])
|
||||||
|
token_count += res["usage"]["total_tokens"]
|
||||||
|
return np.array(ress), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
payload = {
|
payload = {
|
||||||
@ -632,9 +675,14 @@ class ReplicateEmbed(Base):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.client = Client(api_token=key)
|
self.client = Client(api_token=key)
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
res = self.client.run(self.model_name, input={"texts": json.dumps(texts)})
|
batch_size = 16
|
||||||
return np.array(res), sum([num_tokens_from_string(text) for text in texts])
|
token_count = sum([num_tokens_from_string(text) for text in texts])
|
||||||
|
ress = []
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]})
|
||||||
|
ress.extend(res)
|
||||||
|
return np.array(ress), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embed(self.model_name, input={"texts": [text]})
|
res = self.client.embed(self.model_name, input={"texts": [text]})
|
||||||
@ -673,11 +721,17 @@ class VoyageEmbed(Base):
|
|||||||
self.client = voyageai.Client(api_key=key)
|
self.client = voyageai.Client(api_key=key)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
res = self.client.embed(
|
batch_size = 16
|
||||||
texts=texts, model=self.model_name, input_type="document"
|
ress = []
|
||||||
)
|
token_count = 0
|
||||||
return np.array(res.embeddings), res.total_tokens
|
for i in range(0, len(texts), batch_size):
|
||||||
|
res = self.client.embed(
|
||||||
|
texts=texts[i : i + batch_size], model=self.model_name, input_type="document"
|
||||||
|
)
|
||||||
|
ress.extend(res.embeddings)
|
||||||
|
token_count += res.total_tokens
|
||||||
|
return np.array(ress), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embed(
|
res = self.client.embed(
|
||||||
@ -694,7 +748,7 @@ class HuggingFaceEmbed(Base):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.base_url = base_url or "http://127.0.0.1:8080"
|
self.base_url = base_url or "http://127.0.0.1:8080"
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=16):
|
def encode(self, texts: list):
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user