PredictionTypeHandler
code_loader.contract.datasetclasses.PredictionTypeHandler
The purpose of PredictionTypeHandler
is to describe the prediction(s) tensors for visualization and analysis purposes.
@dataclass
class PredictionTypeHandler:
name: str
labels: List[str]
channel_dim: int = -1
For more on PredictionTypeHandler
:
name
(str) The given name of the output, e.g. image/logits/entity
labels
(List[str]) an array containing the labels associated with this prediction
channel_dim
(defaults to -1). The dimension in which the channels exists in the prediction. channel_dim should be set to 1 for channel first predictions (i.e. - C,H,W).
Examples
Basic Usage
MNIST example
One output, which we name classes
that has 10 channels, each a logit for the classification of a digit
from code_loader.contract.datasetclasses import PredictionTypeHandler
prediction_type1 = PredictionTypeHandler('classes', [str(i) for i in range(10)])
YOLO example
Four outputs:
a conctatenated prediction with #channels = 4 + #classes
Three scales, with #channels of 20,40,80
prediction_type1 = PredictionTypeHandler(name='object detection', labels=["x", "y", "w", "h"] + [cl for cl in all_clss.values()], channel_dim=1)
prediction_type2 = PredictionTypeHandler(name='concatenate_20', labels=[str(i) for i in range(20)], channel_dim=-1)
prediction_type3 = PredictionTypeHandler(name='concatenate_40', labels=[str(i) for i in range(40)], channel_dim=-1)
prediction_type4 = PredictionTypeHandler(name='concatenate_80', labels=[str(i) for i in range(80)], channel_dim=-1)
Last updated
Was this helpful?