File size: 849 Bytes
f499d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from torch.optim import Optimizer
from lightning.pytorch import LightningModule
from lightning.pytorch.callbacks import BasePredictionWriter

from .ar import ARSystem, ARWriter
from .skin import SkinSystem, SkinWriter

def get_system(**kwargs) -> LightningModule:
    MAP = {
        'ar': ARSystem,
        'skin': SkinSystem,
    }
    __target__ = kwargs['__target__']
    assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}"
    del kwargs['__target__']
    return MAP[__target__](**kwargs)

def get_writer(**kwargs) -> BasePredictionWriter:
    MAP = {
        'ar': ARWriter,
        'skin': SkinWriter,
    }
    __target__ = kwargs['__target__']
    assert __target__ in MAP, f"expect: [{','.join(MAP.keys())}], found: {__target__}"
    del kwargs['__target__']
    return MAP[__target__](**kwargs)