Source code for rofunc.utils.visualab.segment.vlpart_sam_seg

import os

import matplotlib.pyplot as plt
import numpy as np
import torch

from rofunc.learning.pre_trained_models.download import download_ckpt
from rofunc.utils.oslab import get_rofunc_path
from rofunc.utils.visualab.segment.vlpart.vlpart import build_vlpart


[docs]def vlpart_sam_predict(image, text_prompt, vlpart_checkpoint="swinbase_part_0a0000.pth", sam_checkpoint="sam_vit_h_4b8939.pth", box_threshold=0.3, text_threshold=0.25, device="cuda" ): import detectron2.data.transforms as T from segment_anything import build_sam, SamPredictor from segment_anything.utils.amg import remove_small_regions vlpart_ckpt_path = os.path.join(get_rofunc_path(), "learning/pre_trained_models", vlpart_checkpoint) sam_ckpt_path = os.path.join(get_rofunc_path(), "learning/pre_trained_models", sam_checkpoint) if not os.path.exists(vlpart_ckpt_path): download_ckpt( f"https://github.com/Cheems-Seminar/grounded-segment-any-parts/releases/download/v1.0/{vlpart_checkpoint}", vlpart_checkpoint) if not os.path.exists(sam_ckpt_path): download_ckpt(f"https://dl.fbaipublicfiles.com/segment_anything/{sam_checkpoint}", sam_checkpoint) # initialize VLPart vlpart = build_vlpart(checkpoint=vlpart_ckpt_path).to(device=device) # initialize SAM sam_predictor = SamPredictor(build_sam(checkpoint=sam_ckpt_path).to(device=device)) original_image = image # vlpart model inference preprocess = T.ResizeShortestEdge([800, 800], 1333) height, width = original_image.shape[:2] image = preprocess.get_transform(original_image).apply_image(original_image) image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) inputs = {"image": image, "height": height, "width": width} with torch.no_grad(): predictions = vlpart.inference([inputs], text_prompt=text_prompt)[0] boxes, masks = None, None filter_scores, filter_boxes, filter_classes = [], [], [] if "instances" in predictions: instances = predictions['instances'].to('cpu') boxes = instances.pred_boxes.tensor if instances.has("pred_boxes") else None scores = instances.scores if instances.has("scores") else None classes = instances.pred_classes.tolist() if instances.has("pred_classes") else None num_obj = len(scores) for obj_ind in range(num_obj): category_score = scores[obj_ind] if category_score < 0.7: continue filter_scores.append(category_score) filter_boxes.append(boxes[obj_ind]) filter_classes.append(classes[obj_ind]) if len(filter_boxes) > 0: # sam model inference sam_predictor.set_image(original_image) boxes_filter = torch.stack(filter_boxes) transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filter, original_image.shape[:2]) masks, _, _ = sam_predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes.to(device), multimask_output=False, ) # remove small disconnected regions and holes fine_masks = [] for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w] fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0]) masks = np.stack(fine_masks, axis=0)[:, np.newaxis] masks = torch.from_numpy(masks) # draw output image plt.figure(figsize=(10, 10)) plt.imshow(original_image) if len(filter_boxes) > 0: show_predictions_with_masks(filter_scores, filter_boxes, filter_classes, masks.to('cpu'), text_prompt) plt.axis('off') plt.show() return masks
[docs]def show_predictions_with_masks(scores, boxes, classes, masks, text_prompt): num_obj = len(scores) if num_obj == 0: return text_prompts = text_prompt.split('.') ax = plt.gca() ax.set_autoscale_on(False) colors = plt.cm.gist_rainbow(np.linspace(0, 1, num_obj)) for obj_ind in range(num_obj): box = boxes[obj_ind] score = scores[obj_ind] name = text_prompts[classes[obj_ind]] if score < 0.5: continue # color_mask = np.random.random((1, 3)).tolist()[0] color_mask = colors[obj_ind] m = masks[obj_ind][0] img = np.ones((m.shape[0], m.shape[1], 3)) for i in range(3): img[:, :, i] = color_mask[i] ax.imshow(np.dstack((img, m * 0.45))) x0, y0, w, h = box[0], box[1], box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color_mask, facecolor=(0, 0, 0, 0), lw=2)) label = name + ': {:.2}'.format(score) ax.text(x0, y0, label, color=color_mask, fontsize='large', fontfamily='sans-serif')