| 上一掩码 | 低分辨率掩码输入 | 迭代优化 |
交互式分割
点提示
# 单个前景点
input_point = np.array([[500, 375]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)
# 多个点(前景 + 背景)
input_points = np.array([[500, 375], [600, 400], [450, 300]])
input_labels = np.array([1, 1, 0]) # 2 个前景, 1 个背景
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False # 提示清晰时返回单掩码
)
框提示
# 边界框 [x1, y1, x2, y2]
input_box = np.array([425, 600, 700, 875])
masks, scores, logits = predictor.predict(
box=input_box,
multimask_output=False
)
组合提示
# 框 + 点以精确控制
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
box=np.array([400, 300, 700, 600]),
multimask_output=False
)
迭代优化
# 初始预测
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
# 使用上一掩码通过附加点优化
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375], [550, 400]]),
point_labels=np.array([1, 0]), # 添加背景点
mask_input=logits[np.argmax(scores)][None, :, :], # 使用最佳掩码
multimask_output=False
)
自动掩码生成
基础自动分割
from segment_anything import SamAutomaticMaskGenerator
# 创建生成器
mask_generator = SamAutomaticMaskGenerator(sam)
# 生成所有掩码
masks = mask_generator.generate(image)
# 每个掩码包含:
# - segmentation: 二值掩码
# - bbox: [x, y, w, h]
# - area: 像素数量
# - predicted_iou: 质量分数
# - stability_score: 鲁棒性分数
# - point_coords: 生成点
自定义生成
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32, # 网格密度(更多 = 更多掩码)
pred_iou_thresh=0.88, # 质量阈值
stability_score_thresh=0.95, # 稳定性阈值
crop_n_layers=1, # 多尺度裁剪
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # 移除微小掩码
)
masks = mask_generator.generate(image)
过滤掩码
# 按面积排序(最大优先)
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
# 按预测 IoU 过滤
high_quality = [m for m in masks if m['predicted_iou'] > 0.9]
# 按稳定性分数过滤
stable_masks = [m for m in masks if m['stability_score'] > 0.95]
批量推理
多张图像
# 高效处理多张图像
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
all_masks = []
for image in images:
predictor.set_image(image)
masks, _, _ = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks)
单图多提示
# 高效处理多个提示(一次图像编码)
predictor.set_image(image)
# 点提示批次
points = [
np.array([[100, 100]]),
np.array([[200, 200]]),
np.array([[300, 300]])
]
all_masks = []
for point in points:
masks, scores, _ = predictor.predict(
point_coords=point,
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks[np.argmax(scores)])
ONNX 部署
导出模型
python scripts/export_onnx_model.py \
--checkpoint sam_vit_h_4b8939.pth \
--model-type vit_h \
--output sam_onnx.onnx \
--return-single-mask
使用 ONNX 模型
import onnxruntime
# 加载 ONNX 模型
ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")
# 运行推理(图像嵌入需单独计算)
masks = ort_session.run(
None,
{
"image_embeddings": image_embeddings,
"point_coords": point_coords,
"point_labels": point_labels,
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
"has_mask_input": np.array([0], dtype=np.float32),
"orig_im_size": np.array([h, w], dtype=np.float32)
}
)
常见工作流
工作流 1:标注工具
import cv2
# 加载模型
predictor = SamPredictor(sam)
predictor.set_image(image)
def on_click(event, x, y, flags, param):
if event == cv2.EVENT_LBUTTONDOWN:
# 前景点
masks, scores, _ = predictor.predict(
point_coords=np.array([[x, y]]),
point_labels=np.array([1]),
multimask_output=True
)
# 显示最佳掩码
display_mask(masks[np.argmax(scores)])
工作流 2:物体提取
def extract_object(image, point):
"""用透明背景提取指定点的物体。"""
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=np.array([point]),
point_labels=np.array([1]),
multimask_output=True
)
best_mask = masks[np.argmax(scores)]
# 创建 RGBA 输出
rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
rgba[:, :, :3] = image
rgba[:, :, 3] = best_mask * 255
return rgba
工作流 3:医学图像分割
# 处理医学图像(灰度转 RGB)
medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE)
rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)
predictor.set_image(rgb_image)
# 分割感兴趣区域
masks, scores, _ = predictor.predict(
box=np.array([x1, y1, x2, y2]), # ROI 边界框
multimask_output=True
)
输出格式
掩码数据结构
# SamAutomaticMaskGenerator 输出
{
"segmentation": np.ndarray, # H×W 二值掩码
"bbox": [x, y, w, h], # 边界框
"area": int, # 像素数量
"predicted_iou": float, # 0-1 质量分数
"stability_score": float, # 0-1 鲁棒性分数
"crop_box": [x, y, w, h], # 生成裁剪区域
"point_coords": [[x, y]], # 输入点
}
COCO RLE 格式
from pycocotools import mask as mask_utils
# 编码掩码为 RLE
rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
rle["counts"] = rle["counts"].decode("utf-8")
# 解码 RLE 为掩码
decoded_mask = mask_utils.decode(rle)
性能优化
GPU 内存
# VRAM 受限时使用较小模型
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
# 批量处理图像
# 大批次间清空 CUDA 缓存
torch.cuda.empty_cache()
速度优化
# 使用半精度
sam = sam.half()
# 减少自动生成的点数
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16, # 默认是 32
)
# 使用 ONNX 部署
# 导出时加 --return-single-mask 加速推理
常见问题
| 问题 | 解决方案 |
| 内存不足 | 使用 ViT-B 模型,减小图像尺寸 |
| 推理慢 | 使用 ViT-B,减少 points_per_side |
| 掩码质量差 | 尝试不同提示,使用框 + 点组合 |
| 边缘伪影 | 使用稳定性分数过滤 |
| 小物体遗漏 | 增加 points_per_side |
参考资料
资源链接
- GitHub: https://github.com/facebookresearch/segment-anything
- 论文: https://arxiv.org/abs/2304.02643
- 在线演示: https://segment-anything.com
- SAM 2(视频): https://github.com/facebookresearch/segment-anything-2
- HuggingFace: https://huggingface.co/facebook/sam-vit-huge