victor HF Staff commited on
Commit
f71c1c7
·
1 Parent(s): 1d7f241

fetch: add private-host allowlist + env switch; include resolved IPs in error and recheck redirects; document env vars in README

Browse files
Files changed (2) hide show
  1. README.md +8 -0
  2. app.py +64 -8
README.md CHANGED
@@ -44,6 +44,14 @@ pip install -r requirements.txt
44
  ```
45
  If unset, the app automatically prefers `/data` when available, otherwise `./data`.
46
 
 
 
 
 
 
 
 
 
47
  The request counters live in `<DATA_DIR>/request_counts.json`, guarded by a file lock to support concurrent MCP calls.
48
 
49
  ## Running Locally
 
44
  ```
45
  If unset, the app automatically prefers `/data` when available, otherwise `./data`.
46
 
47
+ 3. (Optional) Control private/local address policy for `fetch`:
48
+ - `FETCH_ALLOW_PRIVATE` — set to `1`/`true` to disable the SSRF guard entirely (not recommended except for trusted, local testing).
49
+ - `FETCH_PRIVATE_ALLOWLIST` — comma/space separated host patterns allowed even if they resolve to private/local IPs, e.g.:
50
+ ```bash
51
+ export FETCH_PRIVATE_ALLOWLIST="*.corp.local, my-proxy.internal"
52
+ ```
53
+ If neither is set, the fetcher refuses URLs whose host resolves to private, loopback, link‑local, multicast, reserved, or unspecified addresses. It also re-checks the final redirect target.
54
+
55
  The request counters live in `<DATA_DIR>/request_counts.json`, guarded by a file lock to support concurrent MCP calls.
56
 
57
  ## Running Locally
app.py CHANGED
@@ -6,6 +6,7 @@ import asyncio
6
  import ipaddress
7
  import socket
8
  from typing import Optional, Dict, Any, List, Tuple
 
9
  from urllib.parse import urlsplit
10
  from datetime import datetime, timezone
11
 
@@ -84,6 +85,23 @@ EXTRACT_CONCURRENCY = max(
84
  SEARCH_CACHE_TTL = max(0, int(os.getenv("SEARCH_CACHE_TTL", "30")))
85
  FETCH_CACHE_TTL = max(0, int(os.getenv("FETCH_CACHE_TTL", "300")))
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  _search_cache: Dict[Tuple[str, str, int], Dict[str, Any]] = {}
88
  _fetch_cache: Dict[str, Dict[str, Any]] = {}
89
  _search_cache_lock: Optional[asyncio.Lock] = None
@@ -162,20 +180,39 @@ def _client_ip(request: Optional[gr.Request]) -> str:
162
  return "unknown"
163
 
164
 
165
- async def _host_is_public(host: str) -> bool:
 
166
  if not host:
167
  return False
 
 
 
 
 
 
168
 
 
169
  def _resolve() -> List[str]:
170
  try:
171
  return list({ai[4][0] for ai in socket.getaddrinfo(host, None)})
172
  except Exception:
173
  return []
174
 
175
- addresses = await asyncio.to_thread(_resolve)
 
 
 
 
 
 
 
 
 
 
 
 
176
  if not addresses:
177
- # If resolution fails we let the HTTP fetch decide.
178
- return True
179
 
180
  for addr in addresses:
181
  ip_obj = ipaddress.ip_address(addr)
@@ -187,8 +224,8 @@ async def _host_is_public(host: str) -> bool:
187
  or ip_obj.is_reserved
188
  or ip_obj.is_unspecified
189
  ):
190
- return False
191
- return True
192
 
193
 
194
  async def _check_rate_limits(bucket: str, ip: str) -> Optional[str]:
