Reqxtract-v2 / api /requirements.py
Lucas ARRIESSE
WIP solution drafting
e97be0e
import logging
from fastapi import APIRouter, Depends, HTTPException
from jinja2 import Environment
from litellm.router import Router
from dependencies import get_llm_router, get_prompt_templates
from schemas import _ReqGroupingCategory, _ReqGroupingOutput, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse
# Router for requirement processing
router = APIRouter(tags=["requirement processing"])
@router.post("/get_reqs_from_query", response_model=ReqSearchResponse)
def find_requirements_from_problem_description(req: ReqSearchRequest, llm_router: Router = Depends(get_llm_router)):
"""Finds the requirements that adress a given problem description from an extracted list"""
requirements = req.requirements
query = req.query
requirements_text = "\n".join(
[f"[Selection ID: {r.req_id} | Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for r in requirements])
resp_ai = llm_router.completion(
model="gemini-v2",
messages=[{"role": "user", "content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of 'Selection ID' for the most relevant corresponding requirements that reference or best cover the problem. If none of the requirements covers the problem, simply return an empty list"}],
response_format=ReqSearchLLMResponse
)
out_llm = ReqSearchLLMResponse.model_validate_json(
resp_ai.choices[0].message.content).selected
logging.info(f"Found {len(out_llm)} reqs matching case.")
if max(out_llm) > len(requirements) - 1:
raise HTTPException(
status_code=500, detail="LLM error : Generated a wrong index, please try again.")
return ReqSearchResponse(requirements=[requirements[i] for i in out_llm])
@router.post("/categorize_requirements")
async def categorize_reqs(params: ReqGroupingRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> ReqGroupingResponse:
"""Categorize the given service requirements into categories"""
MAX_ATTEMPTS = 5
categories: list[_ReqGroupingCategory] = []
messages = []
# categorize the requirements using their indices
req_prompt = await prompt_env.get_template("classify.txt").render_async(**{
"requirements": [rq.model_dump() for rq in params.requirements],
"max_n_categories": params.max_n_categories,
"response_schema": _ReqGroupingOutput.model_json_schema()})
# add system prompt with requirements
messages.append({"role": "user", "content": req_prompt})
# ensure all requirements items are processed
for attempt in range(MAX_ATTEMPTS):
req_completion = await llm_router.acompletion(model="gemini-v2", messages=messages, response_format=_ReqGroupingOutput)
output = _ReqGroupingOutput.model_validate_json(
req_completion.choices[0].message.content)
# quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
valid_ids_universe = set(range(0, len(params.requirements)))
assigned_ids = {
req_id for cat in output.categories for req_id in cat.items}
# keep only non-hallucinated, valid assigned ids
valid_assigned_ids = assigned_ids.intersection(valid_ids_universe)
# check for remaining requirements assigned to none of the categories
unassigned_ids = valid_ids_universe - valid_assigned_ids
if len(unassigned_ids) == 0:
categories.extend(output.categories)
break
else:
messages.append(req_completion.choices[0].message)
messages.append(
{"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."})
if attempt == MAX_ATTEMPTS - 1:
raise Exception("Failed to classify all requirements")
# build the final category objects
# remove the invalid (likely hallucinated) requirement IDs
final_categories = []
for idx, cat in enumerate(output.categories):
final_categories.append(ReqGroupingCategory(
id=idx,
title=cat.title,
requirements=[params.requirements[i]
for i in cat.items if i < len(params.requirements)]
))
return ReqGroupingResponse(categories=final_categories)