update default sam_prompt structure to include point data

Signed-off-by: ふぁ <yuki@yuki0311.com>
This commit is contained in:
ふぁ 2024-11-29 13:21:10 +09:00
parent 00ad95d1c0
commit 9e6c46184d
No known key found for this signature in database
GPG Key ID: 83A8A5E74872A8AA

View File

@ -143,7 +143,16 @@ class SamSession(BaseSession):
Returns:
List[PILImage]: A list of masks generated by the decoder.
"""
prompt = kwargs.get("sam_prompt", "{}")
prompt = kwargs.get(
"sam_prompt",
[
{
"type": "point",
"label": 1,
"data": [int(img.width / 2), int(img.height / 2)],
}
],
)
schema = {
"type": "array",
"items": {