mooki0's picture
Initial commit of Gradio app
57276d4 verified
import os
import cv2
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
# build sr model
def build_sr_model(scale=2, model_name=None, tile=0, tile_pad=10, pre_pad=0, fp32=False, gpu_id=None):
# if model_name not specified, use default mapping
if model_name is None:
if scale == 2:
model_name = 'RealESRGAN_x2plus'
else:
model_name = 'RealESRGAN_x4plus'
# model architecture configs
model_configs = {
'RealESRGAN_x2plus': {
'internal_scale': 2,
'model': lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2),
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
},
'RealESRGAN_x4plus': {
'internal_scale': 4,
'model': lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
'url': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
}
}
if model_name not in model_configs:
raise ValueError(
f'Unknown model name: {model_name}. Available models: {list(model_configs.keys())}')
config = model_configs[model_name]
model = config['model']()
file_url = [config['url']]
model_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'weights', model_name + '.pth')
if not os.path.isfile(model_path):
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
for url in file_url:
# model_path will be updated
model_path = load_file_from_url(
url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
# restorer
upsampler = RealESRGANer(
scale=config['internal_scale'], # Use the internal scale of the model
model_path=model_path,
dni_weight=None,
model=model,
tile=tile,
tile_pad=tile_pad,
pre_pad=pre_pad,
half=not fp32,
gpu_id=gpu_id)
return upsampler
# sr inference code
def sr_inference(input, output_path, upsampler, scale=2, ext='auto', suffix='sr'):
os.makedirs(output_path, exist_ok=True)
path = input
imgname, extension = os.path.splitext(os.path.basename(path))
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
width = img.shape[1]
# pad the image to make eliminate the border artifacts
pad_len = width // 4
img = cv2.copyMakeBorder(img, 0, 0, pad_len, pad_len, cv2.BORDER_WRAP)
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
else:
img_mode = None
try:
output, _ = upsampler.enhance(
img, outscale=scale) # Use the input scale as the final output amplification factor
# remove the padding
output = output[:, int(pad_len*scale):int((width+pad_len)*scale), :]
except RuntimeError as error:
print('Error', error)
print(
'If you encounter CUDA out of memory, try to set --tile with a smaller number.')
else:
if ext == 'auto':
extension = extension[1:]
else:
extension = ext
if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png'
if suffix == '':
save_path = os.path.join(output_path, f'{imgname}.{extension}')
else:
save_path = os.path.join(
output_path, f'{imgname}_{suffix}.{extension}')
cv2.imwrite(save_path, output)