Spaces:
Sleeping
Sleeping
update T4 env
Browse files
app.py
CHANGED
@@ -47,6 +47,7 @@ class TimeSeriesEditor:
|
|
47 |
(200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05)
|
48 |
# 200,250,0,sin(2*pi*x),0.05
|
49 |
]
|
|
|
50 |
|
51 |
def format_value(self, value: float, feature_idx: int) -> str:
|
52 |
"""Format value with appropriate units and notation"""
|
@@ -440,6 +441,10 @@ class TimeSeriesEditor:
|
|
440 |
|
441 |
# Run prediction
|
442 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
443 |
sample = self.trainer.predict_weighted_points(
|
444 |
observed_points, # (seq_length, feature_dim)
|
445 |
observed_mask, # (seq_length, feature_dim)
|
@@ -614,7 +619,7 @@ class TimeSeriesEditor:
|
|
614 |
def create_gradio_interface(editor: TimeSeriesEditor):
|
615 |
with gr.Blocks() as app:
|
616 |
gr.Markdown("# Time Series Editor")
|
617 |
-
gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~20s
|
618 |
|
619 |
metrics_display = gr.JSON(label="Metrics", value={})
|
620 |
|
@@ -1073,7 +1078,9 @@ if __name__ == "__main__":
|
|
1073 |
print(os.getcwd())
|
1074 |
|
1075 |
device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu"
|
1076 |
-
|
|
|
|
|
1077 |
from models.Tiffusion import tiffusion
|
1078 |
|
1079 |
model = tiffusion.Tiffusion(
|
|
|
47 |
(200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05)
|
48 |
# 200,250,0,sin(2*pi*x),0.05
|
49 |
]
|
50 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
51 |
|
52 |
def format_value(self, value: float, feature_idx: int) -> str:
|
53 |
"""Format value with appropriate units and notation"""
|
|
|
441 |
|
442 |
# Run prediction
|
443 |
with torch.no_grad():
|
444 |
+
# to cuda
|
445 |
+
observed_points = observed_points.to(self.device)
|
446 |
+
observed_mask = observed_mask.to(self.device)
|
447 |
+
|
448 |
sample = self.trainer.predict_weighted_points(
|
449 |
observed_points, # (seq_length, feature_dim)
|
450 |
observed_mask, # (seq_length, feature_dim)
|
|
|
619 |
def create_gradio_interface(editor: TimeSeriesEditor):
|
620 |
with gr.Blocks() as app:
|
621 |
gr.Markdown("# Time Series Editor")
|
622 |
+
gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~20s]")
|
623 |
|
624 |
metrics_display = gr.JSON(label="Metrics", value={})
|
625 |
|
|
|
1078 |
print(os.getcwd())
|
1079 |
|
1080 |
device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu"
|
1081 |
+
print(f"Device: {device}")
|
1082 |
+
print(f"Using device: {device}")
|
1083 |
+
|
1084 |
from models.Tiffusion import tiffusion
|
1085 |
|
1086 |
model = tiffusion.Tiffusion(
|