Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import json | |
import os | |
import uuid | |
import spaces | |
import re | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from peft import PeftModel, AutoPeftModelForCausalLM | |
import io | |
from openpyxl import load_workbook | |
from typing import List, Dict, Any, Tuple | |
from utils import * | |
# base_model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO" | |
# lora_path = "tat-llm-final-e4" | |
# base_model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float16) | |
# model = PeftModel.from_pretrained(base_model, lora_path) | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# model = model.to(device) | |
# model.eval() | |
# tokenizer = AutoTokenizer.from_pretrained(lora_path) | |
def generate_answer(json_data: Dict[str, Any], question: str) -> str: | |
""" | |
Generate answer using the fine-tuned model. | |
""" | |
base_model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO" | |
lora_path = "tat-llm-final-e4" | |
# Load base model and LoRA adapter | |
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float16) | |
model = PeftModel.from_pretrained(base_model, lora_path) | |
tokenizer = AutoTokenizer.from_pretrained(lora_path) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() | |
prompt = create_prompt(json_data, question) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) | |
# Move to GPU if available | |
device = next(model.parameters()).device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
input_length = inputs["input_ids"].shape[1] | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=1024, | |
do_sample=False, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
generated_tokens = outputs[0][input_length:] | |
answer = tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
return answer | |
# Gradio interface functions | |
def process_xlsx(file): | |
""" | |
Process uploaded XLSX file and return JSON, JSONL, and Markdown. | |
""" | |
if file is None: | |
return None, "", "", "" | |
try: | |
json_data = xlsx_to_json(file.name) | |
json_str = json.dumps(json_data, indent=2, ensure_ascii=False) | |
jsonl_str = json_to_jsonl(json_data) | |
markdown_str = json_to_markdown(json_data) | |
return json_data, json_str, jsonl_str, markdown_str | |
except Exception as e: | |
return None, f"Error: {str(e)}", "", "" | |
def chat_interface(json_data, question, history): | |
""" | |
Chat interface for Q&A. | |
""" | |
if json_data is None: | |
return history + [[question, "Please upload an XLSX file first."]] | |
if not question.strip(): | |
return history + [[question, "Please enter a question."]] | |
try: | |
answer = generate_answer(json_data, question) | |
return history + [[question, answer]] | |
except Exception as e: | |
return history + [[question, f"Error generating answer: {str(e)}"]] | |
# Gradio UI | |
with gr.Blocks(title="terTATa-LLM: Dari Tabel dan Teks Menjadi Langkah Bisnis Strategis", theme=gr.themes.Soft()) as demo: | |
gr.HTML(""" | |
<style> | |
body, .gradio-container { | |
font-family: 'Poppins', sans-serif; | |
} | |
h1, h2, h3, h4, h5 { | |
font-family: 'Poppins', sans-serif; | |
} | |
</style> | |
<link href="https://fonts.googleapis.com/css2?family=Poppins&display=swap" rel="stylesheet"> | |
""") | |
gr.Markdown(""" | |
# terTATa-LLM: Dari Tabel dan Teks Menjadi Langkah Bisnis Strategis | |
Unggah berkas XLSX berisi tabel dan paragraf, lalu ajukan pertanyaan tentang data tersebut. | |
Sistem akan mengonversi berkas Anda ke format JSON dan menggunakan model terTATa-LLM untuk menjawab pertanyaan. | |
""") | |
json_data_state = gr.State() | |
with gr.Row(): | |
with gr.Column(scale=1): | |
file_input = gr.File( | |
label="Upload XLSX File", | |
file_types=[".xlsx"], | |
type="filepath" | |
) | |
process_btn = gr.Button("Process File", variant="primary") | |
with gr.Tabs(): | |
with gr.Tab("Markdown Preview"): | |
markdown_output = gr.Markdown(label="Markdown Preview") | |
with gr.Tab("JSON Output"): | |
json_output = gr.Code( | |
label="JSON Format", | |
language="json", | |
lines=15 | |
) | |
with gr.Tab("JSONL Output"): | |
jsonl_output = gr.Code( | |
label="JSONL Format", | |
language="json", | |
lines=5 | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("### Ajukan Pertanyaan Mengenai Data Anda") | |
chatbot = gr.Chatbot(height=400) | |
msg = gr.Textbox( | |
label="Prompt", | |
placeholder="Ajukan pertanyaan tentang data tabel...", | |
lines=2 | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant="primary") | |
clear_btn = gr.Button("Clear Chat") | |
gr.Examples( | |
examples=[ | |
"Apa saja wawasan yang bisa kita ambil dari data ini?", | |
"Bagaimana perubahan dari tahun ke tahun?", | |
"Apa saja tren utama yang terlihat dalam data?", | |
"Hitung persentase perubahan antar tahun!", | |
"Rekomendasi apa yang dapat diberikan berdasarkan data ini?" | |
], | |
inputs=msg | |
) | |
process_btn.click( | |
fn=process_xlsx, | |
inputs=[file_input], | |
outputs=[json_data_state, json_output, jsonl_output, markdown_output] | |
) | |
msg.submit( | |
fn=chat_interface, | |
inputs=[json_data_state, msg, chatbot], | |
outputs=[chatbot] | |
).then( | |
lambda: "", | |
outputs=[msg] | |
) | |
submit_btn.click( | |
fn=chat_interface, | |
inputs=[json_data_state, msg, chatbot], | |
outputs=[chatbot] | |
).then( | |
lambda: "", | |
outputs=[msg] | |
) | |
clear_btn.click( | |
lambda: [], | |
outputs=[chatbot] | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(share=True) |