From 9e6c46184d645b12e09de41f62123f2e9df89c7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=B5=E3=81=81?= Date: Fri, 29 Nov 2024 13:21:10 +0900 Subject: [PATCH] update default sam_prompt structure to include point data MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: ふぁ --- rembg/sessions/sam.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/rembg/sessions/sam.py b/rembg/sessions/sam.py index 8b358a4..b0c0221 100644 --- a/rembg/sessions/sam.py +++ b/rembg/sessions/sam.py @@ -143,7 +143,16 @@ class SamSession(BaseSession): Returns: List[PILImage]: A list of masks generated by the decoder. """ - prompt = kwargs.get("sam_prompt", "{}") + prompt = kwargs.get( + "sam_prompt", + [ + { + "type": "point", + "label": 1, + "data": [int(img.width / 2), int(img.height / 2)], + } + ], + ) schema = { "type": "array", "items": {