import gradio as gr import psutil import torch # Check for CUDA GPU availability and select appropriate memory type def get_available_memory(): if torch.cuda.is_available(): # Check if a CUDA-capable NVIDIA GPU is available total_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) # Convert to GB return total_memory, "GPU RAM" else: total_memory = psutil.virtual_memory().total / (1024 ** 3) # Convert to GB return total_memory, "System RAM" total_memory, memory_type = get_available_memory() # Calculate the maximum instances the system can handle max_instances = int(total_memory // 4.5) # Each instance uses 4.5 GB # Function to handle user input and display the bar, text, and warnings def update_usage(num_instances): if num_instances > max_instances: num_instances = max_instances # Snap back to the maximum valid value warning = f"⚠️ You tried to exceed the available {memory_type}. Max instances set to {max_instances}." else: warning = "" # No warning if within limits usage = num_instances * 4.5 usage_percentage = min((usage / total_memory) * 100, 100) # Progress bar HTML with a blue gradient bar_html = f"""
""" memory_text = f"{usage:.2f} GB / {total_memory:.2f} GB ({memory_type})" return num_instances, memory_text, bar_html, warning # Gradio Interface with gr.Blocks() as demo: gr.Markdown(f"# Memory Usage Tracker ({memory_type})") slider = gr.Slider(minimum=1, maximum=max_instances, step=1, value=1, label="Number of Instances") memory_display = gr.Label(value=f"0.00 GB / {total_memory:.2f} GB ({memory_type})") progress_bar = gr.HTML(value="
") # Placeholder for progress bar warning_message = gr.HTML(value="") # Placeholder for warning message def handle_slider_change(num_instances): # Update the UI elements when the slider value changes num_instances, memory_text, bar_html, warning = update_usage(num_instances) return num_instances, memory_text, bar_html, warning # Link the slider to the update function slider.change(handle_slider_change, inputs=[slider], outputs=[slider, memory_display, progress_bar, warning_message]) demo.launch()