Refa: remove max toekns for image2txt models. (#6078)

### What problem does this PR solve?

#6063

### Type of change


- [x] Refactoring
This commit is contained in:
Kevin Hu 2025-03-14 13:51:45 +08:00 committed by GitHub
parent 42eb99554f
commit 56b228f187
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -49,7 +49,6 @@ class Base(ABC):
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)
)
@ -71,7 +70,6 @@ class Base(ABC):
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
stream=True
@ -157,8 +155,7 @@ class GptV4(Base):
res = self.client.chat.completions.create(
model=self.model_name,
messages=prompt,
max_tokens=max_tokens,
messages=prompt
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
@ -181,8 +178,7 @@ class AzureGptV4(Base):
res = self.client.chat.completions.create(
model=self.model_name,
messages=prompt,
max_tokens=max_tokens,
messages=prompt
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
@ -241,7 +237,6 @@ class QWenCV(Base):
if his["role"] == "user":
his["content"] = self.chat_prompt(his["content"], image)
response = MultiModalConversation.call(model=self.model_name, messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7))
@ -271,7 +266,6 @@ class QWenCV(Base):
tk_count = 0
try:
response = MultiModalConversation.call(model=self.model_name, messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
stream=True)
@ -306,8 +300,7 @@ class Zhipu4V(Base):
res = self.client.chat.completions.create(
model=self.model_name,
messages=prompt,
max_tokens=max_tokens,
messages=prompt
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
@ -322,7 +315,6 @@ class Zhipu4V(Base):
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)
)
@ -344,7 +336,6 @@ class Zhipu4V(Base):
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
max_tokens=gen_conf.get("max_tokens", 1000),
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
stream=True
@ -376,12 +367,10 @@ class OllamaCV(Base):
def describe(self, image, max_tokens=1024):
prompt = self.prompt("")
try:
options = {"num_predict": max_tokens}
response = self.client.generate(
model=self.model_name,
prompt=prompt[0]["content"][1]["text"],
images=[image],
options=options
images=[image]
)
ans = response["response"].strip()
return ans, 128
@ -399,8 +388,6 @@ class OllamaCV(Base):
options = {}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
@ -429,8 +416,6 @@ class OllamaCV(Base):
options = {}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
@ -480,8 +465,7 @@ class XinferenceCV(Base):
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.prompt(b64),
max_tokens=max_tokens,
messages=self.prompt(b64)
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
@ -497,15 +481,13 @@ class GeminiCV(Base):
def describe(self, image, max_tokens=2048):
from PIL.Image import open
gen_config = {'max_output_tokens':max_tokens}
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
b64 = self.image2base64(image)
img = open(BytesIO(base64.b64decode(b64)))
input = [prompt,img]
res = self.model.generate_content(
input,
generation_config=gen_config,
input
)
return res.text,res.usage_metadata.total_token_count
@ -525,7 +507,7 @@ class GeminiCV(Base):
history[-1]["parts"].append("data:image/jpeg;base64," + image)
response = self.model.generate_content(history, generation_config=GenerationConfig(
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)))
ans = response.text
@ -551,7 +533,7 @@ class GeminiCV(Base):
history[-1]["parts"].append("data:image/jpeg;base64," + image)
response = self.model.generate_content(history, generation_config=GenerationConfig(
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)), stream=True)
for resp in response:
@ -618,8 +600,7 @@ class NvidiaCV(Base):
"Authorization": f"Bearer {self.key}",
},
json={
"messages": self.prompt(b64),
"max_tokens": max_tokens,
"messages": self.prompt(b64)
},
)
response = response.json()