Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| import zipfile | |
| import torch | |
| from utils import * | |
| import matplotlib.pyplot as plt | |
| from matplotlib import colors | |
| if not hasattr(st, 'paths'): | |
| st.paths = None | |
| # Load Model | |
| # @title Load pretrained weights | |
| best_model_daily_file_name = "best_model_daily.pth" | |
| best_model_annual_file_name = "best_model_annual.pth" | |
| first_input_batch = torch.zeros(71, 9, 5, 48, 48) | |
| # first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:]) | |
| daily_model = FPN(opt, first_input_batch, opt.win_size) | |
| annual_model = SimpleNN(opt) | |
| if torch.cuda.is_available(): | |
| daily_model = torch.nn.DataParallel(daily_model).cuda() | |
| annual_model = torch.nn.DataParallel(annual_model).cuda() | |
| daily_model = torch.nn.DataParallel(daily_model).cuda() | |
| annual_model = torch.nn.DataParallel(annual_model).cuda() | |
| else: | |
| daily_model = torch.nn.DataParallel(daily_model).cpu() | |
| annual_model = torch.nn.DataParallel(annual_model).cpu() | |
| daily_model = torch.nn.DataParallel(daily_model).cpu() | |
| annual_model = torch.nn.DataParallel(annual_model).cpu() | |
| print('trying to resume previous saved models...') | |
| state = resume( | |
| os.path.join(opt.resume_path, best_model_daily_file_name), | |
| model=daily_model, optimizer=None) | |
| state = resume( | |
| os.path.join(opt.resume_path, best_model_annual_file_name), | |
| model=annual_model, optimizer=None) | |
| daily_model = daily_model.eval() | |
| annual_model = annual_model.eval() | |
| st.title('Lombardia Sentinel 2 Crop Mapping') | |
| st.markdown('Using a daily FPN and giving a zip that contains 30 tiff named correctly you can reach prediction of crop mapping og the area.') | |
| file_uploaded = st.file_uploader( | |
| "Upload", | |
| type=["zip"], | |
| accept_multiple_files=False, | |
| ) | |
| sample_path = None | |
| if file_uploaded is not None: | |
| with zipfile.ZipFile(file_uploaded, "r") as z: | |
| z.extractall("uploaded_samples") | |
| sample_path = "uploaded_samples/" + file_uploaded.name[:-4] | |
| st.markdown('or use a demo sample') | |
| if st.button('sample 1'): | |
| sample_path = 'demo_data/lombardia' | |
| paths = None | |
| if sample_path is not None: | |
| st.markdown(f'elaborating {sample_path}...') | |
| validationdataset = SentinelDailyAnnualDatasetNoLabel( | |
| sample_path, | |
| opt.years, | |
| opt.classes_path, | |
| opt.sample_duration, | |
| opt.win_size, | |
| tileids=None) | |
| validationdataloader = torch.utils.data.DataLoader( | |
| validationdataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers) | |
| st.markdown(f'predict in progress...') | |
| out_dir = os.path.join(opt.result_path, "seg_maps") | |
| if not os.path.exists(out_dir): | |
| os.makedirs(out_dir) | |
| for i, (x_dailies, dates, dirs_path) in enumerate(validationdataloader): | |
| with torch.no_grad(): | |
| # x_dailies, dates, dirs_path = next(iter(validationdataloader)) | |
| # reshape merging the first two dimensions | |
| x_dailies = x_dailies.view(-1, *x_dailies.shape[2:]) | |
| if torch.cuda.is_available(): | |
| x_dailies = x_dailies.cuda() | |
| feat_daily, outs_daily = daily_model.forward(x_dailies) | |
| # return to original size of batch and year | |
| outs_daily = outs_daily.view( | |
| opt.batch_size, opt.sample_duration, *outs_daily.shape[1:]) | |
| feat_daily = feat_daily.view( | |
| opt.batch_size, opt.sample_duration, *feat_daily.shape[1:]) | |
| _, out_annual = annual_model.forward(feat_daily) | |
| pred_annual = torch.argmax(out_annual, dim=1).squeeze(1) | |
| pred_annual = pred_annual.cpu().numpy() | |
| # Remapping the labels | |
| pred_annual_nn = ids_to_labels( | |
| validationdataloader, pred_annual).astype(numpy.uint8) | |
| for batch in range(feat_daily.shape[0]): | |
| # _, profile = read(os.path.join(dirs_path[batch], '20191230_MSAVI.tif')) # todo get the last image | |
| _, tmp_path = get_patch_id(validationdataset.samples, 0) | |
| dates = get_all_dates( | |
| tmp_path, validationdataset.max_seq_length) | |
| last_tif_path = os.path.join(tmp_path, dates[-1] + ".tif") | |
| _, profile = read(last_tif_path) | |
| profile["name"] = dirs_path[batch] | |
| pth = dirs_path[batch].split(os.path.sep)[-3:] | |
| full_pth_patch = os.path.join( | |
| out_dir, pth[1] + '-' + pth[0], pth[2]) | |
| if not os.path.exists(full_pth_patch): | |
| os.makedirs(full_pth_patch) | |
| full_pth_pred = os.path.join( | |
| full_pth_patch, 'patch-pred-nn.tif') | |
| profile.update({ | |
| 'nodata': None, | |
| 'dtype': 'uint8', | |
| 'count': 1}) | |
| with rasterio.open(full_pth_pred, 'w', **profile) as dst: | |
| dst.write_band(1, pred_annual_nn[batch]) | |
| # patch_predictions = None | |
| for ch in range(len(dates)): | |
| soft_seg = outs_daily[batch, ch, :, :, :] | |
| # transform probs into a hard segmentation | |
| pred_daily = torch.argmax(soft_seg, dim=0) | |
| pred_daily = pred_daily.cpu() | |
| daily_pred = ids_to_labels( | |
| validationdataloader, pred_daily).astype(numpy.uint8) | |
| # if patch_predictions is None: | |
| # patch_predictions = numpy.expand_dims(daily_pred, axis=0) | |
| # else: | |
| # patch_predictions = numpy.concatenate((patch_predictions, numpy.expand_dims(daily_pred, axis=0)), | |
| # axis=0) | |
| # save GT image in opt.root_path | |
| full_pth_date = os.path.join( | |
| full_pth_patch, dates[ch][batch] + f'-ch{ch}-b{batch}-daily-pred.tif') | |
| profile.update({ | |
| 'nodata': None, | |
| 'dtype': 'uint8', | |
| 'count': 1}) | |
| with rasterio.open(full_pth_date, 'w', **profile) as dst: | |
| dst.write_band(1, daily_pred) | |
| st.markdown('End prediction') | |
| folder = "demo_data/results/seg_maps/example-lombardia/2" | |
| st.paths = os.listdir(folder) | |
| if st.paths is not None: | |
| folder = "demo_data/results/seg_maps/example-lombardia/2" | |
| file_picker = st.selectbox("Select day predict (annual is patch-pred-nn.tif)", | |
| st.paths, index=st.paths.index('patch-pred-nn.tif')) | |
| file_path = os.path.join(folder, file_picker) | |
| print(file_path) | |
| target, profile = read(file_path) | |
| target = np.squeeze(target) | |
| target = [classes_color_map[p] for p in target] | |
| fig, ax = plt.subplots() | |
| ax.imshow(target) | |
| markdown_legend = '' | |
| for c, l in zip(classes_color_map, labels_map): | |
| print(colors.to_hex(c)) | |
| markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>' | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.pyplot(fig) | |
| with col2: | |
| st.markdown(markdown_legend, unsafe_allow_html=True) | |