t0-0
commited on
Commit
·
559d198
1
Parent(s):
bd95334
Remove 'auto' from Enum and add handling for submissions with 'auto'.
Browse files- app.py +1 -1
- src/display/utils.py +0 -3
- src/submission/submit.py +25 -7
app.py
CHANGED
|
@@ -579,7 +579,7 @@ with gr.Blocks() as demo_submission:
|
|
| 579 |
with gr.Column():
|
| 580 |
precision = gr.Dropdown(
|
| 581 |
label="Precision",
|
| 582 |
-
choices=[i.value.name for i in Precision],
|
| 583 |
multiselect=False,
|
| 584 |
value="auto",
|
| 585 |
)
|
|
|
|
| 579 |
with gr.Column():
|
| 580 |
precision = gr.Dropdown(
|
| 581 |
label="Precision",
|
| 582 |
+
choices=[i.value.name for i in Precision] + ["auto"],
|
| 583 |
multiselect=False,
|
| 584 |
value="auto",
|
| 585 |
)
|
src/display/utils.py
CHANGED
|
@@ -129,15 +129,12 @@ class WeightType(Enum):
|
|
| 129 |
|
| 130 |
|
| 131 |
class Precision(Enum):
|
| 132 |
-
auto = ModelDetails("auto")
|
| 133 |
float16 = ModelDetails("float16")
|
| 134 |
bfloat16 = ModelDetails("bfloat16")
|
| 135 |
float32 = ModelDetails("float32")
|
| 136 |
|
| 137 |
@staticmethod
|
| 138 |
def from_str(precision: str) -> "Precision":
|
| 139 |
-
if precision == "auto":
|
| 140 |
-
return Precision.auto
|
| 141 |
if precision == "float16":
|
| 142 |
return Precision.float16
|
| 143 |
if precision == "bfloat16":
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
class Precision(Enum):
|
|
|
|
| 132 |
float16 = ModelDetails("float16")
|
| 133 |
bfloat16 = ModelDetails("bfloat16")
|
| 134 |
float32 = ModelDetails("float32")
|
| 135 |
|
| 136 |
@staticmethod
|
| 137 |
def from_str(precision: str) -> "Precision":
|
|
|
|
|
|
|
| 138 |
if precision == "float16":
|
| 139 |
return Precision.float16
|
| 140 |
if precision == "bfloat16":
|
src/submission/submit.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import json
|
| 2 |
from datetime import datetime, timezone
|
| 3 |
|
|
|
|
|
|
|
| 4 |
from src.display.formatting import styled_error, styled_message, styled_warning
|
| 5 |
from src.display.utils import EvalQueuedModel, LLMJpEvalVersion, VllmVersion
|
| 6 |
from src.envs import API, EVAL_REQUESTS_PATH, HF_TOKEN, QUEUE_REPO
|
|
@@ -25,6 +27,29 @@ def add_new_eval(
|
|
| 25 |
|
| 26 |
revision = revision or "main"
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
model_data = EvalQueuedModel(
|
| 29 |
model=model_id,
|
| 30 |
revision=revision,
|
|
@@ -47,13 +72,6 @@ def add_new_eval(
|
|
| 47 |
if model_type is None or model_type == "":
|
| 48 |
return styled_error("Please select a model type.")
|
| 49 |
|
| 50 |
-
# Is the model on the hub?
|
| 51 |
-
model_on_hub, error, _ = is_model_on_hub(
|
| 52 |
-
model_name=model_id, revision=revision, token=HF_TOKEN, test_tokenizer=True
|
| 53 |
-
)
|
| 54 |
-
if not model_on_hub:
|
| 55 |
-
return styled_error(f'Model "{model_id}" {error}')
|
| 56 |
-
|
| 57 |
# Is the model info correctly filled?
|
| 58 |
try:
|
| 59 |
model_info = API.model_info(repo_id=model_id, revision=revision)
|
|
|
|
| 1 |
import json
|
| 2 |
from datetime import datetime, timezone
|
| 3 |
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
from src.display.formatting import styled_error, styled_message, styled_warning
|
| 7 |
from src.display.utils import EvalQueuedModel, LLMJpEvalVersion, VllmVersion
|
| 8 |
from src.envs import API, EVAL_REQUESTS_PATH, HF_TOKEN, QUEUE_REPO
|
|
|
|
| 27 |
|
| 28 |
revision = revision or "main"
|
| 29 |
|
| 30 |
+
# Is the model on the hub?
|
| 31 |
+
model_on_hub, error, config = is_model_on_hub(
|
| 32 |
+
model_name=model_id, revision=revision, token=HF_TOKEN, test_tokenizer=True
|
| 33 |
+
)
|
| 34 |
+
if not model_on_hub:
|
| 35 |
+
return styled_error(f'Model "{model_id}" {error}')
|
| 36 |
+
if precision == "auto":
|
| 37 |
+
dtype = ""
|
| 38 |
+
if hasattr(config, "dtype"):
|
| 39 |
+
dtype = config.dtype
|
| 40 |
+
elif hasattr(config, "torch_dtype"):
|
| 41 |
+
dtype = config.torch_dtype
|
| 42 |
+
if dtype == torch.float16:
|
| 43 |
+
precision = "float16"
|
| 44 |
+
elif dtype in torch.bfloat16:
|
| 45 |
+
precision = "bfloat16"
|
| 46 |
+
elif dtype in torch.float32:
|
| 47 |
+
precision = "float32"
|
| 48 |
+
else:
|
| 49 |
+
return styled_error(
|
| 50 |
+
"Unable to retrieve a valid dtype from config.json. Please select an appropriate one from fp16/fp32/bf16 and resubmit."
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
model_data = EvalQueuedModel(
|
| 54 |
model=model_id,
|
| 55 |
revision=revision,
|
|
|
|
| 72 |
if model_type is None or model_type == "":
|
| 73 |
return styled_error("Please select a model type.")
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
# Is the model info correctly filled?
|
| 76 |
try:
|
| 77 |
model_info = API.model_info(repo_id=model_id, revision=revision)
|