diff --git a/rembg/sessions/sam.py b/rembg/sessions/sam.py index 8b358a4..b0c0221 100644 --- a/rembg/sessions/sam.py +++ b/rembg/sessions/sam.py @@ -143,7 +143,16 @@ class SamSession(BaseSession): Returns: List[PILImage]: A list of masks generated by the decoder. """ - prompt = kwargs.get("sam_prompt", "{}") + prompt = kwargs.get( + "sam_prompt", + [ + { + "type": "point", + "label": 1, + "data": [int(img.width / 2), int(img.height / 2)], + } + ], + ) schema = { "type": "array", "items": {