File size: 5,249 Bytes
74b6416
 
855f2a2
 
74b6416
 
 
 
 
855f2a2
 
74b6416
 
7d80373
74b6416
 
 
 
7d80373
855f2a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd76f61
 
855f2a2
cd76f61
855f2a2
 
 
 
cd76f61
855f2a2
cd76f61
855f2a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118e3aa
741bf73
118e3aa
7d80373
741bf73
 
 
855f2a2
 
741bf73
 
 
 
 
 
 
855f2a2
741bf73
855f2a2
 
 
 
 
 
741bf73
855f2a2
 
741bf73
 
7d80373
855f2a2
118e3aa
855f2a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74b6416
7d80373
 
 
 
 
 
 
 
 
 
 
 
 
74b6416
 
 
855f2a2
 
 
 
 
 
 
 
74b6416
7d80373
74b6416
7d80373
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3
"""
HuggingFace Spaces startup script with Ollama support.
Starts Ollama server in background, then launches Streamlit.
"""

import os
import sys
import subprocess
import time
import signal
from datetime import datetime


def log(message):
    """Log message to stderr for visibility in HF Spaces."""
    print(f"[{datetime.now().isoformat()}] {message}", file=sys.stderr, flush=True)


def start_ollama():
    """Start Ollama server in background."""
    log("πŸ¦™ Starting Ollama server...")
    
    try:
        # Start Ollama in background
        ollama_process = subprocess.Popen(
            ["ollama", "serve"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            preexec_fn=os.setsid  # Create new process group
        )
        
        log(f"πŸ¦™ Ollama server started with PID {ollama_process.pid}")
        
        # Wait for Ollama to be ready
        log("⏳ Waiting for Ollama to be ready...")
        max_retries = 12  # 60 seconds total
        for attempt in range(max_retries):
            try:
                # Test if Ollama is responding
                result = subprocess.run(
                    ["curl", "-s", "http://localhost:11434/api/tags"],
                    capture_output=True,
                    timeout=5
                )
                if result.returncode == 0:
                    log("βœ… Ollama server is ready!")
                    
                    # Try to pull the model (optimized for HF Spaces)
                    log("πŸ“₯ Pulling llama3.2:1b model (optimized for container deployment)...")
                    pull_result = subprocess.run(
                        ["ollama", "pull", "llama3.2:1b"],
                        capture_output=True,
                        timeout=300  # 5 minutes for model download
                    )
                    if pull_result.returncode == 0:
                        log("βœ… Model llama3.2:1b ready!")
                    else:
                        log("⚠️ Model pull failed, will download on first use")
                    
                    return ollama_process
                    
            except (subprocess.TimeoutExpired, Exception) as e:
                log(f"πŸ”„ Ollama not ready yet (attempt {attempt + 1}/{max_retries}): {e}")
                time.sleep(5)
                continue
        
        log("❌ Ollama failed to start after 60 seconds")
        ollama_process.terminate()
        return None
        
    except Exception as e:
        log(f"❌ Failed to start Ollama: {e}")
        return None


def main():
    """Start services and Streamlit based on configuration."""
    log("πŸš€ Starting Technical RAG Assistant in HuggingFace Spaces...")

    # Check which inference method to use
    use_ollama = os.getenv("USE_OLLAMA", "false").lower() == "true"
    use_inference_providers = os.getenv("USE_INFERENCE_PROVIDERS", "false").lower() == "true"
    
    ollama_process = None
    
    # Configure environment variables based on selected inference method
    if use_inference_providers:
        os.environ["USE_INFERENCE_PROVIDERS"] = "true"
        os.environ["USE_OLLAMA"] = "false"
        log("πŸš€ Using Inference Providers API")
    elif use_ollama:
        os.environ["USE_OLLAMA"] = "true"
        os.environ["USE_INFERENCE_PROVIDERS"] = "false"
        log("πŸ¦™ Ollama enabled - starting server...")
        ollama_process = start_ollama()
        
        if ollama_process is None:
            log("πŸ”„ Ollama failed to start, falling back to HuggingFace API")
            os.environ["USE_OLLAMA"] = "false"
            os.environ["USE_INFERENCE_PROVIDERS"] = "false"
    else:
        os.environ["USE_OLLAMA"] = "false"
        os.environ["USE_INFERENCE_PROVIDERS"] = "false"
        log("πŸ€— Using classic HuggingFace API")

    # Start Streamlit
    log("🎯 Starting Streamlit application...")
    
    def signal_handler(signum, frame):
        """Handle shutdown signals."""
        log("πŸ›‘ Received shutdown signal, cleaning up...")
        if ollama_process:
            log("πŸ¦™ Stopping Ollama server...")
            try:
                os.killpg(os.getpgid(ollama_process.pid), signal.SIGTERM)
            except:
                pass
        sys.exit(0)
    
    # Register signal handlers
    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)
    
    try:
        subprocess.run(
            [
                "streamlit",
                "run",
                "streamlit_app.py",
                "--server.port=8501",
                "--server.address=0.0.0.0",
                "--server.headless=true",
                "--server.enableCORS=false",
                "--server.enableXsrfProtection=false",
            ],
            check=True,
        )
    except Exception as e:
        log(f"❌ Failed to start Streamlit: {e}")
        sys.exit(1)
    finally:
        # Clean up Ollama if it was started
        if ollama_process:
            log("πŸ¦™ Cleaning up Ollama server...")
            try:
                os.killpg(os.getpgid(ollama_process.pid), signal.SIGTERM)
            except:
                pass


if __name__ == "__main__":
    main()