a2 / Tools /build_world.py
ravenscoat619's picture
story agent
67dff27
# Tools/build_world.py
from typing import Dict, Any
from smolagents import tool
import json
import torch
# from llm_utils import tokenizer, model, generate_completion # These are loaded globally elsewhere
@tool
def build_world(facts: Dict[str, Any]) -> Dict[str, Any]:
"""
Given a structured `facts` dictionary, returns a world-building dictionary with keys:
- setting_description: a vivid 2–3 sentence paragraph describing the environment.
- flora: a list of 3–5 plant species commonly found here.
- fauna: a list of 3–5 animals or creatures one might encounter.
- ambiance: a list of 3–5 sensory details (sounds, smells, tactile sensations).
Args:
facts (Dict[str, Any]): The structured facts extracted from the scene.
Returns:
Dict[str, Any]: A dictionary with exactly the four keys:
'setting_description', 'flora', 'fauna', and 'ambiance'.
"""
# 1) Prepare the JSON-extraction prompt
facts_json = json.dumps(facts, indent=2)
prompt = f"""
You are a world-building assistant. Given these structured facts:
{facts_json}
Generate a JSON object with exactly these fields:
1) setting_description: a 2–3 sentence vivid paragraph describing the environment.
2) flora: a list of 3–5 plant species commonly found here.
3) fauna: a list of 3–5 animals or creatures one might encounter.
4) ambiance: a list of 3–5 sensory details (sounds, smells, tactile feelings).
Return ONLY valid JSON with those four keys.
"""
# 2) Tokenize & move to device
inputs = tokenizer.apply_chat_template(
[
{"role": "system", "content": "You convert structured facts into world-building JSON."},
{"role": "user", "content": prompt}
],
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True # Keep return_dict=True to get attention_mask
)
# Move each tensor to the model device
for k,v in inputs.items():
inputs[k] = v.to(model.device)
# 3) Generate up to 256 new tokens via shared generate_completion
with torch.no_grad():
# Pass the input_ids tensor directly and the attention_mask
outputs = model.generate(
inputs["input_ids"], # Pass the tensor directly
max_new_tokens=256,
attention_mask=inputs.get("attention_mask") # Pass attention_mask if present
)
# 4) Slice off the prompt tokens
prompt_len = inputs["input_ids"].shape[-1]
gen_ids = outputs[0][prompt_len:]
# 5) Decode the JSON string
raw = tokenizer.decode(gen_ids, skip_special_tokens=True)
start = raw.find("{")
candidate = raw[start:] if start >= 0 else raw
# 6) Parse, with a defaults fallback
defaults = {
"setting_description": "",
"flora": [],
"fauna": [],
"ambiance": []
}
try:
world_dict = json.loads(candidate)
except Exception:
world_dict = defaults.copy()
# 7) Ensure all keys are present
for key, val in defaults.items():
world_dict.setdefault(key, val)
return world_dict