@tensorleap_load_model

code_loader.inner_leap_binder.leapbinder_decorators.tensorleap_load_model

import os
from code_loader.contract.datasetclasses import PredictionTypeHandler
from code_loader.inner_leap_binder.leapbinder_decorators import tensorleap_load_model
import tensorflow as tf


prediction_type1 = PredictionTypeHandler('classes',[str(i) for i in range(10)])

@tensorleap_load_model([prediction_type1])
def load_model():
    dir_path = os.path.dirname(os.path.abspath(__file__))
    model_path = 'model/model.h5'
    cnn = tf.keras.models.load_model(os.path.join(dir_path, model_path))
    return cnn
    
Args

prediction_types

(Optional, List[PredictionTypeHandler]) This property defines the outputs of the model uploaded to Tensorleap: their names, labels, and channels_dim (=1 for channels first and =-1 for channels last).

PredictionTypes Examples

MNIST example

One output, which we name classes that has 10 channels, each a logit for the classification of a digit

from code_loader.contract.datasetclasses import PredictionTypeHandler
from code_loader.inner_leap_binder.leapbinder_decorators import tensorleap_load_model
prediction_type1 = PredictionTypeHandler('classes', [str(i) for i in range(10)])

@tensorleap_load_model([prediction_type1])
def load_model():
#Retrun an .onnx or .h5 model
...

YOLO example

Four outputs:

  • a conctatenated prediction with #channels = 4 + #classes

  • Three scales, with #channels of 20,40,80

from code_loader.contract.datasetclasses import PredictionTypeHandler
from code_loader.inner_leap_binder.leapbinder_decorators import tensorleap_load_model

prediction_type1 = PredictionTypeHandler(name='object detection', labels=["x", "y", "w", "h"] + [cl for cl in all_clss.values()], channel_dim=1)
prediction_type2 = PredictionTypeHandler(name='concatenate_20', labels=[str(i) for i in range(20)], channel_dim=-1)
prediction_type3 = PredictionTypeHandler(name='concatenate_40', labels=[str(i) for i in range(40)], channel_dim=-1)
prediction_type4 = PredictionTypeHandler(name='concatenate_80', labels=[str(i) for i in range(80)], channel_dim=-1)


@tensorleap_load_model([prediction_type1, prediction_type2, prediction_type3, prediction_type4])
def load_model():
#Retrun an .onnx or .h5 model
...

Last updated

Was this helpful?