PeterYu commited on
Commit
ffdaaba
·
1 Parent(s): a785d5a

update T4 env

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -47,6 +47,7 @@ class TimeSeriesEditor:
47
  (200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05)
48
  # 200,250,0,sin(2*pi*x),0.05
49
  ]
 
50
 
51
  def format_value(self, value: float, feature_idx: int) -> str:
52
  """Format value with appropriate units and notation"""
@@ -440,6 +441,10 @@ class TimeSeriesEditor:
440
 
441
  # Run prediction
442
  with torch.no_grad():
 
 
 
 
443
  sample = self.trainer.predict_weighted_points(
444
  observed_points, # (seq_length, feature_dim)
445
  observed_mask, # (seq_length, feature_dim)
@@ -614,7 +619,7 @@ class TimeSeriesEditor:
614
  def create_gradio_interface(editor: TimeSeriesEditor):
615
  with gr.Blocks() as app:
616
  gr.Markdown("# Time Series Editor")
617
- gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~20s-30s] [Running on CPU...]")
618
 
619
  metrics_display = gr.JSON(label="Metrics", value={})
620
 
@@ -1073,7 +1078,9 @@ if __name__ == "__main__":
1073
  print(os.getcwd())
1074
 
1075
  device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu"
1076
-
 
 
1077
  from models.Tiffusion import tiffusion
1078
 
1079
  model = tiffusion.Tiffusion(
 
47
  (200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05)
48
  # 200,250,0,sin(2*pi*x),0.05
49
  ]
50
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
 
52
  def format_value(self, value: float, feature_idx: int) -> str:
53
  """Format value with appropriate units and notation"""
 
441
 
442
  # Run prediction
443
  with torch.no_grad():
444
+ # to cuda
445
+ observed_points = observed_points.to(self.device)
446
+ observed_mask = observed_mask.to(self.device)
447
+
448
  sample = self.trainer.predict_weighted_points(
449
  observed_points, # (seq_length, feature_dim)
450
  observed_mask, # (seq_length, feature_dim)
 
619
  def create_gradio_interface(editor: TimeSeriesEditor):
620
  with gr.Blocks() as app:
621
  gr.Markdown("# Time Series Editor")
622
+ gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~20s]")
623
 
624
  metrics_display = gr.JSON(label="Metrics", value={})
625
 
 
1078
  print(os.getcwd())
1079
 
1080
  device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu"
1081
+ print(f"Device: {device}")
1082
+ print(f"Using device: {device}")
1083
+
1084
  from models.Tiffusion import tiffusion
1085
 
1086
  model = tiffusion.Tiffusion(