|
import os |
|
import asyncio |
|
import logging |
|
from typing import Optional, List, Union, Literal |
|
from pathlib import Path |
|
from pydantic import BaseModel, Field |
|
from gradio import Interface, Blocks, Component |
|
from gradio.data_classes import FileData, GradioModel, GradioRootModel |
|
from transformers import pipeline |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
import gradio as gr |
|
|
|
|
|
image_model = DiffusionPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-dev", |
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
use_auth_token=os.getenv("HUGGINGFACE_TOKEN") |
|
) |
|
image_model.enable_model_cpu_offload() |
|
|
|
|
|
class FileDataDict(BaseModel): |
|
path: str |
|
url: Optional[str] = None |
|
size: Optional[int] = None |
|
orig_name: Optional[str] = None |
|
mime_type: Optional[str] = None |
|
is_stream: Optional[bool] = False |
|
class Config: |
|
arbitrary_types_allowed = True |
|
|
|
class MessageDict(BaseModel): |
|
content: Union[str, FileDataDict, tuple, Component] |
|
role: Literal["user", "assistant", "system"] |
|
metadata: Optional[dict] = None |
|
options: Optional[List[dict]] = None |
|
class Config: |
|
arbitrary_types_allowed = True |
|
|
|
class ChatMessage(GradioModel): |
|
role: Literal["user", "assistant", "system"] |
|
content: Union[str, FileData, Component] |
|
metadata: dict = Field(default_factory=dict) |
|
options: Optional[List[dict]] = None |
|
class Config: |
|
arbitrary_types_allowed = True |
|
|
|
class ChatbotDataMessages(GradioRootModel): |
|
root: List[ChatMessage] |
|
|
|
|
|
class UniversalReasoning: |
|
def __init__(self, config): |
|
self.config = config |
|
self.context_history = [] |
|
self.sentiment_analyzer = pipeline("sentiment-analysis") |
|
|
|
self.deepseek_model = pipeline( |
|
"text-classification", |
|
model="distilbert-base-uncased-finetuned-sst-2-english", |
|
truncation=True |
|
) |
|
|
|
self.davinci_model = pipeline( |
|
"text2text-generation", |
|
model="t5-small", |
|
truncation=True |
|
) |
|
|
|
self.additional_model = pipeline( |
|
"text-generation", |
|
model="EleutherAI/gpt-neo-125M", |
|
truncation=True |
|
) |
|
|
|
self.image_model = image_model |
|
|
|
async def generate_response(self, question: str) -> str: |
|
self.context_history.append(question) |
|
sentiment_score = self.analyze_sentiment(question) |
|
deepseek_response = self.deepseek_model(question) |
|
davinci_response = self.davinci_model(question, max_length=50) |
|
additional_response = self.additional_model(question, max_length=100) |
|
|
|
responses = [ |
|
f"Sentiment score: {sentiment_score}", |
|
f"DeepSeek Response: {deepseek_response}", |
|
f"T5 Response: {davinci_response}", |
|
f"Additional Model Response: {additional_response}" |
|
] |
|
return "\n\n".join(responses) |
|
|
|
def generate_image(self, prompt: str): |
|
image = self.image_model( |
|
prompt, |
|
height=1024, |
|
width=1024, |
|
guidance_scale=3.5, |
|
num_inference_steps=50, |
|
max_sequence_length=512, |
|
generator=torch.Generator('cpu').manual_seed(0) |
|
).images[0] |
|
image.save("flux-dev.png") |
|
return image |
|
|
|
def analyze_sentiment(self, text: str) -> list: |
|
sentiment_score = self.sentiment_analyzer(text) |
|
logging.info(f"Sentiment analysis result: {sentiment_score}") |
|
return sentiment_score |
|
|
|
|
|
class MultimodalChatbot(Component): |
|
def __init__( |
|
self, |
|
value: Optional[List[MessageDict]] = None, |
|
label: Optional[str] = None, |
|
render: bool = True, |
|
log_file: Optional[Path] = None, |
|
): |
|
value = value or [] |
|
super().__init__(label=label, value=value) |
|
self.log_file = log_file |
|
self.render = render |
|
self.data_model = ChatbotDataMessages |
|
self.universal_reasoning = UniversalReasoning({}) |
|
|
|
def preprocess(self, payload: Optional[ChatbotDataMessages]) -> List[MessageDict]: |
|
return payload.root if payload else [] |
|
|
|
def postprocess(self, messages: Optional[List[MessageDict]]) -> ChatbotDataMessages: |
|
messages = messages or [] |
|
return ChatbotDataMessages(root=messages) |
|
|
|
|
|
class HuggingFaceChatbot: |
|
def __init__(self): |
|
self.chatbot = MultimodalChatbot(value=[]) |
|
|
|
def setup_interface(self): |
|
async def chatbot_logic(input_text: str) -> str: |
|
return await self.chatbot.universal_reasoning.generate_response(input_text) |
|
|
|
def image_logic(prompt: str): |
|
return self.chatbot.universal_reasoning.generate_image(prompt) |
|
|
|
interface = Interface( |
|
fn=chatbot_logic, |
|
inputs="text", |
|
outputs="text", |
|
title="Hugging Face Multimodal Chatbot", |
|
) |
|
|
|
image_interface = Interface( |
|
fn=image_logic, |
|
inputs="text", |
|
outputs="image", |
|
title="Image Generator", |
|
) |
|
|
|
return Blocks([interface, image_interface]) |
|
|
|
def launch(self): |
|
interface = self.setup_interface() |
|
interface.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') |
|
chatbot = HuggingFaceChatbot() |
|
chatbot.launch() |