import json import os import random import string import subprocess import ipywidgets as widgets import yaml from autotrain.app.models import fetch_models from autotrain.app.params import get_task_params def generate_random_string(): prefix = "autotrain" part1 = "".join(random.choices(string.ascii_lowercase + string.digits, k=5)) part2 = "".join(random.choices(string.ascii_lowercase + string.digits, k=5)) return f"{prefix}-{part1}-{part2}" def colab_app(): if not os.path.exists("data"): os.makedirs("data") MODEL_CHOICES = fetch_models() TASK_NAMES = [ "LLM SFT", "LLM ORPO", "LLM Generic", "LLM DPO", "LLM Reward", "Text Classification", "Text Regression", "Sequence to Sequence", "Token Classification", "Image Classification", "Image Regression", "Object Detection", "Tabular Classification", "Tabular Regression", "ST Pair", "ST Pair Classification", "ST Pair Scoring", "ST Triplet", "ST Question Answering", ] TASK_MAP = { "LLM SFT": "llm:sft", "LLM ORPO": "llm:orpo", "LLM Generic": "llm:generic", "LLM DPO": "llm:dpo", "LLM Reward": "llm:reward", "Text Classification": "text-classification", "Text Regression": "text-regression", "Sequence to Sequence": "seq2seq", "Token Classification": "token-classification", "Image Classification": "image-classification", "Image Regression": "image-regression", "Object Detection": "image-object-detection", "Tabular Classification": "tabular:classification", "Tabular Regression": "tabular:regression", "ST Pair": "st:pair", "ST Pair Classification": "st:pair_class", "ST Pair Scoring": "st:pair_score", "ST Triplet": "st:triplet", "ST Question Answering": "st:qa", } def _get_params(task, param_type): _p = get_task_params(task, param_type=param_type) _p["push_to_hub"] = True _p = json.dumps(_p, indent=4) return _p hf_token_label = widgets.HTML("
Hugging Face Write Token
") hf_token = widgets.Password( value="", description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="200px") ) hf_user_label = widgets.HTML("
Hugging Face Username
") hf_user = widgets.Text( value="", description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="200px") ) base_model_label = widgets.HTML("
Base Model
") base_model = widgets.Text(value=MODEL_CHOICES["llm"][0], disabled=False, layout=widgets.Layout(width="420px")) project_name_label = widgets.HTML("
Project Name
") project_name = widgets.Text( value=generate_random_string(), description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="200px"), ) task_dropdown_label = widgets.HTML("
Task
") task_dropdown = widgets.Dropdown( options=TASK_NAMES, value=TASK_NAMES[0], description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="200px"), ) dataset_path_label = widgets.HTML("
Path
") dataset_path = widgets.Text( value="", description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="200px") ) train_split_label = widgets.HTML("
Train Split
") train_split = widgets.Text( value="", description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="200px") ) valid_split_label = widgets.HTML("
Valid Split
") valid_split = widgets.Text( value="", placeholder="optional", description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="200px"), ) dataset_source_dropdown_label = widgets.HTML("
Source
") dataset_source_dropdown = widgets.Dropdown( options=["Hugging Face Hub", "Local"], value="Hugging Face Hub", description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="200px"), ) col_mapping_label = widgets.HTML("
Column Mapping
") col_mapping = widgets.Text( value='{"text": "text"}', placeholder="", description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="420px"), ) parameters_dropdown = widgets.Dropdown( options=["Basic", "Full"], value="Basic", description="", disabled=False, layout=widgets.Layout(width="400px") ) parameters = widgets.Textarea( value=_get_params("llm:sft", "basic"), description="", disabled=False, layout=widgets.Layout(height="400px", width="400px"), ) start_training_button = widgets.Button( description="Start Training", layout=widgets.Layout(width="1000px"), disabled=False, button_style="", # 'success', 'info', 'warning', 'danger' or '' tooltip="Click to start training", icon="check", # (FontAwesome names without the `fa-` prefix) ) spacer = widgets.Box(layout=widgets.Layout(width="20px")) title_hbox0 = widgets.HTML("

Hugging Face Credentials

") title_hbox1 = widgets.HTML("

Project Details

") title_hbox2 = widgets.HTML("

Dataset Details

") title_hbox3 = widgets.HTML("

Parameters

") hbox0 = widgets.HBox( [ widgets.VBox([hf_token_label, hf_token]), spacer, widgets.VBox([hf_user_label, hf_user]), ] ) hbox1 = widgets.HBox( [ widgets.VBox([project_name_label, project_name]), spacer, widgets.VBox([task_dropdown_label, task_dropdown]), ] ) hbox2_1 = widgets.HBox( [ widgets.VBox([dataset_source_dropdown_label, dataset_source_dropdown]), spacer, widgets.VBox([dataset_path_label, dataset_path]), ] ) hbox2_2 = widgets.HBox( [ widgets.VBox([train_split_label, train_split]), spacer, widgets.VBox([valid_split_label, valid_split]), ] ) hbox2_3 = widgets.HBox( [ widgets.VBox([col_mapping_label, col_mapping]), ] ) hbox3 = widgets.VBox([parameters_dropdown, parameters]) vbox0 = widgets.VBox([title_hbox0, hbox0]) vbox1 = widgets.VBox([title_hbox1, base_model_label, base_model, hbox1]) vbox2 = widgets.VBox([title_hbox2, hbox2_1, hbox2_2, hbox2_3]) vbox3 = widgets.VBox([title_hbox3, hbox3]) left_column = widgets.VBox([vbox0, vbox1, vbox2], layout=widgets.Layout(width="500px")) right_column = widgets.VBox([vbox3], layout=widgets.Layout(width="500px", align_items="flex-end")) separator = widgets.HTML('
') _main_layout = widgets.HBox([left_column, separator, right_column]) main_layout = widgets.VBox([_main_layout, start_training_button]) def on_dataset_change(change): if change["new"] == "Local": dataset_path.value = "data/" train_split.value = "train" valid_split.value = "" else: dataset_path.value = "" train_split.value = "" valid_split.value = "" def update_parameters(*args): task = TASK_MAP[task_dropdown.value] param_type = parameters_dropdown.value.lower() parameters.value = _get_params(task, param_type) def update_col_mapping(*args): task = TASK_MAP[task_dropdown.value] if task in ["llm:sft", "llm:generic"]: col_mapping.value = '{"text": "text"}' dataset_source_dropdown.disabled = False valid_split.disabled = True elif task in ["llm:dpo", "llm:orpo"]: col_mapping.value = '{"prompt": "prompt", "text": "text", "rejected_text": "rejected_text"}' dataset_source_dropdown.disabled = False valid_split.disabled = True elif task == "llm:reward": col_mapping.value = '{"text": "text", "rejected_text": "rejected_text"}' dataset_source_dropdown.disabled = False valid_split.disabled = True elif task == "text-classification": col_mapping.value = '{"text": "text", "label": "target"}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "text-regression": col_mapping.value = '{"text": "text", "label": "target"}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "token-classification": col_mapping.value = '{"text": "tokens", "label": "tags"}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "seq2seq": col_mapping.value = '{"text": "text", "label": "target"}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "image-classification": col_mapping.value = '{"image": "image", "label": "label"}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "image-regression": col_mapping.value = '{"image": "image", "label": "target"}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "image-object-detection": col_mapping.value = '{"image": "image", "objects": "objects"}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "tabular:classification": col_mapping.value = '{"id": "id", "label": ["target"]}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "tabular:regression": col_mapping.value = '{"id": "id", "label": ["target"]}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "st:pair": col_mapping.value = '{"sentence1": "anchor", "sentence2": "positive"}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "st:pair_class": col_mapping.value = '{"sentence1": "premise", "sentence2": "hypothesis", "target": "label"}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "st:pair_score": col_mapping.value = '{"sentence1": "sentence1", "sentence2": "sentence2", "target": "score"}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "st:triplet": col_mapping.value = '{"sentence1": "anchor", "sentence2": "positive", "sentence3": "negative"}' dataset_source_dropdown.disabled = False valid_split.disabled = False elif task == "st:qa": col_mapping.value = '{"sentence1": "query", "sentence1": "answer"}' dataset_source_dropdown.disabled = False valid_split.disabled = False else: col_mapping.value = "Enter column mapping..." def update_base_model(*args): if TASK_MAP[task_dropdown.value] == "text-classification": base_model.value = MODEL_CHOICES["text-classification"][0] elif TASK_MAP[task_dropdown.value].startswith("llm"): base_model.value = MODEL_CHOICES["llm"][0] elif TASK_MAP[task_dropdown.value] == "image-classification": base_model.value = MODEL_CHOICES["image-classification"][0] elif TASK_MAP[task_dropdown.value] == "seq2seq": base_model.value = MODEL_CHOICES["seq2seq"][0] elif TASK_MAP[task_dropdown.value] == "tabular:classification": base_model.value = MODEL_CHOICES["tabular-classification"][0] elif TASK_MAP[task_dropdown.value] == "tabular:regression": base_model.value = MODEL_CHOICES["tabular-regression"][0] elif TASK_MAP[task_dropdown.value] == "token-classification": base_model.value = MODEL_CHOICES["token-classification"][0] elif TASK_MAP[task_dropdown.value] == "text-regression": base_model.value = MODEL_CHOICES["text-regression"][0] elif TASK_MAP[task_dropdown.value] == "image-object-detection": base_model.value = MODEL_CHOICES["image-object-detection"][0] elif TASK_MAP[task_dropdown.value].startswith("st:"): base_model.value = MODEL_CHOICES["sentence-transformers"][0] else: base_model.value = "Enter base model..." def start_training(b): start_training_button.disabled = True try: print("Training is starting... Please wait!") os.environ["HF_USERNAME"] = hf_user.value os.environ["HF_TOKEN"] = hf_token.value train_split_value = train_split.value.strip() if train_split.value.strip() != "" else None valid_split_value = valid_split.value.strip() if valid_split.value.strip() != "" else None params_val = json.loads(parameters.value) if task_dropdown.value.startswith("llm") or task_dropdown.value.startswith("sentence-transformers"): params_val["trainer"] = task_dropdown.value.split(":")[1] # params_val = {k: v for k, v in params_val.items() if k != "trainer"} chat_template = params_val.get("chat_template") if chat_template is not None: params_val = {k: v for k, v in params_val.items() if k != "chat_template"} push_to_hub = params_val.get("push_to_hub", True) if "push_to_hub" in params_val: params_val = {k: v for k, v in params_val.items() if k != "push_to_hub"} config = { "task": TASK_MAP[task_dropdown.value].split(":")[0], "base_model": base_model.value, "project_name": project_name.value, "log": "tensorboard", "backend": "local", "data": { "path": dataset_path.value, "train_split": train_split_value, "valid_split": valid_split_value, "column_mapping": json.loads(col_mapping.value), }, "params": params_val, "hub": { "username": "${{HF_USERNAME}}", "token": "${{HF_TOKEN}}", "push_to_hub": push_to_hub, }, } if TASK_MAP[task_dropdown.value].startswith("llm"): config["data"]["chat_template"] = chat_template if config["data"]["chat_template"] == "none": config["data"]["chat_template"] = None with open("config.yml", "w") as f: yaml.dump(config, f) cmd = "autotrain --config config.yml" process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) while True: output = process.stdout.readline() if output == "" and process.poll() is not None: break if output: print(output.strip()) poll_res = process.poll() if poll_res != 0: start_training_button.disabled = False raise Exception(f"Training failed with exit code: {poll_res}") print("Training completed successfully!") start_training_button.disabled = False except Exception as e: print("An error occurred while starting training!") print(f"Error: {e}") start_training_button.disabled = False start_training_button.on_click(start_training) dataset_source_dropdown.observe(on_dataset_change, names="value") task_dropdown.observe(update_col_mapping, names="value") task_dropdown.observe(update_parameters, names="value") task_dropdown.observe(update_base_model, names="value") parameters_dropdown.observe(update_parameters, names="value") return main_layout