import numpy.typing as npt
from code_loader.contract.enums import LeapDataType
@dataclass
class LeapGraph:
data: npt.NDArray[np.float32]
type: LeapDataType = LeapDataType.Graph
from code_loader.contract.visualizer_classes import LeapGraph
import numpy as np
...
def diff_per_channel_visualizer(prediction: np.ndarray, ground_truth: np.array) -> LeapGraph:
diff = pred - ground_truth
return LeapGraph(diff)
leap_binder.set_visualizer(
name='diff_per_channel',
function=diff_per_channel_visualizer,
visualizer_type=LeapGraph.type
)