ignore data
Browse files- .gitignore +3 -1
- inference.py +18 -12
.gitignore
CHANGED
|
@@ -1,2 +1,4 @@
|
|
| 1 |
hf_cache
|
| 2 |
-
__pycache__
|
|
|
|
|
|
|
|
|
| 1 |
hf_cache
|
| 2 |
+
__pycache__
|
| 3 |
+
.DS_Store
|
| 4 |
+
data/
|
inference.py
CHANGED
|
@@ -66,6 +66,9 @@ class InferenceDataLoader:
|
|
| 66 |
with rasterio.open(path) as src:
|
| 67 |
# Transform the coordinates from WGS84 to UTM (EPSG:32632)
|
| 68 |
utm_x, utm_y = self.transformer.transform(lon, lat)
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
try:
|
| 71 |
px, py = rowcol(src.transform, utm_x, utm_y)
|
|
@@ -77,20 +80,23 @@ class InferenceDataLoader:
|
|
| 77 |
|
| 78 |
half_window_size = self.window_size // 2
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
|
| 83 |
-
if col_off < 0:
|
| 84 |
-
col_off = 0
|
| 85 |
if row_off < 0:
|
| 86 |
row_off = 0
|
| 87 |
-
if col_off
|
| 88 |
-
col_off =
|
| 89 |
-
if row_off + self.window_size > src.
|
| 90 |
-
row_off = src.
|
|
|
|
|
|
|
| 91 |
|
| 92 |
window = Window(col_off, row_off, self.window_size, self.window_size)
|
| 93 |
window_transform = src.window_transform(window)
|
|
|
|
|
|
|
|
|
|
| 94 |
crs = src.crs
|
| 95 |
|
| 96 |
return window, window_transform, crs
|
|
@@ -193,10 +199,10 @@ def crop_predictions_to_gdf(field_ids, targets, predictions, transform, crs, cla
|
|
| 193 |
return gdf
|
| 194 |
|
| 195 |
def perform_inference(lon, lat, model, config, debug=False):
|
| 196 |
-
features_path = "
|
| 197 |
-
labels_path = "
|
| 198 |
-
field_ids_path = "
|
| 199 |
-
stats_path = "
|
| 200 |
|
| 201 |
loader = InferenceDataLoader(features_path, labels_path, field_ids_path, stats_path, n_timesteps=9, fold_indices=[0], debug=True)
|
| 202 |
|
|
|
|
| 66 |
with rasterio.open(path) as src:
|
| 67 |
# Transform the coordinates from WGS84 to UTM (EPSG:32632)
|
| 68 |
utm_x, utm_y = self.transformer.transform(lon, lat)
|
| 69 |
+
if self.debug:
|
| 70 |
+
print("Source Transform", src.transform)
|
| 71 |
+
print(f"UTM X: {utm_x}, UTM Y: {utm_y}")
|
| 72 |
|
| 73 |
try:
|
| 74 |
px, py = rowcol(src.transform, utm_x, utm_y)
|
|
|
|
| 80 |
|
| 81 |
half_window_size = self.window_size // 2
|
| 82 |
|
| 83 |
+
row_off = px - half_window_size
|
| 84 |
+
col_off = py - half_window_size
|
| 85 |
|
|
|
|
|
|
|
| 86 |
if row_off < 0:
|
| 87 |
row_off = 0
|
| 88 |
+
if col_off < 0:
|
| 89 |
+
col_off = 0
|
| 90 |
+
if row_off + self.window_size > src.width:
|
| 91 |
+
row_off = src.width - self.window_size
|
| 92 |
+
if col_off + self.window_size > src.height:
|
| 93 |
+
col_off = src.height - self.window_size
|
| 94 |
|
| 95 |
window = Window(col_off, row_off, self.window_size, self.window_size)
|
| 96 |
window_transform = src.window_transform(window)
|
| 97 |
+
if self.debug:
|
| 98 |
+
print(f"Window: {window}")
|
| 99 |
+
print(f"Window Transform: {window_transform}")
|
| 100 |
crs = src.crs
|
| 101 |
|
| 102 |
return window, window_transform, crs
|
|
|
|
| 199 |
return gdf
|
| 200 |
|
| 201 |
def perform_inference(lon, lat, model, config, debug=False):
|
| 202 |
+
features_path = "../data/stacked_features.tif"
|
| 203 |
+
labels_path = "../data/labels.tif"
|
| 204 |
+
field_ids_path = "../data/field_ids.tif"
|
| 205 |
+
stats_path = "../data/chips_stats.yaml"
|
| 206 |
|
| 207 |
loader = InferenceDataLoader(features_path, labels_path, field_ids_path, stats_path, n_timesteps=9, fold_indices=[0], debug=True)
|
| 208 |
|