Spaces:
Runtime error
Runtime error
File size: 3,038 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from typing import List
import navi
from nodes.base_output import BaseOutput, OutputKind
from ...impl.pytorch.types import PyTorchModel
from ...utils.format import format_channel_numbers
def _get_sizes(value: PyTorchModel) -> List[str]:
if "SRVGG" in value.model_arch:
return [f"{value.num_feat}nf", f"{value.num_conv}nc"]
elif (
"SwinIR" in value.model_arch
or "Swin2SR" in value.model_arch
or "HAT" in value.model_arch
or "SRFormer" in value.model_arch
):
head_length = len(value.depths) # type: ignore
if head_length <= 4:
size_tag = "small"
elif head_length < 9:
size_tag = "medium"
else:
size_tag = "large"
return [
size_tag,
f"s{value.img_size}w{value.window_size}",
f"{value.num_feat}nf",
f"{value.embed_dim}dim",
f"{value.resi_connection}",
]
elif "DAT" in value.model_arch:
head_length = len(value.depth) # type: ignore
if head_length <= 4:
size_tag = "small"
elif head_length < 9:
size_tag = "medium"
else:
size_tag = "large"
return [
size_tag,
f"s{value.img_size}|{value.split_size[0]}x{value.split_size[1]}", # type: ignore
f"{value.num_feat}nf",
f"{value.embed_dim}dim",
f"{value.resi_connection}",
]
elif "OmniSR" in value.model_arch:
return [
f"{value.num_feat}nf",
f"w{value.window_size}",
f"{value.res_num}nr",
]
elif value.model_arch in [
"GFPGAN",
"RestoreFormer",
"CodeFormer",
"LaMa",
"MAT",
"SCUNet",
]:
return []
else:
return [
f"{value.num_filters}nf",
f"{value.num_blocks}nb",
]
class ModelOutput(BaseOutput):
def __init__(
self,
model_type: navi.ExpressionJson = "PyTorchModel",
label: str = "Model",
kind: OutputKind = "generic",
):
super().__init__(model_type, label, kind=kind, associated_type=PyTorchModel)
def get_broadcast_data(self, value: PyTorchModel):
return {
"tags": [
value.model_arch,
format_channel_numbers(value.in_nc, value.out_nc),
*_get_sizes(value),
]
}
def get_broadcast_type(self, value: PyTorchModel):
return navi.named(
"PyTorchModel",
{
"scale": value.scale,
"inputChannels": value.in_nc,
"outputChannels": value.out_nc,
"arch": navi.literal(value.model_arch),
"subType": navi.literal(value.sub_type),
"size": navi.literal("x".join(_get_sizes(value))),
},
)
def TorchScriptOutput():
"""Output a JIT traced model"""
return BaseOutput("PyTorchScript", "Traced Model")
|