@@ -374,9 +411,14 @@ async def fetch(
374
  await record_request("fetch")
375
  return cached
376
 
377
- if not await _host_is_public(host):
 
378
  await record_request("fetch")
379
- return {"error": "Refusing to fetch private or local addresses."}
 
 
 
 
380
 
381
  fetch_sema = _get_semaphore("fetch")
382
  await fetch_sema.acquire()
@@ -397,6 +439,20 @@ async def fetch(
397
  fetch_sema.release()
398
 
399
  truncated = total > FETCH_MAX_BYTES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  text = body.decode(encoding, errors="ignore")
401
 
402
  extract_sema = _get_semaphore("extract")
 
6
  import ipaddress
7
  import socket
8
  from typing import Optional, Dict, Any, List, Tuple
9
+ import fnmatch
10
  from urllib.parse import urlsplit
11
  from datetime import datetime, timezone
12
 
 
85
  SEARCH_CACHE_TTL = max(0, int(os.getenv("SEARCH_CACHE_TTL", "30")))
86
  FETCH_CACHE_TTL = max(0, int(os.getenv("FETCH_CACHE_TTL", "300")))
87
 
88
+ # Controls for private/local address handling in fetch()
89
+ def _env_flag(name: str, default: bool = False) -> bool:
90
+ """Parse boolean-like env vars such as 1/true/yes/on."""
91
+ v = os.getenv(name)
92
+ if v is None:
93
+ return default
94
+ return str(v).strip().lower() in {"1", "true", "yes", "on", "y"}
95
+
96
+ # When True, allow any destination (disables SSRF guard — not recommended)
97
+ FETCH_ALLOW_PRIVATE = _env_flag("FETCH_ALLOW_PRIVATE", False)
98
+
99
+ # Optional comma/space separated host patterns to allow even if private, e.g.:
100
+ # FETCH_PRIVATE_ALLOWLIST="*.internal.example.com, my-proxy.local"
101
+ FETCH_PRIVATE_ALLOWLIST = [
102
+ p for p in re.split(r"[\s,]+", os.getenv("FETCH_PRIVATE_ALLOWLIST", "").strip()) if p
103
+ ]
104
+
105
  _search_cache: Dict[Tuple[str, str, int], Dict[str, Any]] = {}
106
  _fetch_cache: Dict[str, Dict[str, Any]] = {}
107
  _search_cache_lock: Optional[asyncio.Lock] = None
 
180
  return "unknown"
181
 
182
 
183
+ def _host_matches_allowlist(host: str) -> bool:
184
+ """Return True if host matches any pattern in FETCH_PRIVATE_ALLOWLIST."""
185
  if not host:
186
  return False
187
+ for pat in FETCH_PRIVATE_ALLOWLIST:
188
+ # Support bare host equality and fnmatch-style patterns (*.foo.bar)
189
+ if host == pat or fnmatch.fnmatch(host, pat):
190
+ return True
191
+ return False
192
+
193
 
194
+ async def _resolve_addresses(host: str) -> List[str]:
195
  def _resolve() -> List[str]:
196
  try:
197
  return list({ai[4][0] for ai in socket.getaddrinfo(host, None)})
198
  except Exception:
199
  return []
200
 
201
+ return await asyncio.to_thread(_resolve)
202
+
203
+
204
+ async def _host_is_public(host: str) -> Tuple[bool, List[str]]:
205
+ """Return (is_public, resolved_addresses).
206
+
207
+ - If resolution fails, treat as public and let HTTP request decide.
208
+ - Honors allowlist/env flags via the caller.
209
+ """
210
+ if not host:
211
+ return False, []
212
+
213
+ addresses = await _resolve_addresses(host)
214
  if not addresses:
215
+ return True, []
 
216
 
217
  for addr in addresses:
218
  ip_obj = ipaddress.ip_address(addr)
 
224
  or ip_obj.is_reserved
225
  or ip_obj.is_unspecified
226
  ):
227
+ return False, addresses
228
+ return True, addresses
229
 
230
 
231
  async def _check_rate_limits(bucket: str, ip: str) -> Optional[str]:
 
411
  await record_request("fetch")
412
  return cached
413
 
414
+ is_public, addrs = await _host_is_public(host)
415
+ if not is_public and not (FETCH_ALLOW_PRIVATE or _host_matches_allowlist(host)):
416
  await record_request("fetch")
417
+ detail = f" (resolved: {', '.join(addrs)})" if addrs else ""
418
+ return {
419
+ "error": "Refusing to fetch private or local addresses." + detail,
420
+ "host": host,
421
+ }
422
 
423
  fetch_sema = _get_semaphore("fetch")
424
  await fetch_sema.acquire()
 
439
  fetch_sema.release()
440
 
441
  truncated = total > FETCH_MAX_BYTES
442
+ # Extra guard: if final URL host ended up private due to a redirect and
443
+ # the user hasn't allowed private hosts, refuse to return body content.
444
+ try:
445
+ final_host = urlsplit(final_url_str).hostname or ""
446
+ except Exception:
447
+ final_host = ""
448
+ if final_host and not (FETCH_ALLOW_PRIVATE or _host_matches_allowlist(final_host)):
449
+ final_public, _ = await _host_is_public(final_host)
450
+ if not final_public:
451
+ await record_request("fetch")
452
+ return {
453
+ "error": "Refusing to fetch private or local addresses after redirect.",
454
+ "host": final_host,
455
+ }
456
  text = body.decode(encoding, errors="ignore")
457
 
458
  extract_sema = _get_semaphore("extract")