| | from typing import Callable, Dict, Sequence, Union |
| |
|
| | import torch |
| | from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce |
| | from monai.apps.detection.metrics.coco import COCOMetric |
| | from monai.apps.detection.metrics.matching import matching_batch |
| | from monai.data import box_utils |
| |
|
| | from .utils import detach_to_numpy |
| |
|
| |
|
| | class IgniteCocoMetric(Metric): |
| | def __init__( |
| | self, |
| | coco_metric_monai: Union[None, COCOMetric] = None, |
| | box_key="box", |
| | label_key="label", |
| | pred_score_key="label_scores", |
| | output_transform: Callable = lambda x: x, |
| | device: Union[str, torch.device, None] = None, |
| | reduce_scalar: bool = True, |
| | ): |
| | r""" |
| | Computes coco detection metric in Ignite. |
| | |
| | Args: |
| | coco_metric_monai: the coco metric in monai. |
| | If not given, will asume COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100]) |
| | box_key: box key in the ground truth target dict and prediction dict. |
| | label_key: classification label key in the ground truth target dict and prediction dict. |
| | pred_score_key: classification score key in the prediction dict. |
| | output_transform: A callable that is used to transform the Engine’s |
| | process_function’s output into the form expected by the metric. |
| | device: specifies which device updates are accumulated on. |
| | Setting the metric’s device to be the same as your update arguments ensures |
| | the update method is non-blocking. By default, CPU. |
| | reduce_scalar: if True, will return the average value of coc metric values; |
| | if False, will return an dictionary of coc metric. |
| | |
| | Examples: |
| | To use with ``Engine`` and ``process_function``, |
| | simply attach the metric instance to the engine. |
| | The output of the engine's ``process_function`` needs to be in format of |
| | ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. |
| | For more information on how metric works with :class:`~ignite.engine.engine.Engine`, |
| | visit :ref:`attach-engine`. |
| | .. include:: defaults.rst |
| | :start-after: :orphan: |
| | .. testcode:: |
| | coco = IgniteCocoMetric() |
| | coco.attach(default_evaluator, 'coco') |
| | preds = [ |
| | { |
| | 'box': torch.Tensor([[1,1,1,2,2,2]]), |
| | 'label':torch.Tensor([0]), |
| | 'label_scores':torch.Tensor([0.8]) |
| | } |
| | ] |
| | target = [{'box': torch.Tensor([[1,1,1,2,2,2]]), 'label':torch.Tensor([0])}] |
| | state = default_evaluator.run([[preds, target]]) |
| | print(state.metrics['coco']) |
| | .. testoutput:: |
| | 1.0... |
| | .. versionadded:: 0.4.3 |
| | """ |
| | self.box_key = box_key |
| | self.label_key = label_key |
| | self.pred_score_key = pred_score_key |
| | if coco_metric_monai is None: |
| | self.coco_metric = COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100]) |
| | else: |
| | self.coco_metric = coco_metric_monai |
| | self.reduce_scalar = reduce_scalar |
| |
|
| | if device is None: |
| | device = torch.device("cpu") |
| | super(IgniteCocoMetric, self).__init__(output_transform=output_transform, device=device) |
| |
|
| | @reinit__is_reduced |
| | def reset(self) -> None: |
| | self.val_targets_all = [] |
| | self.val_outputs_all = [] |
| |
|
| | @reinit__is_reduced |
| | def update(self, output: Sequence[Dict]) -> None: |
| | y_pred, y = output[0], output[1] |
| | self.val_outputs_all += y_pred |
| | self.val_targets_all += y |
| |
|
| | @sync_all_reduce("val_targets_all", "val_outputs_all") |
| | def compute(self) -> float: |
| | self.val_outputs_all = detach_to_numpy(self.val_outputs_all) |
| | self.val_targets_all = detach_to_numpy(self.val_targets_all) |
| |
|
| | results_metric = matching_batch( |
| | iou_fn=box_utils.box_iou, |
| | iou_thresholds=self.coco_metric.iou_thresholds, |
| | pred_boxes=[val_data_i[self.box_key] for val_data_i in self.val_outputs_all], |
| | pred_classes=[val_data_i[self.label_key] for val_data_i in self.val_outputs_all], |
| | pred_scores=[val_data_i[self.pred_score_key] for val_data_i in self.val_outputs_all], |
| | gt_boxes=[val_data_i[self.box_key] for val_data_i in self.val_targets_all], |
| | gt_classes=[val_data_i[self.label_key] for val_data_i in self.val_targets_all], |
| | ) |
| | val_epoch_metric_dict = self.coco_metric(results_metric)[0] |
| |
|
| | if self.reduce_scalar: |
| | val_epoch_metric = val_epoch_metric_dict.values() |
| | val_epoch_metric = sum(val_epoch_metric) / len(val_epoch_metric) |
| | return val_epoch_metric |
| | else: |
| | return val_epoch_metric_dict |
| |
|