File size: 3,038 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import List

import navi
from nodes.base_output import BaseOutput, OutputKind

from ...impl.pytorch.types import PyTorchModel
from ...utils.format import format_channel_numbers


def _get_sizes(value: PyTorchModel) -> List[str]:
    if "SRVGG" in value.model_arch:
        return [f"{value.num_feat}nf", f"{value.num_conv}nc"]
    elif (
        "SwinIR" in value.model_arch
        or "Swin2SR" in value.model_arch
        or "HAT" in value.model_arch
        or "SRFormer" in value.model_arch
    ):
        head_length = len(value.depths)  # type: ignore
        if head_length <= 4:
            size_tag = "small"
        elif head_length < 9:
            size_tag = "medium"
        else:
            size_tag = "large"
        return [
            size_tag,
            f"s{value.img_size}w{value.window_size}",
            f"{value.num_feat}nf",
            f"{value.embed_dim}dim",
            f"{value.resi_connection}",
        ]
    elif "DAT" in value.model_arch:
        head_length = len(value.depth)  # type: ignore
        if head_length <= 4:
            size_tag = "small"
        elif head_length < 9:
            size_tag = "medium"
        else:
            size_tag = "large"
        return [
            size_tag,
            f"s{value.img_size}|{value.split_size[0]}x{value.split_size[1]}",  # type: ignore
            f"{value.num_feat}nf",
            f"{value.embed_dim}dim",
            f"{value.resi_connection}",
        ]
    elif "OmniSR" in value.model_arch:
        return [
            f"{value.num_feat}nf",
            f"w{value.window_size}",
            f"{value.res_num}nr",
        ]
    elif value.model_arch in [
        "GFPGAN",
        "RestoreFormer",
        "CodeFormer",
        "LaMa",
        "MAT",
        "SCUNet",
    ]:
        return []
    else:
        return [
            f"{value.num_filters}nf",
            f"{value.num_blocks}nb",
        ]


class ModelOutput(BaseOutput):
    def __init__(
        self,
        model_type: navi.ExpressionJson = "PyTorchModel",
        label: str = "Model",
        kind: OutputKind = "generic",
    ):
        super().__init__(model_type, label, kind=kind, associated_type=PyTorchModel)

    def get_broadcast_data(self, value: PyTorchModel):
        return {
            "tags": [
                value.model_arch,
                format_channel_numbers(value.in_nc, value.out_nc),
                *_get_sizes(value),
            ]
        }

    def get_broadcast_type(self, value: PyTorchModel):
        return navi.named(
            "PyTorchModel",
            {
                "scale": value.scale,
                "inputChannels": value.in_nc,
                "outputChannels": value.out_nc,
                "arch": navi.literal(value.model_arch),
                "subType": navi.literal(value.sub_type),
                "size": navi.literal("x".join(_get_sizes(value))),
            },
        )


def TorchScriptOutput():
    """Output a JIT traced model"""
    return BaseOutput("PyTorchScript", "Traced Model")