ai-astro / app.py
Prak2005's picture
Update app.py
3bcb6cd verified
import gradio as gr
import torch
from transformers import pipeline, AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import requests
import io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import os
from datetime import datetime, timedelta
import json
import google.generativeai as genai
# Constants
NASA_API_KEY = "DEMO_KEY" # Replace with your NASA API key for production
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "") # Will be set via Hugging Face Spaces environment variables
APOD_URL = "https://api.nasa.gov/planetary/apod"
CELESTIAL_BODIES = ["Sun", "Moon", "Mercury", "Venus", "Mars", "Jupiter", "Saturn", "Uranus", "Neptune", "Pluto"]
CELESTIAL_OBJECTS = ["Galaxy", "Nebula", "Star Cluster", "Supernova Remnant", "Black Hole", "Quasar", "Pulsar"]
# Initialize models
try:
# Astronomy image classifier
feature_extractor = AutoFeatureExtractor.from_pretrained("matthewberryman/astronomy-image-classifier")
model = AutoModelForImageClassification.from_pretrained("matthewberryman/astronomy-image-classifier")
# Image captioning model for astronomy images
caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
# Initialize Gemini Pro Vision if API key is available
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
# Configure the generative model
gemini_model = genai.GenerativeModel('gemini-2.0-flash')
gemini_text_model = genai.GenerativeModel('gemini-2.0-flash')
print("Gemini models initialized successfully")
else:
gemini_model = None
gemini_text_model = None
print("Gemini API key not found. Advanced features will be disabled.")
except Exception as e:
print(f"Model loading error: {e}")
# Fallback to simpler models if needed
caption_model = None
gemini_model = None
gemini_text_model = None
# Helper functions
def get_astronomy_picture_of_day(date=None):
"""Fetch NASA's Astronomy Picture of the Day"""
params = {'api_key': NASA_API_KEY}
if date:
params['date'] = date
try:
response = requests.get(APOD_URL, params=params)
data = response.json()
return data
except Exception as e:
return {"error": str(e), "title": "Error fetching APOD", "explanation": "Could not connect to NASA API"}
def classify_astronomy_image(image):
"""Classify an astronomy image using the pretrained model"""
if feature_extractor is None or model is None:
return {"error": "Model not loaded"}
try:
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
probs = outputs.logits.softmax(1)
pred_class = outputs.logits.argmax(-1).item()
# Get class labels and probabilities
id2label = model.config.id2label
prediction = id2label[pred_class]
confidence = probs[0][pred_class].item()
# Get top 3 predictions
top_3_indices = probs[0].topk(3).indices
top_3_preds = [(id2label[idx.item()], probs[0][idx].item()) for idx in top_3_indices]
return {
"prediction": prediction,
"confidence": confidence,
"top_3": top_3_preds
}
except Exception as e:
return {"error": str(e)}
def generate_image_caption(image):
"""Generate a caption for the astronomy image"""
if caption_model is None:
return "Image captioning model not available"
try:
caption = caption_model(image)[0]['generated_text']
return caption
except Exception as e:
return f"Error generating caption: {str(e)}"
def analyze_with_gemini(image, prompt=None):
"""Analyze astronomy image with Gemini Pro Vision"""
if gemini_model is None:
return "Gemini API not configured. Please add your API key in the Space settings."
try:
# Convert PIL image to bytes for Gemini
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Default prompt for astronomy images
if not prompt:
prompt = """
You are an expert astrophysicist. Analyze this astronomy image in detail.
Include:
1. Identification of the celestial object(s)
2. Scientific explanation of what's visible
3. Approximate distance from Earth (if applicable)
4. Interesting scientific facts about this type of object
5. Technological details about how such images are captured
6. Research value of studying this object
Format your analysis professionally as if for a scientific publication.
"""
# Generate analysis using Gemini
response = gemini_model.generate_content([prompt, img_byte_arr])
return response.text
except Exception as e:
return f"Error analyzing with Gemini: {str(e)}"
def get_professional_insights(query, context=None):
"""Get professional astronomy insights using Gemini Pro"""
if gemini_text_model is None:
return "Gemini API not configured. Please add your API key in the Space settings."
try:
# Build prompt with context if provided
prompt = f"""
You are a professional astrophysicist with expertise in observational astronomy,
cosmology, planetary science, and stellar evolution.
Please provide a comprehensive, scientifically accurate response to the following query:
{query}
"""
if context:
prompt += f"\n\nAdditional context: {context}"
# Generate insights
response = gemini_text_model.generate_content(prompt)
return response.text
except Exception as e:
return f"Error getting insights: {str(e)}"
def fetch_celestial_object_info(object_name):
"""Fetch information about a celestial object"""
# First check if Gemini is available for enhanced descriptions
if gemini_text_model is not None:
try:
# Generate detailed information using Gemini
prompt = f"""
You are an astronomy database. Provide comprehensive, scientifically accurate information about {object_name}.
Include these sections:
- Type of object
- Physical characteristics (size, mass, composition)
- Distance from Earth
- Formation and evolution
- Notable features
- Scientific significance
- Recent discoveries (if applicable)
Format this as structured data that can be parsed as JSON with the following fields:
type, distance, diameter, mass, temperature, composition, age, notable_features, research_value, description
Ensure all values are scientifically accurate and use appropriate units.
"""
response = gemini_text_model.generate_content(prompt)
# Try to parse as JSON
try:
# This is a simplification - in a real app we'd need more robust parsing
import re
json_match = re.search(r'```json\n(.*?)```', response.text, re.DOTALL)
if json_match:
json_str = json_match.group(1)
return json.loads(json_str)
else:
# Fallback to text processing if no JSON is found
lines = response.text.split('\n')
info = {"description": ""}
current_key = None
for line in lines:
if ':' in line and not line.startswith(' '):
parts = line.split(':', 1)
key = parts[0].lower().strip().replace(' ', '_')
value = parts[1].strip()
info[key] = value
current_key = key
elif current_key and line.strip() and current_key == "description":
info[current_key] += " " + line.strip()
if "description" not in info or not info["description"]:
info["description"] = f"Information about {object_name} generated using AI."
return info
except:
# JSON parsing failed, use fallback database
pass
except:
# If Gemini fails, use the fallback database
pass
# Fallback database
info = {
"Sun": {
"type": "Star",
"distance": "1 AU (149.6 million km)",
"diameter": "1,391,000 km",
"mass": "1.989 × 10^30 kg",
"temperature": "5,778 K (surface)",
"description": "The Sun is the star at the center of the Solar System. It is a nearly perfect sphere of hot plasma, heated to incandescence by nuclear fusion reactions in its core."
},
"Moon": {
"type": "Natural Satellite",
"distance": "384,400 km from Earth",
"diameter": "3,474 km",
"mass": "7.342 × 10^22 kg",
"temperature": "-173°C to 127°C",
"description": "The Moon is Earth's only natural satellite. It is the fifth-largest satellite in the Solar System and the largest among planetary satellites relative to the size of the planet it orbits."
},
"Mars": {
"type": "Planet",
"distance": "1.5 AU (227.9 million km)",
"diameter": "6,779 km",
"mass": "6.39 × 10^23 kg",
"temperature": "-87°C to -5°C",
"description": "Mars is the fourth planet from the Sun and the second-smallest planet in the Solar System. Mars is often called the 'Red Planet' due to its reddish appearance."
},
"Galaxy": {
"type": "Galaxy",
"description": "A galaxy is a gravitationally bound system of stars, stellar remnants, interstellar gas, dust, and dark matter. The Milky Way is the galaxy that contains our Solar System."
},
"Nebula": {
"type": "Nebula",
"description": "A nebula is an interstellar cloud of dust, hydrogen, helium and other ionized gases. Many nebulae are regions where new stars are being formed."
}
}
# Return info if available, otherwise return a generic message
return info.get(object_name, {"description": f"Information about {object_name} is not available in the demo database."})
def generate_star_chart(latitude, longitude, date=None):
"""Generate a simple star chart based on location and date"""
# This would ideally use a real astronomy library like Astropy
# For demo purposes, we'll create a simulated star chart
# Create a simple star field
np.random.seed(42) # For reproducibility
# Number of stars depends on date and location (simulated effect)
lat_factor = abs(latitude) / 90.0 # 0 to 1
if date:
try:
date_obj = datetime.strptime(date, "%Y-%m-%d")
day_of_year = date_obj.timetuple().tm_yday
season_factor = abs(((day_of_year + 10) % 365) - 182.5) / 182.5 # 0 to 1
except:
season_factor = 0.5
else:
season_factor = 0.5
num_stars = int(1000 + 2000 * lat_factor * season_factor)
# Create star positions
x = np.random.rand(num_stars) * 2 - 1 # -1 to 1
y = np.random.rand(num_stars) * 2 - 1 # -1 to 1
# Create star brightnesses (magnitudes)
magnitudes = np.random.exponential(1, num_stars) * 5
# Filter stars that would be below horizon
horizon_mask = y > -0.2
x = x[horizon_mask]
y = y[horizon_mask]
magnitudes = magnitudes[horizon_mask]
# Create plot
fig, ax = plt.subplots(figsize=(10, 10), facecolor='black')
ax.set_facecolor('black')
# Plot stars with varying sizes based on magnitude
sizes = 50 * np.exp(-magnitudes/2)
ax.scatter(x, y, s=sizes, color='white', alpha=0.8)
# Add celestial objects based on date and location (simulated)
# Moon
moon_x = 0.7 * np.cos(latitude/30)
moon_y = 0.6 * np.sin(longitude/30)
ax.scatter(moon_x, moon_y, s=300, color='lightgray', alpha=0.9)
ax.text(moon_x + 0.05, moon_y, 'Moon', color='white', fontsize=12)
# A bright planet
planet_x = -0.5 * np.sin(latitude/20)
planet_y = 0.4 * np.cos(longitude/20)
ax.scatter(planet_x, planet_y, s=120, color='orange', alpha=0.9)
ax.text(planet_x + 0.05, planet_y, 'Jupiter', color='white', fontsize=12)
# Add a few constellations (simplified)
constellations = [
{"name": "Big Dipper", "stars": [(0.2, 0.5), (0.3, 0.55), (0.4, 0.6),
(0.5, 0.62), (0.55, 0.5), (0.5, 0.4), (0.4, 0.45)]},
{"name": "Orion", "stars": [(-0.3, -0.1), (-0.25, 0), (-0.2, 0.1),
(-0.15, 0), (-0.35, -0.15), (-0.25, -0.15), (-0.15, -0.15)]}
]
for constellation in constellations:
# Draw lines connecting stars
points = np.array(constellation["stars"])
ax.plot(points[:,0], points[:,1], 'white', alpha=0.3, linestyle='-', linewidth=1)
# Draw stars
for x, y in constellation["stars"]:
ax.scatter(x, y, s=100, color='white', alpha=0.9)
# Label constellation
center_x = np.mean([p[0] for p in constellation["stars"]])
center_y = np.mean([p[1] for p in constellation["stars"]])
ax.text(center_x, center_y + 0.1, constellation["name"], color='white', fontsize=12, ha='center')
# Set plot parameters
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_aspect('equal')
ax.axis('off')
# Set title with location and date
location_str = f"Lat: {latitude:.1f}°, Long: {longitude:.1f}°"
date_str = date if date else datetime.now().strftime("%Y-%m-%d")
ax.set_title(f"Star Chart for {location_str} on {date_str}", color='white', fontsize=14)
# Save to a buffer and return
buf = io.BytesIO()
plt.savefig(buf, format='png', facecolor='black')
buf.seek(0)
plt.close(fig)
return buf
def predict_space_weather(date=None):
"""Predict space weather conditions (solar flares, aurora activity)"""
# This would ideally use real space weather data and predictions
# For demo purposes, we'll generate simulated predictions
if date:
try:
target_date = datetime.strptime(date, "%Y-%m-%d")
except:
target_date = datetime.now()
else:
target_date = datetime.now()
# Generate predictions for 7 days
dates = [(target_date + timedelta(days=i)).strftime("%Y-%m-%d") for i in range(7)]
# Simulate solar activity (0-10 scale)
np.random.seed(int(target_date.timestamp()) % 1000)
solar_activity = np.clip(5 + np.cumsum(np.random.normal(0, 1, 7)) * 0.5, 0, 10)
# Simulate geomagnetic activity (Kp index, 0-9 scale)
geomagnetic_activity = np.clip(np.round(4 + np.cumsum(np.random.normal(0, 0.8, 7)) * 0.3), 0, 9)
# Simulate aurora visibility (0-10 scale)
aurora_visibility = np.clip(geomagnetic_activity * 1.1 + np.random.normal(0, 1, 7), 0, 10)
# Simulate solar flare probability (percentage)
flare_probability = np.clip(solar_activity * 10 + np.random.normal(0, 5, 7), 0, 100)
# Create a dataframe
weather_df = pd.DataFrame({
'Date': dates,
'Solar Activity': [f"{x:.1f}/10" for x in solar_activity],
'Geomagnetic Activity': [f"Kp {int(x)}" for x in geomagnetic_activity],
'Aurora Visibility': [f"{x:.1f}/10" for x in aurora_visibility],
'Solar Flare Probability': [f"{int(x)}%" for x in flare_probability]
})
return weather_df
# UI Components
def build_ui():
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo")) as app:
gr.Markdown(
"""
# 🌌 Professional AI Astronomy Explorer
Explore the universe with the power of AI and Gemini Pro! Upload your astronomy images for classification,
get the latest astronomy picture of the day, generate star charts based on your location,
and access professional-grade astronomical analysis powered by Google's Gemini API.
"""
)
with gr.Tab("📸 Professional Image Analysis"):
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Astronomy Image")
with gr.Row():
classify_btn = gr.Button("Basic Analysis", variant="secondary", scale=1)
gemini_btn = gr.Button("Professional Analysis (Gemini)", variant="primary", scale=1)
gemini_prompt = gr.Textbox(
label="Customize Gemini Analysis Prompt (Optional)",
placeholder="Leave blank for default professional analysis",
lines=3,
visible=True
)
with gr.Column(scale=1):
with gr.Tabs():
with gr.TabItem("Basic Results"):
prediction_output = gr.Textbox(label="Predicted Object Type")
confidence_output = gr.Textbox(label="Confidence")
top3_output = gr.JSON(label="Top 3 Predictions")
caption_output = gr.Textbox(label="AI-Generated Caption", lines=3)
with gr.TabItem("Professional Analysis"):
gemini_output = gr.Markdown(label="Gemini Pro Analysis")
classify_btn.click(
fn=lambda img: {
prediction_output: classify_astronomy_image(img).get("prediction", "Unknown"),
confidence_output: f"{classify_astronomy_image(img).get('confidence', 0) * 100:.2f}%",
top3_output: [{"class": c, "probability": f"{p*100:.2f}%"} for c, p in classify_astronomy_image(img).get("top_3", [])],
caption_output: generate_image_caption(img)
},
inputs=input_image,
outputs=[prediction_output, confidence_output, top3_output, caption_output]
)
gemini_btn.click(
fn=lambda img, prompt: analyze_with_gemini(img, prompt),
inputs=[input_image, gemini_prompt],
outputs=gemini_output
)
with gr.Tab("🔭 Astronomy Picture of the Day"):
with gr.Row():
with gr.Column(scale=1):
apod_date = gr.Date(label="Select Date (or leave blank for today)")
apod_btn = gr.Button("Get Astronomy Picture of the Day", variant="primary")
with gr.Column(scale=2):
apod_image = gr.Image(label="APOD Image", interactive=False)
apod_title = gr.Textbox(label="Title")
apod_desc = gr.Textbox(label="Description", lines=5)
apod_btn.click(
fn=lambda date: {
apod_image: requests.get(get_astronomy_picture_of_day(date).get("url", "")).content if "url" in get_astronomy_picture_of_day(date) else None,
apod_title: get_astronomy_picture_of_day(date).get("title", "Error fetching APOD"),
apod_desc: get_astronomy_picture_of_day(date).get("explanation", "No description available")
},
inputs=apod_date,
outputs=[apod_image, apod_title, apod_desc]
)
with gr.Tab("🌠 Star Chart Generator"):
with gr.Row():
with gr.Column(scale=1):
latitude = gr.Slider(minimum=-90, maximum=90, value=40, step=0.1, label="Latitude")
longitude = gr.Slider(minimum=-180, maximum=180, value=-75, step=0.1, label="Longitude")
chart_date = gr.Date(label="Date (leave blank for today)")
chart_btn = gr.Button("Generate Star Chart", variant="primary")
with gr.Column(scale=2):
star_chart = gr.Image(label="Generated Star Chart", interactive=False)
chart_btn.click(
fn=lambda lat, long, date: star_chart.update(generate_star_chart(lat, long, date)),
inputs=[latitude, longitude, chart_date],
outputs=star_chart
)
with gr.Tab("☀️ Space Weather"):
with gr.Row():
with gr.Column(scale=1):
weather_date = gr.Date(label="Start Date (leave blank for today)")
weather_btn = gr.Button("Predict Space Weather", variant="primary")
with gr.Column(scale=2):
weather_output = gr.Dataframe(label="7-Day Space Weather Forecast")
weather_btn.click(
fn=lambda date: predict_space_weather(date),
inputs=weather_date,
outputs=weather_output
)
with gr.Tab("🪐 Professional Astronomy Knowledge Base"):
with gr.Tabs():
with gr.TabItem("Celestial Object Database"):
with gr.Row():
with gr.Column(scale=1):
object_selector = gr.Dropdown(
choices=CELESTIAL_BODIES + CELESTIAL_OBJECTS,
label="Select Celestial Object"
)
object_btn = gr.Button("Get Information", variant="primary")
with gr.Column(scale=2):
object_info = gr.JSON(label="Object Information")
object_desc = gr.Textbox(label="Description", lines=4)
object_btn.click(
fn=lambda obj: {
object_info: {k: v for k, v in fetch_celestial_object_info(obj).items() if k != "description"},
object_desc: fetch_celestial_object_info(obj).get("description", "No description available")
},
inputs=object_selector,
outputs=[object_info, object_desc]
)
with gr.TabItem("Ask a Professional Astronomer"):
with gr.Row():
with gr.Column(scale=1):
astro_query = gr.Textbox(
label="Your Astronomy Question",
placeholder="Ask about celestial objects, phenomena, theories, or observational techniques...",
lines=3
)
astro_context = gr.Textbox(
label="Additional Context (Optional)",
placeholder="Add any relevant context or background to your question",
lines=2
)
ask_btn = gr.Button("Get Professional Insights", variant="primary")
with gr.Column(scale=1):
pro_insights = gr.Markdown(label="Professional Insights")
ask_btn.click(
fn=lambda query, context: get_professional_insights(query, context),
inputs=[astro_query, astro_context],
outputs=pro_insights
)
gr.Markdown(
"""
### About This Professional Astronomy App
This AI Astronomy Explorer combines advanced machine learning models with Google's Gemini AI to provide professional-grade astronomical analysis:
- **Professional Image Analysis**:
- Basic classification with standard ML models
- Advanced analysis with Gemini Pro Vision providing expert-level insights
- Customizable analysis prompts for specific research questions
- **Research-Grade Tools**:
- NASA APOD integration for daily astronomical phenomena
- Interactive star chart generation with astronomical calculations
- Space weather forecasting for observational planning
- **Professional Knowledge Base**:
- Comprehensive celestial object database enhanced by Gemini Pro
- "Ask a Professional Astronomer" feature for research questions
- Scientifically accurate information suitable for educational and research purposes
Developed with ❤️ for astronomy professionals, researchers, educators, and enthusiasts.
*Note: The full functionality of this app requires a valid Google Gemini API key to be configured in the Space settings.*
"""
)
return app
# Create and launch the app
app = build_ui()
# For Hugging Face Spaces deployment
if __name__ == "__main__":
app.launch()