File size: 713 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import navi
from nodes.base_output import BaseOutput

from ...impl.onnx.model import OnnxModel


class OnnxModelOutput(BaseOutput):
    """Output for onnx model"""

    def __init__(
        self,
        model_type: navi.ExpressionJson = "OnnxModel",
        label: str = "Model",
    ):
        super().__init__(model_type, label, associated_type=OnnxModel)

    def get_broadcast_type(self, value: OnnxModel):
        fields = {
            "subType": navi.literal(value.sub_type),
        }

        if value.scale_width:
            fields["scaleWidth"] = value.scale_width
        if value.scale_height:
            fields["scaleHeight"] = value.scale_height

        return navi.named("OnnxModel", fields)