Reqxtract-v2 / api /solutions.py
Lucas ARRIESSE
wip
46800f4
import asyncio
import json
from fastapi import APIRouter, Depends
from httpx import AsyncClient
from jinja2 import Environment, TemplateNotFound
from litellm.router import Router
from dependencies import INSIGHT_FINDER_BASE_URL, get_http_client, get_llm_router, get_prompt_templates
from typing import Awaitable, Callable, TypeVar
from schemas import _RefinedSolutionModel, _BootstrappedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, InsightFinderConstraintsList, PriorArtSearchRequest, PriorArtSearchResponse, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse, SolutionCriticism, SolutionModel, SolutionBootstrapResponse, SolutionBootstrapRequest, TechnologyData
# Router for solution generation and critique
router = APIRouter(tags=["solution generation and critique"])
# ============== utilities ===========================
T = TypeVar("T")
A = TypeVar("A")
async def retry_until(
func: Callable[[A], Awaitable[T]],
arg: A,
predicate: Callable[[T], bool],
max_retries: int,
) -> T:
"""Retries the given async function until the passed in validation predicate returns true."""
last_value = await func(arg)
for _ in range(max_retries):
if predicate(last_value):
return last_value
last_value = await func(arg)
return last_value
# =================================================== Search solutions ============================================================================
@router.post("/bootstrap_solutions")
async def bootstrap_solutions(req: SolutionBootstrapRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router), http_client: AsyncClient = Depends(get_http_client)) -> SolutionBootstrapResponse:
"""
Boostraps a solution for each of the passed in requirements categories using Insight Finder's API.
"""
async def _bootstrap_solution_inner(cat: ReqGroupingCategory):
# process requirements into insight finder format
fmt_completion = await llm_router.acompletion("gemini-v2", messages=[
{
"role": "user",
"content": await prompt_env.get_template("format_requirements.txt").render_async(**{
"category": cat.model_dump(),
"response_schema": InsightFinderConstraintsList.model_json_schema()
})
}], response_format=InsightFinderConstraintsList)
fmt_model = InsightFinderConstraintsList.model_validate_json(
fmt_completion.choices[0].message.content)
# translate from a structured output to a dict for insights finder
formatted_constraints = {'constraints': {
cons.title: cons.description for cons in fmt_model.constraints}}
# fetch technologies from insight finder
technologies_req = await http_client.post(INSIGHT_FINDER_BASE_URL + "process-constraints", content=json.dumps(formatted_constraints))
technologies = TechnologyData.model_validate(technologies_req.json())
# =============================================================== synthesize solution using LLM =========================================
format_solution = await llm_router.acompletion("gemini-v2", messages=[{
"role": "user",
"content": await prompt_env.get_template("bootstrap_solution.txt").render_async(**{
"category": cat.model_dump(),
"technologies": technologies.model_dump()["technologies"],
"user_constraints": req.user_constraints,
"response_schema": _BootstrappedSolutionModel.model_json_schema()
})}
], response_format=_BootstrappedSolutionModel)
format_solution_model = _BootstrappedSolutionModel.model_validate_json(
format_solution.choices[0].message.content)
final_solution = SolutionModel(
context="",
requirements=[
cat.requirements[i].requirement for i in format_solution_model.requirement_ids
],
problem_description=format_solution_model.problem_description,
solution_description=format_solution_model.solution_description,
references=[],
category_id=cat.id,
)
# ========================================================================================================================================
return final_solution
tasks = await asyncio.gather(*[_bootstrap_solution_inner(cat) for cat in req.categories], return_exceptions=True)
final_solutions = [sol for sol in tasks if not isinstance(sol, Exception)]
return SolutionBootstrapResponse(solutions=final_solutions)
@router.post("/criticize_solution", response_model=CritiqueResponse)
async def criticize_solution(params: CriticizeSolutionsRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> CritiqueResponse:
"""Criticize the challenges, weaknesses and limitations of the provided solutions."""
async def __criticize_single(solution: SolutionModel):
req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
"solutions": [solution.model_dump()],
"response_schema": _SolutionCriticismOutput.model_json_schema()
})
req_completion = await llm_router.acompletion(
model="gemini-v2",
messages=[{"role": "user", "content": req_prompt}],
response_format=_SolutionCriticismOutput
)
criticism_out = _SolutionCriticismOutput.model_validate_json(
req_completion.choices[0].message.content
)
return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0])
critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False)
return CritiqueResponse(critiques=critiques)
# =================================================================== Refine solution ====================================
@router.post("/refine_solutions", response_model=SolutionBootstrapResponse)
async def refine_solutions(params: CritiqueResponse, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> SolutionBootstrapResponse:
"""Refines the previously critiqued solutions."""
async def __refine_solution(crit: SolutionCriticism):
req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{
"solution": crit.solution.model_dump(),
"criticism": crit.criticism,
"response_schema": _RefinedSolutionModel.model_json_schema(),
})
req_completion = await llm_router.acompletion(model="gemini-v2", messages=[
{"role": "user", "content": req_prompt}
], response_format=_RefinedSolutionModel)
req_model = _RefinedSolutionModel.model_validate_json(
req_completion.choices[0].message.content)
# copy previous solution model
refined_solution = crit.solution.model_copy(deep=True)
refined_solution.problem_description = req_model.problem_description
refined_solution.solution_description = req_model.solution_description
return refined_solution
refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False)
return SolutionBootstrapResponse(solutions=refined_solutions)
@router.post("/search_prior_art")
async def search_prior_art(req: PriorArtSearchRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> PriorArtSearchResponse:
"""Performs a comprehensive prior art search / FTO search against the provided topics for a drafted solution"""
sema = asyncio.Semaphore(4)
async def __search_topic(topic: str) -> str:
search_prompt = await prompt_env.get_template("search/search_topic.txt").render_async(**{
"topic": topic
})
try:
await sema.acquire()
search_completion = await llm_router.acompletion(model="gemini-v2", messages=[
{"role": "user", "content": search_prompt}
], temperature=0.3, tools=[{"googleSearch": {}}])
return {"topic": topic, "content": search_completion.choices[0].message.content}
finally:
sema.release()
# Dispatch the individual tasks for topic search
topics = await asyncio.gather(*[__search_topic(top) for top in req.topics], return_exceptions=False)
consolidation_prompt = await prompt_env.get_template("search/build_final_report.txt").render_async(**{
"searches": topics
})
# Then consolidate everything into a single detailed topic
consolidation_completion = await llm_router.acompletion(model="gemini-v2", messages=[
{"role": "user", "content": consolidation_prompt}
], temperature=0.5)
return PriorArtSearchResponse(content=consolidation_completion.choices[0].message.content, topic_contents=topics)