|
|
|
import numpy as np |
|
from medsam import MedSAM |
|
import SimpleITK as sitk |
|
from flask import Flask, request, jsonify |
|
|
|
app = Flask(__name__) |
|
model = MedSAM.load_from_checkpoint("/app/medsam_vit_b.pth") |
|
|
|
@app.route('/segment', methods=['POST']) |
|
def segment(): |
|
|
|
data = request.json |
|
image = np.array(data['image']) |
|
prompt = data['prompt'] |
|
|
|
|
|
img_3c = np.repeat(image[:,:,None], 3, axis=-1) |
|
|
|
|
|
medsam_prompt = [] |
|
for p in prompt: |
|
if p['type'] == 'point': |
|
medsam_prompt.append({'point': p['data'], 'label': 1}) |
|
elif p['type'] == 'box': |
|
medsam_prompt.append({'box': p['data']}) |
|
|
|
|
|
mask = model.predict(img_3c, medsam_prompt) |
|
|
|
return jsonify({"mask": mask.tolist()}) |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860) |