remove redundant torch.cat()
Browse files
app.py
CHANGED
@@ -47,7 +47,7 @@ def process(path, progress = gr.Progress(), device = 'cpu'):
|
|
47 |
progress(0, desc="Recognizing..")
|
48 |
batch = torch.cat(batch).to(device)
|
49 |
print("batch.shape", batch.shape)
|
50 |
-
generated_ids = model.generate(
|
51 |
progress(0, desc="Decoding (token -> str)")
|
52 |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
53 |
|
|
|
47 |
progress(0, desc="Recognizing..")
|
48 |
batch = torch.cat(batch).to(device)
|
49 |
print("batch.shape", batch.shape)
|
50 |
+
generated_ids = model.generate(batch)
|
51 |
progress(0, desc="Decoding (token -> str)")
|
52 |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
53 |
|