|
|
|
|
|
|
|
import sys |
|
import pathlib |
|
import subprocess |
|
import PIL.Image |
|
import torch |
|
|
|
|
|
REPO_ROOT = pathlib.Path(__file__).parent.resolve() |
|
NOVIC_DIR = REPO_ROOT / 'novic' |
|
NOVIC_TEST = NOVIC_DIR / '__init__.py' |
|
if not NOVIC_TEST.exists(): |
|
print("Initialising git submodules as NOVIC code was not found yet...") |
|
subprocess.run(['git', 'submodule', 'update', '--init', '--recursive'], cwd=REPO_ROOT, check=True) |
|
if not NOVIC_TEST.exists(): |
|
raise RuntimeError("Failed to find NOVIC code") |
|
if str(NOVIC_DIR) not in sys.path: |
|
sys.path.insert(0, str(NOVIC_DIR)) |
|
|
|
|
|
import infer |
|
|
|
|
|
infer.utils.allow_tf32(enable=True) |
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
MODELS: dict[str, infer.NOVICModel] = {} |
|
|
|
|
|
def get_model(checkpoint: str) -> infer.NOVICModel: |
|
|
|
if (model := MODELS.get(checkpoint, None)) is not None: |
|
return model |
|
|
|
model = infer.NOVICModel( |
|
checkpoint=checkpoint, |
|
gencfg='beam_k10_vnone_gp_t1_a0', |
|
guide_targets=None, |
|
torch_compile=False, |
|
batch_size=1, |
|
device=DEVICE, |
|
cfg_flat_override=None, |
|
embedder_override=None, |
|
) |
|
|
|
model.__enter__() |
|
MODELS[checkpoint] = model |
|
return model |
|
|
|
|
|
def classify_image(image: PIL.Image.Image, checkpoint: str) -> dict[str, float]: |
|
model = get_model(checkpoint=checkpoint) |
|
output = model.classify_image(image=image) |
|
return dict(zip(output.preds[0], output.probs[0], strict=True)) |
|
|
|
|