Spaces:
Running
Running
""" | |
Upload Fine-tuned TinyLlama Model to Hugging Face Hub | |
This script uploads a fine-tuned TinyLlama model to the Hugging Face Hub. | |
It handles authentication, model card creation, and repository management. | |
Usage: | |
python upload_to_huggingface.py --model_dir ./mt564_tinyllama_model --repo_name username/mt564-tinyllama | |
""" | |
import os | |
import argparse | |
import logging | |
from datetime import datetime | |
from typing import Optional, List, Dict, Any | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Upload fine-tuned TinyLlama model to Hugging Face Hub") | |
parser.add_argument( | |
"--model_dir", | |
type=str, | |
required=True, | |
help="Directory containing the fine-tuned model" | |
) | |
parser.add_argument( | |
"--repo_name", | |
type=str, | |
required=True, | |
help="Name for the Hugging Face repository (format: username/repo-name)" | |
) | |
parser.add_argument( | |
"--commit_message", | |
type=str, | |
default=f"Upload fine-tuned TinyLlama model - {datetime.now().strftime('%Y-%m-%d')}", | |
help="Commit message for the model upload" | |
) | |
parser.add_argument( | |
"--private", | |
action="store_true", | |
help="Make the repository private" | |
) | |
parser.add_argument( | |
"--create_model_card", | |
action="store_true", | |
default=True, | |
help="Create a model card README.md" | |
) | |
parser.add_argument( | |
"--base_model", | |
type=str, | |
default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
help="Base model name used for fine-tuning" | |
) | |
parser.add_argument( | |
"--tags", | |
type=str, | |
nargs="+", | |
default=["swift", "mt564", "financial", "tinyllama", "finance"], | |
help="Tags for the model" | |
) | |
return parser.parse_args() | |
def create_model_card( | |
base_model: str, | |
repo_name: str, | |
dataset_info: str = "SWIFT MT564 documentation", | |
tags: List[str] = None, | |
training_details: Dict[str, Any] = None | |
) -> str: | |
""" | |
Create a model card for the Hugging Face Hub | |
Args: | |
base_model: Name of the base model used for fine-tuning | |
repo_name: Name of the Hugging Face repository | |
dataset_info: Information about the dataset used | |
tags: Tags for the model | |
training_details: Dictionary with training hyperparameters | |
Returns: | |
Model card content as a string | |
""" | |
if tags is None: | |
tags = ["swift", "mt564", "finance", "tinyllama"] | |
if training_details is None: | |
training_details = { | |
"epochs": 3, | |
"learning_rate": "2e-5", | |
"batch_size": 2, | |
"gradient_accumulation_steps": 4, | |
"training_date": datetime.now().strftime("%Y-%m-%d") | |
} | |
repo_owner, repo_id = repo_name.split('/') | |
model_card = f"""--- | |
language: en | |
license: apache-2.0 | |
tags: | |
{chr(10).join([f'- {tag}' for tag in tags])} | |
datasets: | |
- custom | |
metrics: | |
- accuracy | |
--- | |
# {repo_id} | |
This is a fine-tuned version of [{base_model}](https://huggingface.co/{base_model}) specialized for understanding SWIFT MT564 message formats and financial documentation. | |
## Model Description | |
This model was fine-tuned on SWIFT MT564 documentation to help financial professionals understand and work with Corporate Action Notification messages. It can answer questions about message structure, field specifications, and usage guidelines for MT564 messages. | |
### Base Model | |
- **Base Model**: {base_model} | |
- **Model Type**: TinyLlama | |
- **Language**: English | |
- **Fine-tuning Focus**: SWIFT financial messaging formats, particularly MT564 | |
## Training Data | |
The model was fine-tuned on the following data: | |
- {dataset_info} | |
- The data includes message specifications, field descriptions, sequence structures, and usage guidelines | |
## Training Procedure | |
The model was fine-tuned with the following parameters: | |
- **Epochs**: {training_details['epochs']} | |
- **Learning Rate**: {training_details['learning_rate']} | |
- **Batch Size**: {training_details['batch_size']} | |
- **Gradient Accumulation Steps**: {training_details['gradient_accumulation_steps']} | |
- **Training Date**: {training_details['training_date']} | |
## Intended Use & Limitations | |
This model is specifically designed to: | |
- Answer questions about SWIFT MT564 message formats | |
- Assist with understanding Corporate Action Notifications | |
- Help parse and interpret MT564 messages | |
**Limitations**: | |
- This model specializes in MT564 and may have limited knowledge of other SWIFT message types | |
- The model should not be used for generating actual SWIFT messages for production systems | |
- Always verify critical financial information with official SWIFT documentation | |
## Usage | |
```python | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Load model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained("{repo_name}") | |
tokenizer = AutoTokenizer.from_pretrained("{repo_name}") | |
# Format prompt for the chat model | |
prompt = "<|im_start|>user\\nWhat is the purpose of Sequence A in MT564 messages?<|im_end|>\\n<|im_start|>assistant\\n" | |
# Tokenize and generate response | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(inputs.input_ids, max_new_tokens=256, temperature=0.7) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print(response) | |
``` | |
## Citation & Contact | |
If you use this model, please cite: | |
``` | |
@misc{{{repo_id.replace('-', '_').lower()}}}, | |
author = {{{repo_owner}}}, | |
title = {{{repo_id} - A fine-tuned TinyLlama model for SWIFT MT564 documentation}}, | |
year = {{{datetime.now().year}}}, | |
publisher = {Hugging Face}, | |
journal = {Hugging Face Repository}, | |
howpublished = {{https://huggingface.co/{repo_name}}}, | |
} | |
``` | |
For questions or feedback, please reach out through the [Hugging Face community](https://discuss.huggingface.co/) or the GitHub repository linked to this project. | |
""" | |
return model_card | |
def upload_to_hub( | |
model_dir: str, | |
repo_name: str, | |
commit_message: str = "Upload fine-tuned model", | |
private: bool = False, | |
create_card: bool = True, | |
base_model: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
tags: List[str] = None | |
) -> bool: | |
""" | |
Upload model to Hugging Face Hub | |
Args: | |
model_dir: Directory containing the fine-tuned model | |
repo_name: Name for the Hugging Face repository (username/repo-name) | |
commit_message: Commit message for the upload | |
private: Whether to make the repository private | |
create_card: Whether to create a model card | |
base_model: Base model used for fine-tuning | |
tags: Tags for the model | |
Returns: | |
Success status (True if upload was successful) | |
""" | |
try: | |
# Import libraries here so the script doesn't fail if they're not installed | |
from huggingface_hub import HfApi, create_repo | |
# Check if HUGGING_FACE_TOKEN environment variable is set | |
token = os.environ.get("HUGGING_FACE_TOKEN") | |
if not token: | |
logger.error("HUGGING_FACE_TOKEN environment variable is not set.") | |
logger.error("Set it using: export HUGGING_FACE_TOKEN=your_token_here") | |
return False | |
api = HfApi(token=token) | |
logger.info(f"Authenticated with Hugging Face Hub") | |
# Create repository if it doesn't exist | |
try: | |
repo_url = create_repo( | |
repo_id=repo_name, | |
private=private, | |
token=token, | |
exist_ok=True | |
) | |
logger.info(f"Repository created/accessed: {repo_url}") | |
except Exception as e: | |
logger.error(f"Error creating repository: {e}") | |
return False | |
# Create and save model card if requested | |
if create_card: | |
logger.info("Creating model card") | |
model_card_content = create_model_card( | |
base_model=base_model, | |
repo_name=repo_name, | |
tags=tags | |
) | |
model_card_path = os.path.join(model_dir, "README.md") | |
with open(model_card_path, "w", encoding="utf-8") as f: | |
f.write(model_card_content) | |
logger.info(f"Model card saved to {model_card_path}") | |
# Upload model to Hub | |
logger.info(f"Uploading model from {model_dir} to {repo_name}") | |
api.upload_folder( | |
folder_path=model_dir, | |
repo_id=repo_name, | |
commit_message=commit_message | |
) | |
logger.info(f"Model successfully uploaded to {repo_name}") | |
logger.info(f"View your model at: https://huggingface.co/{repo_name}") | |
return True | |
except ImportError as e: | |
logger.error(f"Required libraries not installed: {e}") | |
logger.error("Please install huggingface_hub: pip install huggingface_hub") | |
return False | |
except Exception as e: | |
logger.error(f"Error uploading model: {e}") | |
return False | |
def main(): | |
args = parse_args() | |
# Check if model directory exists | |
if not os.path.isdir(args.model_dir): | |
logger.error(f"Model directory {args.model_dir} does not exist") | |
return | |
# Check if model files exist | |
model_files = [ | |
"pytorch_model.bin", "config.json", | |
"generation_config.json", "tokenizer_config.json", | |
"tokenizer.json", "special_tokens_map.json" | |
] | |
missing_files = [f for f in model_files if not os.path.exists(os.path.join(args.model_dir, f))] | |
if missing_files: | |
logger.warning(f"The following model files are missing: {', '.join(missing_files)}") | |
logger.warning("The model might be in a different format or incomplete") | |
# Upload to Hugging Face Hub | |
logger.info(f"Uploading model from {args.model_dir} to {args.repo_name}") | |
success = upload_to_hub( | |
model_dir=args.model_dir, | |
repo_name=args.repo_name, | |
commit_message=args.commit_message, | |
private=args.private, | |
create_card=args.create_model_card, | |
base_model=args.base_model, | |
tags=args.tags | |
) | |
if success: | |
logger.info(f"Model upload complete! Your model is now available at: https://huggingface.co/{args.repo_name}") | |
logger.info("You can use it with Transformers library:") | |
logger.info(f"from transformers import AutoModelForCausalLM, AutoTokenizer") | |
logger.info(f"model = AutoModelForCausalLM.from_pretrained('{args.repo_name}')") | |
logger.info(f"tokenizer = AutoTokenizer.from_pretrained('{args.repo_name}')") | |
else: | |
logger.error("Model upload failed. Please check the error messages above.") | |
if __name__ == "__main__": | |
main() |