diff --git a/minsu3d/model/hais.py b/minsu3d/model/hais.py index 582c0eb..e667861 100644 --- a/minsu3d/model/hais.py +++ b/minsu3d/model/hais.py @@ -1,6 +1,7 @@ import torch import time import torch.nn as nn +import numpy as np from minsu3d.evaluation.instance_segmentation import get_gt_instances, rle_encode from minsu3d.evaluation.object_detection import get_gt_bbox from minsu3d.common_ops.functions import hais_ops, common_ops @@ -9,7 +10,6 @@ from minsu3d.model.module import TinyUnet from minsu3d.evaluation.semantic_segmentation import * from minsu3d.model.general_model import GeneralModel, clusters_voxelization, get_batch_offsets -from minsu3d.util.nms import get_nms_instance class HAIS(GeneralModel): @@ -229,15 +229,8 @@ def _get_pred_instances(self, scan_id, gt_xyz, scores, proposals_idx, num_propos scores_pred = scores_pred[npoint_mask] proposals_pred = proposals_pred[npoint_mask] - if scores_pred.shape[0] == 0: - pick_idxs = np.empty(0) - elif self.hparams.inference.TEST_NMS_THRESH >= 1: - pick_idxs = list(range(0, scores_pred.shape[0])) - else: - pick_idxs = get_nms_instance(proposals_pred.float(), scores_pred.numpy(), self.hparams.inference.TEST_NMS_THRESH) - - clusters = proposals_pred[pick_idxs].numpy() - cluster_scores = scores_pred[pick_idxs].numpy() + clusters = proposals_pred.numpy() + cluster_scores = scores_pred.numpy() nclusters = clusters.shape[0] diff --git a/minsu3d/model/pointgroup.py b/minsu3d/model/pointgroup.py index b5c6d00..09fb826 100755 --- a/minsu3d/model/pointgroup.py +++ b/minsu3d/model/pointgroup.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import time +import numpy as np from minsu3d.evaluation.instance_segmentation import get_gt_instances, rle_encode from minsu3d.evaluation.object_detection import get_gt_bbox from minsu3d.common_ops.functions import pointgroup_ops, common_ops @@ -9,7 +10,6 @@ from minsu3d.model.module import TinyUnet from minsu3d.evaluation.semantic_segmentation import * from minsu3d.model.general_model import GeneralModel, clusters_voxelization, get_batch_offsets -from minsu3d.util.nms import get_nms_instance class PointGroup(GeneralModel): @@ -44,15 +44,25 @@ def forward(self, data_dict): semantic_preds_cpu = semantic_preds[object_idxs].cpu().int() object_idxs_cpu = object_idxs.cpu() - idx_shift, start_len_shift = common_ops.ballquery_batch_p(coords_ + pt_offsets_, batch_idxs_, batch_offsets_, self.hparams.model.cluster.cluster_radius, self.hparams.model.cluster.cluster_shift_meanActive) - proposals_idx_shift, proposals_offset_shift = pointgroup_ops.pg_bfs_cluster(semantic_preds_cpu, idx_shift.cpu(), start_len_shift.cpu(), self.hparams.model.cluster.cluster_npoint_thre) + idx_shift, start_len_shift = common_ops.ballquery_batch_p(coords_ + pt_offsets_, batch_idxs_, + batch_offsets_, + self.hparams.model.cluster.cluster_radius, + self.hparams.model.cluster.cluster_shift_meanActive) + proposals_idx_shift, proposals_offset_shift = pointgroup_ops.pg_bfs_cluster(semantic_preds_cpu, + idx_shift.cpu(), + start_len_shift.cpu(), + self.hparams.model.cluster.cluster_npoint_thre) proposals_idx_shift[:, 1] = object_idxs_cpu[proposals_idx_shift[:, 1].long()].int() # proposals_idx_shift: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N # proposals_offset_shift: (nProposal + 1), int # proposals_batchId_shift_all: (sumNPoint,) batch id - idx, start_len = common_ops.ballquery_batch_p(coords_, batch_idxs_, batch_offsets_, self.hparams.model.cluster.cluster_radius, self.hparams.model.cluster.cluster_meanActive) - proposals_idx, proposals_offset = pointgroup_ops.pg_bfs_cluster(semantic_preds_cpu, idx.cpu(), start_len.cpu(), self.hparams.model.cluster.cluster_npoint_thre) + idx, start_len = common_ops.ballquery_batch_p(coords_, batch_idxs_, batch_offsets_, + self.hparams.model.cluster.cluster_radius, + self.hparams.model.cluster.cluster_meanActive) + proposals_idx, proposals_offset = pointgroup_ops.pg_bfs_cluster(semantic_preds_cpu, idx.cpu(), + start_len.cpu(), + self.hparams.model.cluster.cluster_npoint_thre) proposals_idx[:, 1] = object_idxs_cpu[proposals_idx[:, 1].long()].int() # proposals_idx: (sumNPoint, 2), int, dim 0 for cluster_id, dim 1 for corresponding point idxs in N # proposals_offset: (nProposal + 1), int @@ -82,7 +92,7 @@ def forward(self, data_dict): proposals_score_feats = common_ops.roipool(pt_score_feats, proposals_offset) # (nProposal, C) scores = self.score_branch(proposals_score_feats) # (nProposal, 1) output_dict["proposal_scores"] = (scores, proposals_idx, proposals_offset) - + return output_dict def _loss(self, data_dict, output_dict): @@ -100,7 +110,8 @@ def _loss(self, data_dict, output_dict): # proposals_idx: (sumNPoint, 2), int, cpu, dim 0 for cluster_id, dim 1 for corresponding point idxs in N # proposals_offset: (nProposal + 1), int, cpu # instance_pointnum: (total_nInst), int - ious = common_ops.get_iou(proposals_idx[:, 1].cuda(), proposals_offset, data_dict["instance_ids"], instance_pointnum) # (nProposal, nInstance), float + ious = common_ops.get_iou(proposals_idx[:, 1].cuda(), proposals_offset, data_dict["instance_ids"], + instance_pointnum) # (nProposal, nInstance), float gt_ious, gt_instance_idxs = ious.max(1) # (nProposal) float, long gt_scores = get_segmented_scores(gt_ious, self.hparams.model.fg_thresh, self.hparams.model.bg_thresh) score_criterion = ScoreLoss() @@ -115,7 +126,8 @@ def validation_step(self, data_dict, idx): losses, total_loss = self._loss(data_dict, output_dict) # log losses - self.log("val/total_loss", total_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True, batch_size=1) + self.log("val/total_loss", total_loss, prog_bar=True, on_step=False, + on_epoch=True, sync_dist=True, batch_size=1) for key, value in losses.items(): self.log(f"val/{key}", value, on_step=False, on_epoch=True, sync_dist=True, batch_size=1) @@ -125,8 +137,10 @@ def validation_step(self, data_dict, idx): ignore_label=self.hparams.data.ignore_label) semantic_mean_iou = evaluate_semantic_miou(semantic_predictions, data_dict["sem_labels"].cpu().numpy(), ignore_label=self.hparams.data.ignore_label) - self.log("val_eval/semantic_accuracy", semantic_accuracy, on_step=False, on_epoch=True, sync_dist=True, batch_size=1) - self.log("val_eval/semantic_mean_iou", semantic_mean_iou, on_step=False, on_epoch=True, sync_dist=True, batch_size=1) + self.log("val_eval/semantic_accuracy", semantic_accuracy, on_step=False, on_epoch=True, sync_dist=True, + batch_size=1) + self.log("val_eval/semantic_mean_iou", semantic_mean_iou, on_step=False, on_epoch=True, sync_dist=True, + batch_size=1) if self.current_epoch > self.hparams.model.prepare_epochs: pred_instances = self._get_pred_instances(data_dict["scan_ids"][0], @@ -134,10 +148,14 @@ def validation_step(self, data_dict, idx): output_dict["proposal_scores"][0].cpu(), output_dict["proposal_scores"][1].cpu(), output_dict["proposal_scores"][2].size(0) - 1, - output_dict["semantic_scores"].cpu(), len(self.hparams.data.ignore_classes)) - gt_instances = get_gt_instances(data_dict["sem_labels"].cpu(), data_dict["instance_ids"].cpu(), self.hparams.data.ignore_classes) + output_dict["semantic_scores"].cpu(), + len(self.hparams.data.ignore_classes)) + gt_instances = get_gt_instances(data_dict["sem_labels"].cpu(), data_dict["instance_ids"].cpu(), + self.hparams.data.ignore_classes) gt_instances_bbox = get_gt_bbox(data_dict["locs"].cpu().numpy(), - data_dict["instance_ids"].cpu().numpy(), data_dict["sem_labels"].cpu().numpy(), self.hparams.data.ignore_label, self.hparams.data.ignore_classes) + data_dict["instance_ids"].cpu().numpy(), + data_dict["sem_labels"].cpu().numpy(), self.hparams.data.ignore_label, + self.hparams.data.ignore_classes) return pred_instances, gt_instances, gt_instances_bbox @@ -161,14 +179,41 @@ def test_step(self, data_dict, idx): output_dict["proposal_scores"][0].cpu(), output_dict["proposal_scores"][1].cpu(), output_dict["proposal_scores"][2].size(0) - 1, - output_dict["semantic_scores"].cpu(), len(self.hparams.data.ignore_classes)) + output_dict["semantic_scores"].cpu(), + len(self.hparams.data.ignore_classes)) gt_instances = get_gt_instances(sem_labels_cpu, data_dict["instance_ids"].cpu(), self.hparams.data.ignore_classes) gt_instances_bbox = get_gt_bbox(data_dict["locs"].cpu().numpy(), - data_dict["instance_ids"].cpu().numpy(), data_dict["sem_labels"].cpu().numpy(), self.hparams.data.ignore_label, self.hparams.data.ignore_classes) + data_dict["instance_ids"].cpu().numpy(), + data_dict["sem_labels"].cpu().numpy(), self.hparams.data.ignore_label, + self.hparams.data.ignore_classes) return semantic_accuracy, semantic_mean_iou, pred_instances, gt_instances, gt_instances_bbox, end_time - def _get_pred_instances(self, scan_id, gt_xyz, proposals_scores, proposals_idx, num_proposals, semantic_scores, num_ignored_classes): + def _get_nms_instances(self, cross_ious, scores, threshold): + """ non max suppression for 3D instance proposals based on cross ious and scores + + Args: + ious (np.array): cross ious, (n, n) + scores (np.array): scores for each proposal, (n,) + threshold (float): iou threshold + + Returns: + np.array: idx of picked instance proposals + """ + ixs = np.argsort(-scores) # descending order + pick = [] + while len(ixs) > 0: + i = ixs[0] + pick.append(i) + ious = cross_ious[i, ixs[1:]] + remove_ixs = np.where(ious > threshold)[0] + 1 + ixs = np.delete(ixs, remove_ixs) + ixs = np.delete(ixs, 0) + + return np.array(pick, dtype=np.int32) + + def _get_pred_instances(self, scan_id, gt_xyz, proposals_scores, proposals_idx, num_proposals, semantic_scores, + num_ignored_classes): semantic_pred_labels = semantic_scores.max(1)[1] proposals_score = torch.sigmoid(proposals_scores.view(-1)) # (nProposal,) float # proposals_idx: (sumNPoint, 2), int, cpu, dim 0 for cluster_id, dim 1 for corresponding point idxs in N @@ -190,10 +235,16 @@ def _get_pred_instances(self, scan_id, gt_xyz, proposals_scores, proposals_idx, # instance masks non_max_suppression if proposals_score.shape[0] == 0: pick_idxs = np.empty(0) - elif self.hparams.model.test.TEST_NMS_THRESH >= 1: - pick_idxs = list(range(0, proposals_score.shape[0])) else: - pick_idxs = get_nms_instance(proposals_mask.float(), proposals_score.numpy(), self.hparams.model.test.TEST_NMS_THRESH) + proposals_mask_f = proposals_mask.float() # (nProposal, N), float + intersection = torch.mm(proposals_mask_f, proposals_mask_f.t()) # (nProposal, nProposal), float + proposals_npoint = proposals_mask_f.sum(1) # (nProposal), float, cuda + proposals_np_repeat_h = proposals_npoint.unsqueeze(-1).repeat(1, proposals_npoint.shape[0]) + proposals_np_repeat_v = proposals_npoint.unsqueeze(0).repeat(proposals_npoint.shape[0], 1) + cross_ious = intersection / ( + proposals_np_repeat_h + proposals_np_repeat_v - intersection) # (nProposal, nProposal), float, cuda + pick_idxs = self._get_nms_instances(cross_ious.numpy(), proposals_score.numpy(), + self.hparams.model.test.TEST_NMS_THRESH) # int, (nCluster,) clusters_mask = proposals_mask[pick_idxs].numpy() # int, (nCluster, N) score_pred = proposals_score[pick_idxs].numpy() # float, (nCluster,) diff --git a/minsu3d/model/softgroup.py b/minsu3d/model/softgroup.py index cf9c265..56eb349 100644 --- a/minsu3d/model/softgroup.py +++ b/minsu3d/model/softgroup.py @@ -1,5 +1,6 @@ import torch import time +import numpy as np import torch.nn as nn from minsu3d.evaluation.instance_segmentation import get_gt_instances, rle_encode from minsu3d.evaluation.object_detection import get_gt_bbox @@ -8,7 +9,6 @@ from minsu3d.evaluation.semantic_segmentation import * from minsu3d.model.module import TinyUnet from minsu3d.model.general_model import GeneralModel, clusters_voxelization, get_batch_offsets -from minsu3d.util.nms import get_nms_instance class SoftGroup(GeneralModel): @@ -306,19 +306,9 @@ def _get_pred_instances(self, scan_id, gt_xyz, proposals_idx, num_points, cls_sc score_pred_list.append(score_pred) mask_pred_list.append(mask_pred) - cls_pred = torch.cat(cls_pred_list) - score_pred = torch.cat(score_pred_list) - mask_pred = torch.cat(mask_pred_list) - - if score_pred.shape[0] == 0: - pick_idxs = np.empty(0) - elif self.hparams.inference.TEST_NMS_THRESH >= 1: - pick_idxs = list(range(0, score_pred.shape[0])) - else: - pick_idxs = get_nms_instance(mask_pred.float(), score_pred.numpy(), self.hparams.inference.TEST_NMS_THRESH) - mask_pred = mask_pred[pick_idxs].numpy() # int, (nCluster, N) - score_pred = score_pred[pick_idxs].numpy() # float, (nCluster,) - cls_pred = cls_pred[pick_idxs].numpy() + cls_pred = torch.cat(cls_pred_list).numpy() + score_pred = torch.cat(score_pred_list).numpy() + mask_pred = torch.cat(mask_pred_list).numpy() pred_instances = [] for i in range(cls_pred.shape[0]):