from typing import List import navi from nodes.base_output import BaseOutput, OutputKind from ...impl.pytorch.types import PyTorchModel from ...utils.format import format_channel_numbers def _get_sizes(value: PyTorchModel) -> List[str]: if "SRVGG" in value.model_arch: return [f"{value.num_feat}nf", f"{value.num_conv}nc"] elif ( "SwinIR" in value.model_arch or "Swin2SR" in value.model_arch or "HAT" in value.model_arch or "SRFormer" in value.model_arch ): head_length = len(value.depths) # type: ignore if head_length <= 4: size_tag = "small" elif head_length < 9: size_tag = "medium" else: size_tag = "large" return [ size_tag, f"s{value.img_size}w{value.window_size}", f"{value.num_feat}nf", f"{value.embed_dim}dim", f"{value.resi_connection}", ] elif "DAT" in value.model_arch: head_length = len(value.depth) # type: ignore if head_length <= 4: size_tag = "small" elif head_length < 9: size_tag = "medium" else: size_tag = "large" return [ size_tag, f"s{value.img_size}|{value.split_size[0]}x{value.split_size[1]}", # type: ignore f"{value.num_feat}nf", f"{value.embed_dim}dim", f"{value.resi_connection}", ] elif "OmniSR" in value.model_arch: return [ f"{value.num_feat}nf", f"w{value.window_size}", f"{value.res_num}nr", ] elif value.model_arch in [ "GFPGAN", "RestoreFormer", "CodeFormer", "LaMa", "MAT", "SCUNet", ]: return [] else: return [ f"{value.num_filters}nf", f"{value.num_blocks}nb", ] class ModelOutput(BaseOutput): def __init__( self, model_type: navi.ExpressionJson = "PyTorchModel", label: str = "Model", kind: OutputKind = "generic", ): super().__init__(model_type, label, kind=kind, associated_type=PyTorchModel) def get_broadcast_data(self, value: PyTorchModel): return { "tags": [ value.model_arch, format_channel_numbers(value.in_nc, value.out_nc), *_get_sizes(value), ] } def get_broadcast_type(self, value: PyTorchModel): return navi.named( "PyTorchModel", { "scale": value.scale, "inputChannels": value.in_nc, "outputChannels": value.out_nc, "arch": navi.literal(value.model_arch), "subType": navi.literal(value.sub_type), "size": navi.literal("x".join(_get_sizes(value))), }, ) def TorchScriptOutput(): """Output a JIT traced model""" return BaseOutput("PyTorchScript", "Traced Model")