PredictionTypeHandler

code_loader.contract.datasetclasses.PredictionTypeHandler

The purpose of PredictionTypeHandler is to describe the prediction(s) tensors for visualization and analysis purposes.

@dataclass
class PredictionTypeHandler:
    name: str
    labels: List[str]
    channel_dim: int = -1

For more on PredictionTypeHandler:

Args

name

(str) The given name of the output, e.g. image/logits/entity

labels

(List[str]) an array containing the labels associated with this prediction

channel_dim

(defaults to -1). The dimension in which the channels exists in the prediction. channel_dim should be set to 1 for channel first predictions (i.e. - C,H,W).

Examples

Basic Usage

MNIST example

One output, which we name classes that has 10 channels, each a logit for the classification of a digit

from code_loader.contract.datasetclasses import PredictionTypeHandler
prediction_type1 = PredictionTypeHandler('classes', [str(i) for i in range(10)])

YOLO example

Four outputs:

  • a conctatenated prediction with #channels = 4 + #classes

  • Three scales, with #channels of 20,40,80

prediction_type1 = PredictionTypeHandler(name='object detection', labels=["x", "y", "w", "h"] + [cl for cl in all_clss.values()], channel_dim=1)
prediction_type2 = PredictionTypeHandler(name='concatenate_20', labels=[str(i) for i in range(20)], channel_dim=-1)
prediction_type3 = PredictionTypeHandler(name='concatenate_40', labels=[str(i) for i in range(40)], channel_dim=-1)
prediction_type4 = PredictionTypeHandler(name='concatenate_80', labels=[str(i) for i in range(80)], channel_dim=-1)

Last updated

Was this helpful?