Spaces:
Runtime error
Runtime error
import sys | |
import os | |
# 依存関係のインストール(basicsrを明示的にインストール) | |
os.system("git clone https://github.com/sczhou/CodeFormer.git") | |
os.system("pip install basicsr facexlib gfpgan") # 主要な依存関係を先にインストール | |
os.system("cd CodeFormer && pip install -r requirements.txt") | |
os.system("pip install -e ./CodeFormer") | |
os.system("pip install -e ./CodeFormer/basicsr") | |
# パスを追加(CodeFormerとそのサブディレクトリ) | |
sys.path.insert(0, os.path.abspath('.')) | |
sys.path.insert(0, os.path.abspath('CodeFormer')) | |
sys.path.insert(0, os.path.abspath('CodeFormer/CodeFormer')) | |
sys.path.insert(0, os.path.abspath('CodeFormer/basicsr')) | |
os.makedirs("/home/user/app/CodeFormer/basicsr", exist_ok=True) | |
with open("/home/user/app/CodeFormer/basicsr/version.py", "w") as f: | |
f.write("__version__ = '1.0.0'\n__gitsha__ = 'unknown'\n") | |
# ウェイトファイルをダウンロード | |
weights = { | |
'realesr-general-x4v3.pth': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', | |
'GFPGANv1.2.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth', | |
'GFPGANv1.3.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', | |
'GFPGANv1.4.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', | |
'RestoreFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth', | |
'CodeFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth' | |
} | |
for filename, url in weights.items(): | |
if not os.path.exists(filename): | |
os.system(f"wget {url} -O {filename}") | |
import cv2 | |
import torch | |
from flask import Flask, request, jsonify, send_file | |
from basicsr.archs.srvgg_arch import SRVGGNetCompact | |
from gfpgan.utils import GFPGANer | |
from realesrgan.utils import RealESRGANer | |
import uuid | |
import tempfile | |
from torchvision.transforms.functional import normalize | |
from torchvision import transforms | |
from PIL import Image | |
from basicsr.utils import img2tensor, tensor2img | |
from facexlib.utils.face_restoration_helper import FaceRestoreHelper | |
from CodeFormer.codeformer_arch import CodeFormer | |
app = Flask(__name__) | |
# モデルの初期化 | |
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
model_path = 'realesr-general-x4v3.pth' | |
half = True if torch.cuda.is_available() else False | |
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half) | |
os.makedirs('output', exist_ok=True) | |
def restore_with_codeformer(img, scale=2, weight=0.5): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
net = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device) | |
net.load_state_dict(torch.load('CodeFormer.pth')['params_ema']) | |
net.eval() | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
img = Image.fromarray(img) | |
face_helper = FaceRestoreHelper( | |
upscale_factor=scale, face_size=512, crop_ratio=(1, 1), use_parse=True, | |
device=device) | |
face_helper.clean_all() | |
face_helper.read_image(img) | |
face_helper.get_face_landmarks_5(only_center_face=False, resize=640) | |
face_helper.align_warp_face() | |
for idx, cropped_face in enumerate(face_helper.cropped_faces): | |
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=False, float32=True) | |
normalize(cropped_face_t, [0.5], [0.5], inplace=True) | |
cropped_face_t = cropped_face_t.unsqueeze(0).to(device) | |
with torch.no_grad(): | |
output = net(cropped_face_t, w=weight, adain=True)[0] | |
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) | |
face_helper.add_restored_face(restored_face) | |
restored_img = face_helper.paste_faces_to_input_image() | |
return cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR) | |
def restore_image(): | |
try: | |
if 'file' not in request.files: | |
return jsonify({'error': 'No file uploaded'}), 400 | |
file = request.files['file'] | |
version = request.form.get('version', 'v1.4') | |
scale = float(request.form.get('scale', 2)) | |
weight = float(request.form.get('weight', 0.5)) | |
temp_dir = tempfile.mkdtemp() | |
input_path = os.path.join(temp_dir, file.filename) | |
file.save(input_path) | |
extension = os.path.splitext(os.path.basename(str(input_path)))[1] | |
img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED) | |
if len(img.shape) == 3 and img.shape[2] == 4: | |
img_mode = 'RGBA' | |
elif len(img.shape) == 2: | |
img_mode = None | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
else: | |
img_mode = None | |
h, w = img.shape[0:2] | |
if h < 300: | |
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4) | |
if version == 'v1.2': | |
face_enhancer = GFPGANer( | |
model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) | |
elif version == 'v1.3': | |
face_enhancer = GFPGANer( | |
model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) | |
elif version == 'v1.4': | |
face_enhancer = GFPGANer( | |
model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) | |
elif version == 'RestoreFormer': | |
face_enhancer = GFPGANer( | |
model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler) | |
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) | |
elif version == 'CodeFormer': | |
output = restore_with_codeformer(img, scale=scale, weight=weight) | |
elif version == 'RealESR-General-x4v3': | |
face_enhancer = GFPGANer( | |
model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler) | |
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) | |
if scale != 2: | |
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 | |
h, w = img.shape[0:2] | |
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) | |
output_filename = f'output_{uuid.uuid4().hex}' | |
if img_mode == 'RGBA': | |
output_path = os.path.join('output', f'{output_filename}.png') | |
cv2.imwrite(output_path, output) | |
mimetype = 'image/png' | |
else: | |
output_path = os.path.join('output', f'{output_filename}.jpg') | |
cv2.imwrite(output_path, output) | |
mimetype = 'image/jpeg' | |
return send_file(output_path, mimetype=mimetype, as_attachment=True, download_name=os.path.basename(output_path)) | |
except Exception as e: | |
return jsonify({'error': str(e)}), 500 | |
def index(): | |
return """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Image Upscaling & Restoration API</title> | |
<style> | |
body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; } | |
.container { border: 1px solid #ddd; padding: 20px; border-radius: 5px; } | |
.form-group { margin-bottom: 15px; } | |
label { display: block; margin-bottom: 5px; } | |
input, select { width: 100%; padding: 8px; box-sizing: border-box; } | |
button { background-color: #4CAF50; color: white; padding: 10px 15px; border: none; border-radius: 4px; cursor: pointer; } | |
button:hover { background-color: #45a049; } | |
#result { margin-top: 20px; } | |
#preview { max-width: 100%; margin-top: 10px; } | |
#apiUsage { background-color: #f5f5f5; padding: 15px; border-radius: 5px; margin-top: 20px; font-family: monospace; white-space: pre-wrap; } | |
#apiUsage h3 { margin-top: 0; } | |
#formDataPreview { max-height: 200px; overflow-y: auto; margin-bottom: 10px; } | |
.code-block { background-color: #f8f8f8; padding: 10px; border-radius: 4px; border-left: 3px solid #4CAF50; } | |
.comment { color: #666; font-style: italic; } | |
.loader { | |
width: 48px; | |
height: 48px; | |
border: 5px solid #4CAF50; | |
border-bottom-color: transparent; | |
border-radius: 50%; | |
display: inline-block; | |
box-sizing: border-box; | |
animation: rotation 1s linear infinite; | |
margin: 20px auto; | |
display: none; | |
} | |
@keyframes rotation { | |
0% { transform: rotate(0deg); } | |
100% { transform: rotate(360deg); } | |
} | |
</style> | |
</head> | |
<body> | |
<h1>Image Upscaling & Restoration API</h1> | |
<div class="container"> | |
<form id="uploadForm" enctype="multipart/form-data"> | |
<div class="form-group"> | |
<label for="file">Upload Image:</label> | |
<input type="file" id="file" name="file" required> | |
</div> | |
<div class="form-group"> | |
<label for="version">Version:</label> | |
<select id="version" name="version"> | |
<option value="v1.2">GFPGANv1.2</option> | |
<option value="v1.3">GFPGANv1.3</option> | |
<option value="v1.4" selected>GFPGANv1.4</option> | |
<option value="RestoreFormer">RestoreFormer</option> | |
<option value="CodeFormer">CodeFormer</option> | |
<option value="RealESR-General-x4v3">RealESR-General-x4v3</option> | |
</select> | |
</div> | |
<div class="form-group"> | |
<label for="scale">Rescaling factor:</label> | |
<input type="number" id="scale" name="scale" value="2" step="0.1" min="1" max="4" required> | |
</div> | |
<div class="form-group" id="weightGroup" style="display: none;"> | |
<label for="weight">CodeFormer Weight (0-1):</label> | |
<input type="number" id="weight" name="weight" value="0.5" step="0.1" min="0" max="1"> | |
</div> | |
<button type="submit" id="submitButton">Process Image</button> | |
</form> | |
<div id="loading" class="loader"></div> | |
<div id="result"> | |
<h3>Result:</h3> | |
<div id="outputContainer" style="display: none;"> | |
<img id="preview" src="" alt="Processed Image"> | |
<a id="downloadLink" href="#" download>Download Image</a> | |
</div> | |
</div> | |
<div id="apiUsage"> | |
<h3>API Usage:</h3> | |
<div id="fetchCode" class="code-block"> | |
// JavaScript fetch code will appear here | |
</div> | |
</div> | |
</div> | |
<script> | |
document.getElementById('version').addEventListener('change', function() { | |
const weightGroup = document.getElementById('weightGroup'); | |
if (this.value === 'CodeFormer') { | |
weightGroup.style.display = 'block'; | |
} else { | |
weightGroup.style.display = 'none'; | |
} | |
updateApiUsage(); | |
}); | |
function updateApiUsage() { | |
const fileInput = document.getElementById('file'); | |
const version = document.getElementById('version').value; | |
const scale = document.getElementById('scale').value; | |
const weight = document.getElementById('weight').value; | |
const baseUrl = window.location.origin; | |
const apiUrl = baseUrl + '/api/restore'; | |
let filePreview = '"img-dataURL"'; | |
if (fileInput.files.length > 0) { | |
const file = fileInput.files[0]; | |
const reader = new FileReader(); | |
reader.onload = function(e) { | |
const dataURL = e.target.result; | |
if (dataURL.length > 40) { | |
filePreview = "${dataURL.substring(0, 20)}...${dataURL.substring(dataURL.length - 20)}"; | |
} else { | |
filePreview = "${dataURL}"; | |
} | |
updateFetchCode(apiUrl, version, scale, weight, filePreview); | |
}; | |
reader.readAsDataURL(file); | |
} else { | |
updateFetchCode(apiUrl, version, scale, weight, filePreview); | |
} | |
} | |
function updateFetchCode(apiUrl, version, scale, weight, filePreview) { | |
const fetchCodeDiv = document.getElementById('fetchCode'); | |
let code = `// JavaScript fetch example: | |
const formData = new FormData(); | |
formData.append('file', ${filePreview}); | |
formData.append('version', '${version}'); | |
formData.append('scale', ${scale});`; | |
if (version === 'CodeFormer') { | |
code += ` | |
formData.append('weight', ${weight});`; | |
} | |
code += ` | |
fetch('${apiUrl}', { | |
method: 'POST', | |
body: formData | |
}) | |
.then(response => { | |
if (!response.ok) { | |
return response.json().then(err => { throw new Error(err.error || 'Unknown error'); }); | |
} | |
return response.blob(); | |
}) | |
.then(blob => { | |
const url = URL.createObjectURL(blob); | |
console.log('Image processed successfully', url); | |
}) | |
.catch(error => { | |
console.error('Error:', error.message); | |
});`; | |
fetchCodeDiv.innerHTML = code; | |
} | |
document.getElementById('file').addEventListener('change', updateApiUsage); | |
document.getElementById('version').addEventListener('change', updateApiUsage); | |
document.getElementById('scale').addEventListener('input', updateApiUsage); | |
document.getElementById('weight').addEventListener('input', updateApiUsage); | |
updateApiUsage(); | |
document.getElementById('uploadForm').addEventListener('submit', function(e) { | |
e.preventDefault(); | |
const submitButton = document.getElementById('submitButton'); | |
const loadingElement = document.getElementById('loading'); | |
submitButton.disabled = true; | |
loadingElement.style.display = 'block'; | |
const formData = new FormData(); | |
formData.append('file', document.getElementById('file').files[0]); | |
formData.append('version', document.getElementById('version').value); | |
formData.append('scale', document.getElementById('scale').value); | |
if (document.getElementById('version').value === 'CodeFormer') { | |
formData.append('weight', document.getElementById('weight').value); | |
} | |
const baseUrl = window.location.origin; | |
const apiUrl = baseUrl + '/api/restore'; | |
fetch(apiUrl, { | |
method: 'POST', | |
body: formData | |
}) | |
.then(response => { | |
if (!response.ok) { | |
return response.json().then(err => { throw new Error(err.error || 'Unknown error'); }); | |
} | |
return response.blob(); | |
}) | |
.then(blob => { | |
const url = URL.createObjectURL(blob); | |
const preview = document.getElementById('preview'); | |
const downloadLink = document.getElementById('downloadLink'); | |
const outputContainer = document.getElementById('outputContainer'); | |
preview.src = url; | |
downloadLink.href = url; | |
downloadLink.download = 'restored_' + document.getElementById('file').files[0].name; | |
outputContainer.style.display = 'block'; | |
}) | |
.catch(error => { | |
alert('Error: ' + error.message); | |
}) | |
.finally(() => { | |
loadingElement.style.display = 'none'; | |
submitButton.disabled = false; | |
}); | |
}); | |
</script> | |
</body> | |
</html> | |
""" | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860, debug=True) |