@tensorleap_load_model
code_loader.inner_leap_binder.leapbinder_decorators.tensorleap_load_model
import os
from code_loader.contract.datasetclasses import PredictionTypeHandler
from code_loader.inner_leap_binder.leapbinder_decorators import tensorleap_load_model
import tensorflow as tf
prediction_type1 = PredictionTypeHandler('classes',[str(i) for i in range(10)])
@tensorleap_load_model([prediction_type1])
def load_model():
dir_path = os.path.dirname(os.path.abspath(__file__))
model_path = 'model/model.h5'
cnn = tf.keras.models.load_model(os.path.join(dir_path, model_path))
return cnn
Args
prediction_types
(Optional, List[PredictionTypeHandler]) This property defines the outputs of the model uploaded to Tensorleap: their names, labels, and channels_dim (=1 for channels first and =-1 for channels last).
PredictionTypes Examples
MNIST example
One output, which we name classes
that has 10 channels, each a logit for the classification of a digit
from code_loader.contract.datasetclasses import PredictionTypeHandler
from code_loader.inner_leap_binder.leapbinder_decorators import tensorleap_load_model
prediction_type1 = PredictionTypeHandler('classes', [str(i) for i in range(10)])
@tensorleap_load_model([prediction_type1])
def load_model():
#Retrun an .onnx or .h5 model
...
YOLO example
Four outputs:
a conctatenated prediction with #channels = 4 + #classes
Three scales, with #channels of 20,40,80
from code_loader.contract.datasetclasses import PredictionTypeHandler
from code_loader.inner_leap_binder.leapbinder_decorators import tensorleap_load_model
prediction_type1 = PredictionTypeHandler(name='object detection', labels=["x", "y", "w", "h"] + [cl for cl in all_clss.values()], channel_dim=1)
prediction_type2 = PredictionTypeHandler(name='concatenate_20', labels=[str(i) for i in range(20)], channel_dim=-1)
prediction_type3 = PredictionTypeHandler(name='concatenate_40', labels=[str(i) for i in range(40)], channel_dim=-1)
prediction_type4 = PredictionTypeHandler(name='concatenate_80', labels=[str(i) for i in range(80)], channel_dim=-1)
@tensorleap_load_model([prediction_type1, prediction_type2, prediction_type3, prediction_type4])
def load_model():
#Retrun an .onnx or .h5 model
...
Last updated
Was this helpful?