|
|
import os |
|
|
import csv |
|
|
import json |
|
|
import shutil |
|
|
from typing import Optional, List, Any |
|
|
from huggingface_hub import login |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from tools import DEFAULT_SYSTEM_MSG |
|
|
|
|
|
|
|
|
def authenticate_hf(token: Optional[str]) -> None: |
|
|
"""Logs into the Hugging Face Hub.""" |
|
|
if token: |
|
|
print("Logging into Hugging Face Hub...") |
|
|
login(token=token) |
|
|
else: |
|
|
print("Skipping Hugging Face login: HF_TOKEN not set.") |
|
|
|
|
|
def load_model_and_tokenizer(model_name: str): |
|
|
print(f"Loading Transformer model: {model_name}") |
|
|
try: |
|
|
target_model = model_name |
|
|
if model_name.startswith("..") and not os.path.exists(model_name): |
|
|
print(f"Warning: Local path {model_name} not found. Falling back to default hub model.") |
|
|
target_model = "google/gemma-2b-it" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(target_model) |
|
|
model = AutoModelForCausalLM.from_pretrained(target_model) |
|
|
print("Model loaded successfully.") |
|
|
return model, tokenizer |
|
|
except Exception as e: |
|
|
print(f"Error loading Transformer model {target_model}: {e}") |
|
|
raise e |
|
|
|
|
|
|
|
|
def create_conversation_format(sample, tools_list): |
|
|
"""Formats a dataset row into the conversational format required for SFT.""" |
|
|
try: |
|
|
tool_args = json.loads(sample["tool_arguments"]) |
|
|
except (json.JSONDecodeError, TypeError): |
|
|
tool_args = {} |
|
|
|
|
|
return { |
|
|
"messages": [ |
|
|
{"role": "developer", "content": DEFAULT_SYSTEM_MSG}, |
|
|
{"role": "user", "content": sample["user_content"]}, |
|
|
{"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": sample["tool_name"], "arguments": tool_args}}]}, |
|
|
], |
|
|
"tools": tools_list |
|
|
} |
|
|
|
|
|
def parse_csv_dataset(file_path: str) -> List[List[str]]: |
|
|
"""Parses an uploaded CSV file.""" |
|
|
dataset = [] |
|
|
if not file_path: |
|
|
return dataset |
|
|
|
|
|
with open(file_path, 'r', newline='', encoding='utf-8') as f: |
|
|
reader = csv.reader(f) |
|
|
try: |
|
|
header = next(reader) |
|
|
if not (header and "user_content" in header[0].lower()): |
|
|
f.seek(0) |
|
|
except StopIteration: |
|
|
return dataset |
|
|
|
|
|
for row in reader: |
|
|
if len(row) >= 3: |
|
|
dataset.append([s.strip() for s in row[:3]]) |
|
|
return dataset |
|
|
|
|
|
def zip_directory(source_dir: str, output_name_base: str) -> str: |
|
|
"""Zips a directory.""" |
|
|
return shutil.make_archive( |
|
|
base_name=output_name_base, |
|
|
format='zip', |
|
|
root_dir=source_dir, |
|
|
) |
|
|
|