import asyncio import json from fastapi import APIRouter, Depends from httpx import AsyncClient from jinja2 import Environment, TemplateNotFound from litellm.router import Router from dependencies import INSIGHT_FINDER_BASE_URL, get_http_client, get_llm_router, get_prompt_templates from typing import Awaitable, Callable, TypeVar from schemas import _RefinedSolutionModel, _BootstrappedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, InsightFinderConstraintsList, PriorArtSearchRequest, PriorArtSearchResponse, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse, SolutionCriticism, SolutionModel, SolutionBootstrapResponse, SolutionBootstrapRequest, TechnologyData # Router for solution generation and critique router = APIRouter(tags=["solution generation and critique"]) # ============== utilities =========================== T = TypeVar("T") A = TypeVar("A") async def retry_until( func: Callable[[A], Awaitable[T]], arg: A, predicate: Callable[[T], bool], max_retries: int, ) -> T: """Retries the given async function until the passed in validation predicate returns true.""" last_value = await func(arg) for _ in range(max_retries): if predicate(last_value): return last_value last_value = await func(arg) return last_value # =================================================== Search solutions ============================================================================ @router.post("/bootstrap_solutions") async def bootstrap_solutions(req: SolutionBootstrapRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router), http_client: AsyncClient = Depends(get_http_client)) -> SolutionBootstrapResponse: """ Boostraps a solution for each of the passed in requirements categories using Insight Finder's API. """ async def _bootstrap_solution_inner(cat: ReqGroupingCategory): # process requirements into insight finder format fmt_completion = await llm_router.acompletion("gemini-v2", messages=[ { "role": "user", "content": await prompt_env.get_template("format_requirements.txt").render_async(**{ "category": cat.model_dump(), "response_schema": InsightFinderConstraintsList.model_json_schema() }) }], response_format=InsightFinderConstraintsList) fmt_model = InsightFinderConstraintsList.model_validate_json( fmt_completion.choices[0].message.content) # translate from a structured output to a dict for insights finder formatted_constraints = {'constraints': { cons.title: cons.description for cons in fmt_model.constraints}} # fetch technologies from insight finder technologies_req = await http_client.post(INSIGHT_FINDER_BASE_URL + "process-constraints", content=json.dumps(formatted_constraints)) technologies = TechnologyData.model_validate(technologies_req.json()) # =============================================================== synthesize solution using LLM ========================================= format_solution = await llm_router.acompletion("gemini-v2", messages=[{ "role": "user", "content": await prompt_env.get_template("bootstrap_solution.txt").render_async(**{ "category": cat.model_dump(), "technologies": technologies.model_dump()["technologies"], "user_constraints": req.user_constraints, "response_schema": _BootstrappedSolutionModel.model_json_schema() })} ], response_format=_BootstrappedSolutionModel) format_solution_model = _BootstrappedSolutionModel.model_validate_json( format_solution.choices[0].message.content) final_solution = SolutionModel( context="", requirements=[ cat.requirements[i].requirement for i in format_solution_model.requirement_ids ], problem_description=format_solution_model.problem_description, solution_description=format_solution_model.solution_description, references=[], category_id=cat.id, ) # ======================================================================================================================================== return final_solution tasks = await asyncio.gather(*[_bootstrap_solution_inner(cat) for cat in req.categories], return_exceptions=True) final_solutions = [sol for sol in tasks if not isinstance(sol, Exception)] return SolutionBootstrapResponse(solutions=final_solutions) @router.post("/criticize_solution", response_model=CritiqueResponse) async def criticize_solution(params: CriticizeSolutionsRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> CritiqueResponse: """Criticize the challenges, weaknesses and limitations of the provided solutions.""" async def __criticize_single(solution: SolutionModel): req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{ "solutions": [solution.model_dump()], "response_schema": _SolutionCriticismOutput.model_json_schema() }) req_completion = await llm_router.acompletion( model="gemini-v2", messages=[{"role": "user", "content": req_prompt}], response_format=_SolutionCriticismOutput ) criticism_out = _SolutionCriticismOutput.model_validate_json( req_completion.choices[0].message.content ) return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0]) critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False) return CritiqueResponse(critiques=critiques) # =================================================================== Refine solution ==================================== @router.post("/refine_solutions", response_model=SolutionBootstrapResponse) async def refine_solutions(params: CritiqueResponse, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> SolutionBootstrapResponse: """Refines the previously critiqued solutions.""" async def __refine_solution(crit: SolutionCriticism): req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{ "solution": crit.solution.model_dump(), "criticism": crit.criticism, "response_schema": _RefinedSolutionModel.model_json_schema(), }) req_completion = await llm_router.acompletion(model="gemini-v2", messages=[ {"role": "user", "content": req_prompt} ], response_format=_RefinedSolutionModel) req_model = _RefinedSolutionModel.model_validate_json( req_completion.choices[0].message.content) # copy previous solution model refined_solution = crit.solution.model_copy(deep=True) refined_solution.problem_description = req_model.problem_description refined_solution.solution_description = req_model.solution_description return refined_solution refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False) return SolutionBootstrapResponse(solutions=refined_solutions) @router.post("/search_prior_art") async def search_prior_art(req: PriorArtSearchRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> PriorArtSearchResponse: """Performs a comprehensive prior art search / FTO search against the provided topics for a drafted solution""" sema = asyncio.Semaphore(4) async def __search_topic(topic: str) -> str: search_prompt = await prompt_env.get_template("search/search_topic.txt").render_async(**{ "topic": topic }) try: await sema.acquire() search_completion = await llm_router.acompletion(model="gemini-v2", messages=[ {"role": "user", "content": search_prompt} ], temperature=0.3, tools=[{"googleSearch": {}}]) return {"topic": topic, "content": search_completion.choices[0].message.content} finally: sema.release() # Dispatch the individual tasks for topic search topics = await asyncio.gather(*[__search_topic(top) for top in req.topics], return_exceptions=False) consolidation_prompt = await prompt_env.get_template("search/build_final_report.txt").render_async(**{ "searches": topics }) # Then consolidate everything into a single detailed topic consolidation_completion = await llm_router.acompletion(model="gemini-v2", messages=[ {"role": "user", "content": consolidation_prompt} ], temperature=0.5) return PriorArtSearchResponse(content=consolidation_completion.choices[0].message.content, topic_contents=topics)