Spaces:
Build error
Build error
from datasets.arrow_dataset import InMemoryTable | |
import streamlit as st | |
from PIL import Image, ImageDraw | |
from streamlit_image_coordinates import streamlit_image_coordinates | |
import numpy as np | |
from datasets import load_dataset | |
st.set_page_config(layout="wide") | |
ds = load_dataset("Circularmachines/batch_indexing_machine_green_test", split="test") | |
patch_size=32 | |
stride=16 | |
#image_size=2304 | |
image_size=512 | |
gridsize=31 | |
n_patches=961 | |
#pred_dict={'Trained on color images (recommended)': np.load('pred_all_scratch.npy').reshape(-1,64), | |
# 'Trained on grayscale images': np.load('pred_all_grey.npy').reshape(-1,64)} | |
pred_dict={ 'Trained on augmented images 230809': np.load('pred_all_green_random.npy').reshape(-1,64), | |
'Trained on unaugmented images 230805': np.load('pred_all_scratch.npy').reshape(-1,64)} | |
random_i=np.load('random.npy') | |
if "point" not in st.session_state: | |
st.session_state["point"] = (128,64) | |
st.session_state["model"] = tuple(pred_dict.keys())[0] | |
if "img" not in st.session_state: | |
st.session_state["img"] = 0 | |
if "draw" not in st.session_state: | |
st.session_state["draw"] = True | |
def patch(ij): | |
#st.write(ij) | |
immg=ij//n_patches | |
imm=ds[int(immg)]['image'].resize(size=(512,512)) | |
p=ij%n_patches | |
y=p//gridsize | |
x=p%gridsize | |
imc=imm.crop(((x-1)*stride,(y-1)*stride,(x+3)*stride,(y+3)*stride)) | |
return imc | |
def find(): | |
st.session_state["sideix"] = [] | |
point=st.session_state["point"] | |
point=(point[0]//stride,point[1]//stride) | |
#point=point[0]*36+point[1] | |
#st.write(point) | |
#st.write(pred_all[st.session_state["img"],point[0]*36+point[1]]) | |
i=st.session_state["img"] | |
p=point[1]*gridsize+point[0] | |
diff=np.linalg.norm(pred_dict[st.session_state["model"]][np.newaxis,i*n_patches+p,:]-pred_dict[st.session_state["model"]],axis=-1) | |
#re_pred=pred_all.reshape(20,20,256,64) | |
#diff_re=diff.reshape((20,20,256)).argmin(axis=[]) | |
i=0 | |
ix=0 | |
batches=[] | |
while ix<4: | |
batch=diff.argsort()[i]//n_patches//20 | |
if batch not in batches: | |
batches.append(batch) | |
st.session_state["sideimg"][ix]=patch(diff.argsort()[i]) | |
ix+=1 | |
i+=1 | |
st.session_state["sideix"]=batches | |
def button_click(): | |
st.session_state["img"]=np.random.randint(100) | |
st.session_state["draw"] = False | |
if "sideimg" not in st.session_state: | |
st.session_state["sideimg"] = [patch(i) for i in range(4)] | |
if "sideix" not in st.session_state: | |
find() | |
def get_ellipse_coords(point):# tuple[int, int]) -> tuple[int, int, int, int]): | |
center = point | |
#patch_size | |
return ( | |
center[0] , | |
center[1] , | |
center[0] + patch_size, | |
center[1] + patch_size, | |
) | |
col1, col2, col3= st.columns([3,1,1]) | |
with col1: | |
current_image=ds[st.session_state["img"]]['image'].resize(size=(512,512)) | |
draw = ImageDraw.Draw(current_image) | |
if st.session_state["draw"]: | |
# Draw an ellipse at each coordinate in points | |
#for point in st.session_state["points"]: | |
point=st.session_state["point"] | |
coords = get_ellipse_coords(point) | |
draw.rectangle(coords, outline="green",width=2) | |
value = streamlit_image_coordinates(current_image, key="pil") | |
if value is not None: | |
point = (value["x"]-8)//stride*stride, (value["y"]-8)//stride*stride | |
if point != st.session_state["point"]: | |
st.session_state["point"]=point | |
st.session_state["draw"]=True | |
st.experimental_rerun() | |
#subcol1, subcol2 = st.columns(2) | |
#with subcol1: | |
#st.button('Previous Frame', on_click=button_click) | |
scol1, scol2 = st.columns(2) | |
with scol1: | |
st.button('Change Image', on_click=button_click) | |
st.selectbox("Model",tuple(pred_dict.keys()),key="model") | |
with scol2: | |
st.button('Find similar parts', on_click=find) | |
#st.write("Currently viewing frame "+str(random_i[st.session_state["img"]%20])+" in batch "+str(st.session_state["img"]//20)) | |
#st.write(st.session_state["img"]) | |
#st.write(st.session_state["point"]) | |
#st.write(st.session_state["draw"]) | |
with col2: | |
# st.write("current selection:") | |
for i in [0,2]: | |
if i==0: | |
st.write("Target in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20)) | |
else: | |
st.write("Match #"+str(i)+" in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20)) | |
st.image(st.session_state["sideimg"][i].resize((192,192))) | |
with col3: | |
# st.write("current selection:") | |
for i in [1,3]: | |
st.write("Match #"+str(i)+" in batch "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20)) | |
st.image(st.session_state["sideimg"][i].resize((192,192))) | |
#st.write(st.session_state["sideix"][i]) | |
#st.write(st.session_state["sideix"][i]) | |
"The batch indexing machine shakes parts while recording a video." | |
"The machine processed 20 batches of random parts, with each batch running for 30 seconds." | |
"INSCTRUCTIONS:" | |
"Click in the image to set target part" | |
"Click “Find similar parts” to find the best matches in other batches" | |
"The model is trained completely unsupervised using a CNN with a custom contrastive loss." | |
"https://github.com/circularmachines/batch_indexing_machine/" | |
"johan.lagerloef@gmail.com" | |