Source code for rofunc.utils.visualab.segment.sam_seg

import os

import numpy as np
import pip
from matplotlib import pyplot as plt

from rofunc.learning.pre_trained_models.download import download_ckpt
from rofunc.utils.logger.beauty_logger import beauty_print
from rofunc.utils.oslab import get_rofunc_path
from rofunc.utils.visualab.image import show_anns, show_mask, show_star_points_w_labels
from rofunc.utils.visualab.interact import mouse_click_coords, mouse_select_rec_region


[docs]def sam_generate(image, sam_checkpoint="sam_vit_h_4b8939.pth", model_type="vit_h"): try: import segment_anything except ImportError: print("segment-anything is not installed. Install it automatically...") pip.main(['install', 'git+https://github.com/facebookresearch/segment-anything.git']) from segment_anything import SamAutomaticMaskGenerator, sam_model_registry ckpt_path = os.path.join(get_rofunc_path(), "learning/pre_trained_models", sam_checkpoint) if not os.path.exists(ckpt_path): download_ckpt(f"https://dl.fbaipublicfiles.com/segment_anything/{sam_checkpoint}", sam_checkpoint) sam = sam_model_registry[model_type](checkpoint=ckpt_path).to(device="cuda:0") mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=32, pred_iou_thresh=0.86, stability_score_thresh=0.92, crop_n_layers=1, crop_n_points_downscale_factor=2, min_mask_region_area=100, # Requires open-cv to run post-processing ) masks = mask_generator.generate(image) plt.figure(figsize=(20, 20)) plt.imshow(image) show_anns(masks) plt.axis('off') plt.show()
[docs]def sam_predict(image, use_point=False, use_box=False, choose_best_mask=True, sam_checkpoint="sam_vit_h_4b8939.pth", model_type="vit_h"): """ Use mouse to select points or a box, and segment the object with prompt in the image. :param image: image to segment :param use_point: whether to use pos/neg points to segment :param use_box: whether to use a box to segment :param choose_best_mask: whether to only choose the best mask :param sam_checkpoint: checkpoint path of sam model :param model_type: sam model type :return: """ try: import segment_anything except ImportError: print("segment-anything is not installed. Install it automatically...") pip.main(['install', 'git+https://github.com/facebookresearch/segment-anything.git']) from segment_anything import SamPredictor, sam_model_registry assert use_point != use_box, "Either use_point or use_box should be True" ckpt_path = os.path.join(get_rofunc_path(), "learning/pre_trained_models", sam_checkpoint) if not os.path.exists(ckpt_path): download_ckpt(f"https://dl.fbaipublicfiles.com/segment_anything/{sam_checkpoint}", sam_checkpoint) beauty_print("Segment with prompt", type="module") sam = sam_model_registry[model_type](checkpoint=ckpt_path).to(device="cuda:0") predictor = SamPredictor(sam) predictor.set_image(image) if use_point and not use_box: fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111) plt.imshow(image) pos_input_point = mouse_click_coords(fig, ax, "positive") plt.axis('off') plt.show() fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111) plt.imshow(image) neg_input_point = mouse_click_coords(fig, ax, "negative") plt.axis('off') plt.show() input_point = np.concatenate([pos_input_point, neg_input_point], axis=0) input_label = np.concatenate([np.ones(len(pos_input_point)), np.zeros(len(neg_input_point))], axis=0) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=not choose_best_mask, ) elif not use_point and use_box: fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111) plt.imshow(image) input_box = mouse_select_rec_region(fig, ax) plt.axis('off') plt.show() masks, scores, logits = predictor.predict( box=np.array(input_box), multimask_output=not choose_best_mask, ) else: raise ValueError("Either use_point or use_box should be True") if not choose_best_mask: for i, (mask, score) in enumerate(zip(masks, scores)): plt.figure(figsize=(10, 10)) plt.imshow(image) show_mask(mask, plt.gca()) if use_point: show_star_points_w_labels(input_point, input_label, plt.gca()) plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18) plt.axis('off') plt.show() else: plt.figure(figsize=(10, 10)) plt.imshow(image) show_mask(masks[0], plt.gca()) if use_point: show_star_points_w_labels(input_point, input_label, plt.gca()) plt.axis('off') plt.show()