Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
#!/usr/bin/env python3 | |
# import gradio as gr | |
import json | |
import logging | |
import os | |
import traceback | |
from pathlib import Path | |
from urllib.parse import urlparse | |
from typing import Dict, Any, List, Set | |
from git import Repo | |
import io | |
import torch | |
import numpy as np | |
import faiss | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer, util | |
from huggingface_hub import snapshot_download | |
import os | |
from openai import AzureOpenAI | |
import requests | |
import re | |
import matplotlib.pyplot as plt | |
from sklearn.manifold import TSNE | |
from sklearn.cluster import KMeans | |
import plotly.graph_objects as go | |
import plotly.express as px | |
import random | |
from sklearn.cluster import AgglomerativeClustering | |
def load_env(): | |
from dotenv import load_dotenv | |
env_path = Path(__file__).parent.parent / '.env' | |
load_dotenv(dotenv_path=env_path) | |
load_env() | |
# Centralized env parameters | |
HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") | |
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") | |
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") | |
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") | |
MODEL_NAME = "gpt-4o-mini" | |
DEPLOYMENT = "gpt-4o-mini" | |
API_VERSION = "2024-12-01-preview" | |
FILE_REGEX = re.compile(r"^diff --git a/(.+?) b/(.+)") | |
LINE_HUNK = re.compile(r"@@ -(?P<old_start>\d+),(?P<old_len>\d+) \+(?P<new_start>\d+),(?P<new_len>\d+) @@") | |
# Configure logging to capture all output | |
log_stream = io.StringIO() | |
log_handler = logging.StreamHandler(log_stream) | |
log_handler.setLevel(logging.INFO) | |
log_formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") | |
log_handler.setFormatter(log_formatter) | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s %(levelname)s %(message)s", | |
handlers=[log_handler, logging.StreamHandler()] | |
) | |
logger = logging.getLogger(__name__) | |
class InferenceContext: | |
def __init__(self, repo_url: str): | |
self.repo_url = repo_url | |
owner, name = self._parse_owner_repo(repo_url) | |
self.repo_id = f"{owner}/{name}" | |
self.repo_dir = f"{owner}-{name}" | |
self.hf_repo_id = "kotlarmilos/repository-learning" | |
# Local paths for downloaded models | |
self.base = Path("artifacts") / self.repo_dir | |
self.model_dirs = { | |
'fine_tune': self.base / 'fine_tune', | |
'contrastive': self.base / 'contrastive', | |
'index': self.base / 'index' | |
} | |
self.code_dir = self.base / 'code' | |
# Create directories | |
for d in (*self.model_dirs.values(), self.code_dir): | |
d.mkdir(parents=True, exist_ok=True) | |
def _parse_owner_repo(url: str) -> tuple[str, str]: | |
parts = urlparse(url).path.strip("/").split("/") | |
if len(parts) < 2: | |
raise ValueError(f"Invalid GitHub URL: {url}") | |
return parts[-2], parts[-1] | |
class InferencePipeline: | |
def __init__(self, ctx: InferenceContext): | |
self.ctx = ctx | |
self.tokenizer = None | |
self.llm = None | |
self.embedder = None | |
self.faiss_index = None | |
self.faiss_metadata = None | |
self.download_artifacts() | |
self.load_models() | |
def download_artifacts(self): | |
"""Download models and index from Hugging Face if they don't exist locally.""" | |
self.repo_files = self._clone_or_pull() | |
snapshot_download( | |
repo_id=self.ctx.hf_repo_id, | |
allow_patterns=f"{self.ctx.repo_dir}/**", | |
local_dir=str(self.ctx.base.parent), | |
local_dir_use_symlinks=False, | |
token=HUGGINGFACE_HUB_TOKEN | |
) | |
logger.info("All artifacts download complete.") | |
def _clone_or_pull(self) -> bool: | |
dest = self.ctx.code_dir | |
git_dir = dest / ".git" | |
if git_dir.exists(): | |
Repo(dest).remotes.origin.pull() | |
logger.info("Pulled latest code into %s", dest) | |
else: | |
Repo.clone_from(self.ctx.repo_url, dest) | |
logger.info("Cloned repo %s into %s", self.ctx.repo_url, dest) | |
return [str(f.relative_to(dest)) for f in dest.rglob("*") if f.is_file()] | |
def load_models(self): | |
"""Load the fine-tuned LLM model.""" | |
self.tokenizer = AutoTokenizer.from_pretrained(self.ctx.model_dirs['fine_tune']) | |
self.local_llm = AutoModelForCausalLM.from_pretrained( | |
self.ctx.model_dirs['fine_tune'], | |
device_map="auto", | |
torch_dtype=torch.bfloat16 | |
) | |
self.enterprise_llm = AzureOpenAI( | |
api_version=API_VERSION, | |
azure_endpoint=AZURE_OPENAI_ENDPOINT, | |
api_key=AZURE_OPENAI_API_KEY, | |
) | |
self.embedder = SentenceTransformer(str(self.ctx.model_dirs['contrastive'])) | |
self.faiss_index = faiss.read_index(str(self.ctx.model_dirs['index'] / "index.faiss")) | |
self.faiss_metadata = json.loads((self.ctx.model_dirs['index'] / "metadata.json").read_text()) | |
logger.info("FAISS index loaded successfully") | |
def _extract_pr_data(self, pr_url: str) -> dict: | |
""" | |
Collect PR data using GitHub API. | |
""" | |
match = re.search(r'/pull/(\d+)', pr_url) | |
pr_number = int(match.group(1)) | |
pr_url = f"https://api.github.com/repos/{self.ctx.repo_id}/pulls/{pr_number}" | |
comments_url = f"https://api.github.com/repos/{self.ctx.repo_id}/pulls/{pr_number}/comments" | |
headers = {} | |
headers["Authorization"] = f"token {GITHUB_TOKEN}" | |
headers["Accept"] = "application/vnd.github.v3+json" | |
try: | |
logger.info(f"Fetching PR #{pr_number} details...") | |
pr_response = requests.get(pr_url, headers=headers) | |
pr_response.raise_for_status() | |
pr_data = pr_response.json() | |
logger.info(f"Fetching PR #{pr_number} review comments...") | |
comments_response = requests.get(comments_url, headers=headers) | |
comments_response.raise_for_status() | |
comments_data = comments_response.json() | |
grouped = {} | |
for comment in comments_data: | |
hunk = comment.get("diff_hunk", "") | |
grouped.setdefault(hunk, []).append(comment.get("body", "")) | |
review_comments = [ | |
{"diff_hunk": hunk, "comments": comments} | |
for hunk, comments in grouped.items() | |
] | |
logger.info(f"Fetching PR #{pr_number} diff...") | |
diff_headers = headers.copy() | |
diff_headers["Accept"] = "application/vnd.github.v3.diff" | |
diff_response = requests.get(pr_url, headers=diff_headers) | |
diff_response.raise_for_status() | |
parsed_diff = self.parse_diff_with_lines(diff_response.text) | |
result = { | |
"title": pr_data.get("title", ""), | |
"body": pr_data.get("body", ""), | |
"review_comments": review_comments, | |
"diff": diff_response.text, | |
"changed_files": list(parsed_diff['changed_files']), | |
"diff_hunks": parsed_diff['diff_hunks'] | |
} | |
logger.info(f"Successfully collected PR #{pr_number} data") | |
return result | |
except Exception as e: | |
logger.error(f"Error processing PR #{pr_number} data: {e}") | |
raise | |
def parse_diff_with_lines(self, diff_text: str) -> Dict[str, Any]: | |
lines = diff_text.splitlines() | |
result = { | |
'changed_files': set(), | |
'diff_hunks': {} | |
} | |
current_file = None | |
current_hunk_content = [] | |
current_line_range = None | |
file_header_lines = [] | |
for line in lines: | |
# Check if this is a new file header | |
file_match = FILE_REGEX.match(line) | |
if file_match: | |
# Save previous file data if exists | |
if current_file and current_hunk_content and current_line_range: | |
if current_file not in result['diff_hunks']: | |
result['diff_hunks'][current_file] = [] | |
result['diff_hunks'][current_file].append({ | |
'line_range': current_line_range, | |
'content': '\n'.join(current_hunk_content) | |
}) | |
# Start new file | |
current_file = file_match.group(2) # Use the 'b/' file path (new file) | |
result['changed_files'].add(current_file) | |
file_header_lines = [line] | |
current_hunk_content = [] | |
current_line_range = None | |
elif current_file: # Only process if we're inside a file | |
# Check for hunk headers to extract line ranges | |
hunk_match = LINE_HUNK.match(line) | |
if hunk_match: | |
# Save previous hunk if exists | |
if current_hunk_content and current_line_range: | |
if current_file not in result['diff_hunks']: | |
result['diff_hunks'][current_file] = [] | |
result['diff_hunks'][current_file].append({ | |
'line_range': current_line_range, | |
'content': '\n'.join(current_hunk_content) | |
}) | |
# Start new hunk | |
old_start = int(hunk_match.group('old_start')) | |
old_len = int(hunk_match.group('old_len')) | |
new_start = int(hunk_match.group('new_start')) | |
new_len = int(hunk_match.group('new_len')) | |
# Calculate the range of changed lines | |
if new_len > 0: | |
line_start = new_start | |
line_end = new_start + new_len - 1 | |
current_line_range = (line_start, line_end) | |
else: | |
current_line_range = (new_start, new_start) | |
# Start fresh hunk content with file headers and current hunk header | |
current_hunk_content = file_header_lines + [line] | |
else: | |
# Add content line to current hunk | |
if current_hunk_content is not None: | |
current_hunk_content.append(line) | |
# Save the last hunk data | |
if current_file and current_hunk_content and current_line_range: | |
if current_file not in result['diff_hunks']: | |
result['diff_hunks'][current_file] = [] | |
result['diff_hunks'][current_file].append({ | |
'line_range': current_line_range, | |
'content': '\n'.join(current_hunk_content) | |
}) | |
return result | |
def analyze_file_similarity(self, changed_files: List[str]) -> Dict[str, Any]: | |
result = { | |
'similar_file_groups': [], | |
'anomalous_files': [], | |
'analysis_summary': { | |
'total_files': len(changed_files), | |
'num_groups': 0, | |
'num_anomalies': 0, | |
'avg_group_size': 0 | |
} | |
} | |
# Handle edge cases | |
if len(changed_files) == 0: | |
logger.info("No changed files to analyze") | |
return result | |
if len(changed_files) == 1: | |
logger.info(f"Only one file changed: {changed_files[0]} - no similarity analysis needed") | |
result['analysis_summary']['num_anomalies'] = 1 | |
result['anomalous_files'].append({ | |
'file': changed_files[0], | |
'reason': 'single_file', | |
'max_similarity_to_others': 0.0, | |
'most_similar_file': None, | |
'is_anomaly': False | |
}) | |
return result | |
# Encode all changed files | |
file_embeddings = self.embedder.encode(changed_files, convert_to_tensor=True) | |
similarity_matrix = util.pytorch_cos_sim(file_embeddings, file_embeddings) | |
# Convert similarity matrix to distance matrix for clustering | |
distance_matrix = 1 - similarity_matrix.cpu().numpy() | |
# Perform hierarchical clustering | |
clustering = AgglomerativeClustering( | |
n_clusters=None, | |
distance_threshold=0.3, # 1 - 0.7 = 0.3 (similarity threshold of 0.7) | |
metric='precomputed', | |
linkage='average' | |
) | |
cluster_labels = clustering.fit_predict(distance_matrix) | |
# Group files by cluster | |
clusters = {} | |
for i, label in enumerate(cluster_labels): | |
if label not in clusters: | |
clusters[label] = [] | |
clusters[label].append((changed_files[i], i)) # Store file and its index | |
# Process clusters to identify groups and anomalies | |
for cluster_id, files_with_indices in clusters.items(): | |
files_in_cluster = [f[0] for f in files_with_indices] | |
if len(files_in_cluster) > 1: | |
# This is a group of similar files | |
group_similarities = [] | |
pairwise_similarities = [] | |
for i in range(len(files_with_indices)): | |
for j in range(i+1, len(files_with_indices)): | |
file_i_idx = files_with_indices[i][1] | |
file_j_idx = files_with_indices[j][1] | |
similarity = float(similarity_matrix[file_i_idx][file_j_idx]) | |
group_similarities.append(similarity) | |
pairwise_similarities.append({ | |
'file1': files_with_indices[i][0], | |
'file2': files_with_indices[j][0], | |
'similarity': similarity | |
}) | |
avg_similarity = sum(group_similarities) / len(group_similarities) if group_similarities else 0 | |
min_similarity = min(group_similarities) if group_similarities else 0 | |
max_similarity = max(group_similarities) if group_similarities else 0 | |
result['similar_file_groups'].append({ | |
'cluster_id': cluster_id, | |
'files': files_in_cluster, | |
'avg_similarity': avg_similarity, | |
'min_similarity': min_similarity, | |
'max_similarity': max_similarity, | |
'pairwise_similarities': pairwise_similarities, | |
'coherence': 'high' if min_similarity > 0.6 else 'medium' if min_similarity > 0.4 else 'low' | |
}) | |
else: | |
# This is a singleton cluster - potentially anomalous | |
file = files_in_cluster[0] | |
file_idx = files_with_indices[0][1] | |
# Calculate maximum similarity to any other file | |
max_similarity = 0 | |
most_similar_file = None | |
similarities_to_others = [] | |
for other_idx, other_file in enumerate(changed_files): | |
if other_idx != file_idx: | |
similarity = float(similarity_matrix[file_idx][other_idx]) | |
similarities_to_others.append({ | |
'file': other_file, | |
'similarity': similarity | |
}) | |
if similarity > max_similarity: | |
max_similarity = similarity | |
most_similar_file = other_file | |
result['anomalous_files'].append({ | |
'file': file, | |
'cluster_id': cluster_id, | |
'max_similarity_to_others': max_similarity, | |
'most_similar_file': most_similar_file, | |
'similarities_to_others': similarities_to_others, | |
'is_anomaly': max_similarity < 0.5, # Strong anomaly threshold | |
'anomaly_strength': 'strong' if max_similarity < 0.3 else 'medium' if max_similarity < 0.5 else 'weak', | |
'reason': 'isolated_cluster' | |
}) | |
# Additional anomaly detection: files that are far from the group average | |
if len(changed_files) >= 3: | |
# Calculate average embedding of all changed files | |
avg_embedding = torch.mean(file_embeddings, dim=0) | |
# Find files that are far from the average | |
for i, file in enumerate(changed_files): | |
file_embedding = file_embeddings[i] | |
similarity_to_avg = float(util.pytorch_cos_sim(file_embedding.unsqueeze(0), avg_embedding.unsqueeze(0))[0][0]) | |
# Check if this file is already in anomalous_files | |
existing_anomaly = next((a for a in result['anomalous_files'] if a['file'] == file), None) | |
if existing_anomaly: | |
# Update existing anomaly record | |
existing_anomaly['similarity_to_group_avg'] = similarity_to_avg | |
existing_anomaly['is_strong_anomaly'] = ( | |
similarity_to_avg < 0.4 and existing_anomaly['max_similarity_to_others'] < 0.5 | |
) | |
if existing_anomaly['is_strong_anomaly']: | |
existing_anomaly['anomaly_strength'] = 'very_strong' | |
elif similarity_to_avg < 0.4: # Low similarity to group average | |
# Calculate similarities to all other files | |
similarities_to_others = [] | |
max_sim = 0 | |
most_sim_file = None | |
for j, other_file in enumerate(changed_files): | |
if i != j: | |
sim = float(similarity_matrix[i][j]) | |
similarities_to_others.append({ | |
'file': other_file, | |
'similarity': sim | |
}) | |
if sim > max_sim: | |
max_sim = sim | |
most_sim_file = other_file | |
result['anomalous_files'].append({ | |
'file': file, | |
'cluster_id': None, | |
'max_similarity_to_others': max_sim, | |
'most_similar_file': most_sim_file, | |
'similarities_to_others': similarities_to_others, | |
'similarity_to_group_avg': similarity_to_avg, | |
'is_anomaly': True, | |
'is_strong_anomaly': max_sim < 0.5, | |
'anomaly_strength': 'very_strong' if max_sim < 0.3 else 'strong' if max_sim < 0.5 else 'medium', | |
'reason': 'distant_from_group_average' | |
}) | |
# Update analysis summary | |
result['analysis_summary']['num_groups'] = len(result['similar_file_groups']) | |
result['analysis_summary']['num_anomalies'] = len(result['anomalous_files']) | |
if result['similar_file_groups']: | |
total_files_in_groups = sum(len(group['files']) for group in result['similar_file_groups']) | |
result['analysis_summary']['avg_group_size'] = total_files_in_groups / len(result['similar_file_groups']) | |
# Log results | |
logger.info(f"File similarity analysis complete:") | |
logger.info(f" Total files: {result['analysis_summary']['total_files']}") | |
logger.info(f" Similar groups: {result['analysis_summary']['num_groups']}") | |
logger.info(f" Anomalous files: {result['analysis_summary']['num_anomalies']}") | |
for i, group in enumerate(result['similar_file_groups']): | |
logger.info(f" Group {i+1} ({group['coherence']} coherence): {group['files']} (avg: {group['avg_similarity']:.3f})") | |
for anomaly in result['anomalous_files']: | |
logger.info(f" {anomaly['anomaly_strength'].upper()} ANOMALY: {anomaly['file']} (reason: {anomaly['reason']}, max_sim: {anomaly['max_similarity_to_others']:.3f})") | |
return result | |
# TODO: Add local LLM reasoning | |
# def generate_llm_response(self, prompt: str, max_new_tokens: int = 256) -> str: | |
# """Generate response using the fine-tuned LLM.""" | |
# if not self.tokenizer or not self.local_llm: | |
# raise ValueError("LLM not loaded. Call load_llm() first.") | |
# inputs = self.tokenizer(prompt, return_tensors="pt").to(self.local_llm.device) | |
# outputs = self.local_llm.generate( | |
# **inputs, | |
# max_new_tokens=max_new_tokens, | |
# pad_token_id=self.tokenizer.eos_token_id | |
# ) | |
# return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def search_code_snippets(self, diff_hunks) -> list: | |
metadata_file = self.ctx.model_dirs["index"] / "metadata.json" | |
with open(metadata_file, 'r', encoding='utf-8') as f: | |
metadata = json.load(f) | |
result = [] | |
# Process each file's diff hunks | |
for file_path, hunks in diff_hunks.items(): | |
logger.info(f"Searching functions for file: {file_path}") | |
for hunk in hunks: | |
line_range = hunk.get('line_range') | |
if not line_range: | |
continue | |
start_line, end_line = line_range | |
logger.debug(f"Processing hunk at lines {start_line}-{end_line}") | |
# Find functions that overlap with this line range | |
overlapping_functions = [] | |
for func_metadata in metadata: | |
func_file = func_metadata.get('file', '') | |
func_start = func_metadata.get('start_line') | |
func_end = func_metadata.get('end_line') | |
func_name = func_metadata.get('name', 'unknown') | |
func_description = func_metadata.get('llm_description', '') | |
# Check if this function is in the same file | |
if func_file != file_path: | |
continue | |
# Check if function line range overlaps with diff hunk line range | |
if func_start is not None and func_end is not None: | |
# Check for overlap: function overlaps if it starts before diff ends | |
# and ends after diff starts | |
if func_start <= end_line and func_end >= start_line: | |
overlap_start = max(func_start, start_line) | |
overlap_end = min(func_end, end_line) | |
overlapping_functions.append({ | |
'function_name': func_name, | |
'function_description': func_description, | |
'function_start_line': func_start, | |
'function_end_line': func_end, | |
# 'overlap_start': overlap_start, | |
# 'overlap_end': overlap_end, | |
# 'overlap_lines': overlap_end - overlap_start + 1 | |
}) | |
# if len(overlapping_functions) > 0: | |
hunk_result = { | |
'file_name': file_path, | |
'diff_hunk': hunk.get('content', ''), | |
'overlapping_functions': overlapping_functions | |
} | |
result.append(hunk_result) | |
total_hunks = sum(len(hunks) for hunks in diff_hunks.values()) | |
total_functions = sum(len(entry['overlapping_functions']) for entry in result) | |
logger.info(f"Processed {total_hunks} diff hunks across {len(diff_hunks)} files, found {total_functions} overlapping functions") | |
return result | |
def _select_files_around_changed(self, changed_files: List[str] = None, max_files: int = 500) -> List[str]: | |
"""Select files to visualize, prioritizing changed files and semantically similar ones.""" | |
logger.info(f"Selecting {max_files} files around {len(changed_files)} changed files...") | |
# Start with changed files | |
selected_files = set(changed_files) | |
# Find files similar to changed files using embeddings | |
try: | |
# Encode changed files | |
changed_embeddings = self.embedder.encode(changed_files, convert_to_tensor=False) | |
# Calculate target number of similar files to find | |
target_similar = min(max_files - len(changed_files), 200) # Leave room for random files | |
# Get a sample of repo files to compare against (for performance) | |
sample_size = min(2000, len(self.repo_files)) | |
repo_sample = self.repo_files[:sample_size] | |
# Remove already selected files from sample | |
repo_sample = [f for f in repo_sample if f not in selected_files] | |
if len(repo_sample) > 0: | |
# Encode sample files | |
sample_embeddings = self.embedder.encode(repo_sample, convert_to_tensor=False, show_progress_bar=False) | |
# Calculate similarities | |
similarities = [] | |
for i, repo_file in enumerate(repo_sample): | |
# Calculate max similarity to any changed file | |
max_sim = 0 | |
for changed_emb in changed_embeddings: | |
sim = np.dot(changed_emb, sample_embeddings[i]) / ( | |
np.linalg.norm(changed_emb) * np.linalg.norm(sample_embeddings[i]) | |
) | |
max_sim = max(max_sim, sim) | |
# Only add if not already selected (avoid duplicates) | |
similarities.append((repo_file, max_sim)) | |
# Sort by similarity and take top ones, avoiding duplicates | |
added = 0 | |
for file_path, sim in sorted(similarities, key=lambda x: x[1], reverse=True): | |
if file_path not in selected_files: | |
selected_files.add(file_path) | |
added += 1 | |
if len(selected_files) >= max_files or added >= target_similar: | |
break | |
logger.info(f"Added {len(similarities[:target_similar])} similar files to visualization") | |
except Exception as e: | |
logger.warning(f"Could not compute file similarities: {e}") | |
# Fill remaining slots with random files | |
remaining_slots = max_files - len(selected_files) | |
if remaining_slots > 0: | |
remaining_files = [f for f in self.repo_files if f not in selected_files] | |
random.shuffle(remaining_files) | |
for file_path in remaining_files[:remaining_slots]: | |
selected_files.add(file_path) | |
result = list(selected_files) | |
logger.info(f"Selected {len(result)} files total: {len(changed_files)} changed, {len(result) - len(changed_files)} related/random") | |
return result | |
def create_repo_visualization(self, changed_files: List[str] = None, max_files: int = 500): | |
files_to_plot = self._select_files_around_changed(changed_files, max_files * len(changed_files)) | |
logger.info(f"Creating visualization for {len(files_to_plot)} files...") | |
if len(files_to_plot) < 2: | |
return self._create_dummy_plot(f"Only {len(files_to_plot)} files available") | |
embeddings = self.embedder.encode(files_to_plot, convert_to_tensor=False, show_progress_bar=False) | |
logger.info(f"Embeddings computed successfully: shape {getattr(embeddings, 'shape', None)}") | |
n = len(files_to_plot) | |
perplexity = min(30, max(1, n - 1)) | |
tsne = TSNE(n_components=3, perplexity=perplexity, init='random', random_state=42) | |
reduced = tsne.fit_transform(embeddings) | |
fig = go.Figure() | |
colors = [] | |
sizes = [] | |
hover_texts = [] | |
for i, file_path in enumerate(files_to_plot): | |
if changed_files and file_path in changed_files: | |
colors.append('red') | |
else: | |
# Color by file type | |
ext = os.path.splitext(file_path)[1].lower() | |
if ext in ['.py', '.js', '.ts', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.rs']: | |
colors.append('blue') | |
elif ext in ['.md', '.txt', '.rst', '.doc']: | |
colors.append('green') | |
elif ext in ['.json', '.yaml', '.yml', '.xml', '.toml', '.ini']: | |
colors.append('orange') | |
elif ext in ['.html', '.css', '.scss', '.sass']: | |
colors.append('purple') | |
else: | |
colors.append('gray') | |
sizes.append(8) | |
hover_texts.append(f"{os.path.basename(file_path)}") | |
fig.add_trace(go.Scatter3d( | |
x=reduced[:, 0].tolist(), | |
y=reduced[:, 1].tolist(), | |
z=reduced[:, 2].tolist(), | |
mode='markers+text', | |
marker=dict(size=sizes, color=colors), | |
text=[os.path.basename(f) for f in files_to_plot], | |
hovertext=hover_texts, | |
textposition='middle center', | |
name='Repository Files' | |
)) | |
title_text = 'Repository File Embeddings (3D t-SNE)' | |
if changed_files: | |
title_text += f' - {len(changed_files)} Changed Files Highlighted in Red' | |
fig.update_layout( | |
title=title_text, | |
scene=dict( | |
xaxis_title='t-SNE 1', | |
yaxis_title='t-SNE 2', | |
zaxis_title='t-SNE 3', | |
camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)) | |
), | |
width=800, | |
height=600, | |
margin=dict(r=20, b=10, l=10, t=60) | |
) | |
return fig | |
def build_structured_prompt(self, data: dict, sim_analysis: dict, code_desc: list) -> str: | |
# Group clusters | |
clusters = sim_analysis['similar_file_groups'] | |
anomalies = sim_analysis['anomalous_files'] | |
# Header | |
prompt = [] | |
prompt.append("You are an expert reviewer. First give group summaries, then detailed line-by-line feedback.") | |
prompt.append(f"Title: {data['title']}") | |
prompt.append(f"Description: {data['body']}") | |
# Clusters | |
for c in clusters: | |
prompt.append(f"## Group {c['cluster_id']} ({len(c['files'])} files, avg_sim={c['avg_similarity']:.2f}): {', '.join(c['files'])}") | |
prompt.append("Files:") | |
for f in c['files']: | |
prompt.append(f"- {f}") | |
prompt.append(f"Summary: Changes in these files share semantic pattern. Focus on shared logic.") | |
# Anomalies | |
if anomalies: | |
prompt.append("## Isolated Files (low similarity with changed files)") | |
for a in anomalies: | |
prompt.append(f"- {a['file']} (reason: {a['reason']}, strength: {a.get('anomaly_strength')})") | |
# Grounding diffs per cluster/files | |
prompt.append("## Diff Hunks and Context:") | |
for entry in code_desc: | |
prompt.append(f"File: {entry['file_name']}\n{entry['diff_hunk']}") | |
if entry['overlapping_functions']: | |
prompt.append("Affected functions:") | |
for f in entry['overlapping_functions']: | |
prompt.append(f"- {f['function_name']}: {f['function_description']}") | |
# Request | |
prompt.append("Provide feedback on groups, then isolated files. After that provide line-by-line feedback in diff format.") | |
return "\n".join(prompt) | |
def get_current_logs(): | |
return log_stream.getvalue() | |
# Pipeline | |
pipeline = InferencePipeline(InferenceContext("https://github.com/dotnet/xharness")) | |
def analyze_pr_streaming(pr_url): | |
log_stream.seek(0) | |
log_stream.truncate() | |
response = {} | |
base_review = "" | |
final_review = "" | |
visualization = None | |
data = pipeline._extract_pr_data(pr_url) | |
yield base_review, final_review, get_current_logs(), visualization | |
visualization = pipeline.create_repo_visualization(list(data["changed_files"]), max_files=20) | |
yield "", "", get_current_logs(), visualization | |
similarity_analysis = pipeline.analyze_file_similarity(list(data["changed_files"])) | |
similar_file_groups = similarity_analysis['similar_file_groups'] | |
anomalous_files = similarity_analysis['anomalous_files'] | |
yield "", "", get_current_logs(), visualization | |
code_description = pipeline.search_code_snippets(data["diff_hunks"]) | |
comprehensive_prompt = pipeline.build_structured_prompt(data, similarity_analysis, code_description) | |
# Base prompt | |
base_prompt = f"""You are an expert reviewer. Provide detailed line-by-line feedback. | |
Title: {data['title']} | |
Description: {data['body']} | |
Diff: {data['diff']} | |
""" | |
# similar_file_groups_formatted = [] | |
# for i, group in enumerate(similar_file_groups): | |
# files_str = ", ".join(group['files']) | |
# similar_file_groups_formatted.append(f"group {i}: {files_str}") | |
# anomalous_files_formatted = [] | |
# for anomaly in anomalous_files: | |
# anomalous_files_formatted.append(f"anomaly: {anomaly['file']} (reason: {anomaly['reason']}, strength: {anomaly['anomaly_strength']})") | |
# grounding_formatted = "" | |
# for entry in code_description: | |
# file_name = entry['file_name'] | |
# overlapping_functions = entry['overlapping_functions'] | |
# diff_hunk = entry['diff_hunk'] | |
# if len(overlapping_functions) > 0: | |
# grounding_formatted += f"In file {file_name}, the following changes were made: {diff_hunk}\n" | |
# grounding_formatted += f"These changes affected the following functions:\n" | |
# for func in overlapping_functions: | |
# grounding_formatted += f"{func['function_name']} - {func['function_description']}\n" | |
# else: | |
# grounding_formatted += f"In file {file_name}, the following changes were made: {diff_hunk}\n" | |
# grounding_formatted += "\n" | |
# # Create formatted strings for f-string | |
# similar_groups_text = "\n".join(similar_file_groups_formatted) | |
# anomalous_files_text = "\n".join(anomalous_files_formatted) | |
# # TODO: Add local LLM reasoning | |
# # TODO: Add relevant files from the directory not included | |
# comprehensive_prompt = f"""{base_prompt} | |
# FILES THAT ARE SEMANTICALLY CLOSE CHANGED IN THIS PR: | |
# {similar_groups_text} | |
# UNEXPECTED CHANGES IN FILES: | |
# {anomalous_files_text} | |
# GROUNDING DATA: The following provides specific information about which functions are affected by each diff hunk: | |
# {grounding_formatted} | |
# """ | |
base_prompt += f""" | |
DIFF: {data['diff']} | |
""" | |
logger.info(f"Base prompt word count: {len(base_prompt.split())}") | |
logger.info(f"Base prompt: {base_prompt}") | |
logger.info(f"Comprehensive prompt word count: {len(comprehensive_prompt.split())}") | |
logger.info(f"Comprehensive prompt: {comprehensive_prompt}") | |
logger.info("Calling Azure OpenAI...") | |
yield "", "", get_current_logs(), visualization | |
base_review_response = pipeline.enterprise_llm.chat.completions.create( | |
model=DEPLOYMENT, | |
messages=[ | |
{"role": "system", "content": "You are an expert code reviewer. Provide thorough, constructive feedback."}, | |
{"role": "user", "content": base_prompt} | |
], | |
max_tokens=8192, | |
temperature=0.3 | |
) | |
base_review = base_review_response.choices[0].message.content | |
logger.info("Base review completed") | |
final_review_response = pipeline.enterprise_llm.chat.completions.create( | |
model=DEPLOYMENT, | |
messages=[ | |
{"role": "system", "content": "You are an expert code reviewer. Provide thorough, constructive feedback."}, | |
{"role": "user", "content": comprehensive_prompt} | |
], | |
max_tokens=8192, | |
temperature=0.3 | |
) | |
final_review = final_review_response.choices[0].message.content | |
logger.info("Final review completed") | |
yield base_review, final_review, get_current_logs(), visualization | |
with gr.Blocks(title="PR Code Review Assistant") as demo: | |
gr.Markdown("# PR Code Review Assistant") | |
gr.Markdown("Enter a GitHub PR URL to get comprehensive code review analysis with interactive repository visualization.") | |
with gr.Row(): | |
pr_url_input = gr.Textbox( | |
label="GitHub PR URL", | |
placeholder="https://github.com/owner/repo/pull/123", | |
value="https://github.com/dotnet/xharness/pull/1416" | |
) | |
analyze_btn = gr.Button("Analyze PR", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
base_review_output = gr.Textbox( | |
label="Base Review", | |
lines=15, | |
max_lines=30, | |
interactive=False | |
) | |
with gr.Column(scale=1): | |
final_review_output = gr.Textbox( | |
label="Comprehensive Review", | |
lines=15, | |
max_lines=30, | |
interactive=False | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
visualization_output = gr.Plot( | |
label="Repository Files Visualization (3D)", | |
value=None | |
) | |
with gr.Column(scale=1): | |
logs_output = gr.Textbox( | |
label="Analysis Logs", | |
lines=15, | |
max_lines=25, | |
interactive=False, | |
show_copy_button=True | |
) | |
analyze_btn.click( | |
fn=analyze_pr_streaming, | |
inputs=[pr_url_input], | |
outputs=[base_review_output, final_review_output, logs_output, visualization_output], | |
show_progress=True | |
) | |
if __name__ == "__main__": | |
demo.launch() | |