|
|
|
""" |
|
Backend Code Generation Model Training Pipeline |
|
=============================================== |
|
|
|
A comprehensive training pipeline for building an AI model that generates |
|
framework-agnostic backend code with full application scaffolding. |
|
|
|
Features: |
|
- Data collection from multiple sources |
|
- Multi-framework support (Express.js, FastAPI, Django, Flask, etc.) |
|
- Full application scaffolding generation |
|
- Model training with transformer architecture |
|
- Evaluation and benchmarking tools |
|
""" |
|
|
|
import os |
|
import json |
|
import logging |
|
import asyncio |
|
import aiohttp |
|
import pandas as pd |
|
import numpy as np |
|
from typing import Dict, List, Optional, Tuple, Any |
|
from dataclasses import dataclass, asdict |
|
from pathlib import Path |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import ( |
|
AutoTokenizer, AutoModelForCausalLM, TrainingArguments, |
|
Trainer, DataCollatorForLanguageModeling |
|
) |
|
from datasets import Dataset as HFDataset |
|
import ast |
|
import subprocess |
|
import tempfile |
|
from concurrent.futures import ThreadPoolExecutor |
|
import requests |
|
import time |
|
import random |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class CodeExample: |
|
"""Represents a single training example""" |
|
description: str |
|
requirements: List[str] |
|
framework: str |
|
language: str |
|
code_files: Dict[str, str] |
|
project_structure: Dict[str, Any] |
|
metadata: Dict[str, Any] |
|
|
|
|
|
class DataCollector: |
|
"""Collects training data from various sources""" |
|
|
|
def __init__(self): |
|
self.github_token = os.getenv('GITHUB_TOKEN') |
|
self.collected_examples: List[CodeExample] = [] |
|
|
|
async def collect_github_repositories(self, queries: List[str], max_repos: int = 100): |
|
"""Collect backend projects from GitHub""" |
|
logger.info("Starting GitHub repository collection...") |
|
|
|
headers = {'Authorization': f'token {self.github_token}'} if self.github_token else {} |
|
|
|
async with aiohttp.ClientSession(headers=headers) as session: |
|
per_query = max(1, max_repos // max(1, len(queries))) |
|
for query in queries: |
|
await self._search_github_repos(session, query, per_query) |
|
|
|
async def _search_github_repos(self, session: aiohttp.ClientSession, query: str, limit: int): |
|
"""Search GitHub for repositories matching query""" |
|
url = f"https://api.github.com/search/repositories" |
|
params = { |
|
'q': query, |
|
'sort': 'stars', |
|
'order': 'desc', |
|
'per_page': min(limit, 100) |
|
} |
|
|
|
try: |
|
async with session.get(url, params=params) as response: |
|
if response.status == 200: |
|
data = await response.json() |
|
for repo in data.get('items', []): |
|
await self._process_repository(session, repo) |
|
else: |
|
logger.warning(f"GitHub API request failed: {response.status}") |
|
except Exception as e: |
|
logger.error(f"Error searching GitHub: {e}") |
|
|
|
async def _process_repository(self, session: aiohttp.ClientSession, repo: Dict): |
|
"""Process a single repository to extract code examples""" |
|
logger.info(f"Processing repository: {repo.get('full_name', '<unknown>')}") |
|
|
|
try: |
|
contents_url = f"https://api.github.com/repos/{repo['full_name']}/contents" |
|
async with session.get(contents_url) as response: |
|
if response.status == 200: |
|
contents = await response.json() |
|
await self._extract_code_example(session, repo, contents) |
|
except Exception as e: |
|
logger.error(f"Error processing repository {repo.get('full_name')}: {e}") |
|
|
|
async def _extract_code_example(self, session: aiohttp.ClientSession, repo: Dict, contents: List[Dict]): |
|
"""Extract a structured code example from repository""" |
|
framework = self._identify_framework(contents, repo.get('description', '')) |
|
language = self._identify_language(contents) |
|
|
|
if not framework or not language: |
|
return |
|
|
|
code_files: Dict[str, str] = {} |
|
for item in contents: |
|
if item.get('type') == 'file' and self._is_important_file(item.get('name', '')): |
|
try: |
|
async with session.get(item['download_url']) as response: |
|
if response.status == 200: |
|
content = await response.text() |
|
code_files[item['name']] = content |
|
except Exception: |
|
continue |
|
|
|
if code_files: |
|
example = CodeExample( |
|
description=repo.get('description', ''), |
|
requirements=self._extract_requirements(code_files), |
|
framework=framework, |
|
language=language, |
|
code_files=code_files, |
|
project_structure=self._analyze_structure(contents), |
|
metadata={ |
|
'stars': repo.get('stargazers_count', 0), |
|
'forks': repo.get('forks_count', 0), |
|
'url': repo.get('html_url'), |
|
'created_at': repo.get('created_at'), |
|
'updated_at': repo.get('updated_at') |
|
} |
|
) |
|
self.collected_examples.append(example) |
|
|
|
def _identify_framework(self, contents: List[Dict], description: str) -> Optional[str]: |
|
"""Identify the backend framework used""" |
|
filenames = [item.get('name', '').lower() for item in contents if item.get('type') == 'file'] |
|
|
|
frameworks = { |
|
'express': ['package.json', 'app.js', 'server.js'], |
|
'fastapi': ['requirements.txt', 'main.py', 'app.py'], |
|
'django': ['manage.py', 'settings.py', 'requirements.txt'], |
|
'flask': ['app.py', 'requirements.txt'], |
|
'nestjs': ['nest-cli.json', 'package.json'], |
|
'koa': ['package.json'], |
|
'gin': ['go.mod', 'main.go'], |
|
'fiber': ['go.mod', 'main.go'], |
|
} |
|
|
|
for framework, required_files in frameworks.items(): |
|
if all(any(req in filename for filename in filenames) for req in required_files[:2]): |
|
return framework |
|
|
|
desc_lower = description.lower() |
|
for framework in frameworks.keys(): |
|
if framework in desc_lower: |
|
return framework |
|
|
|
return None |
|
|
|
def _identify_language(self, contents: List[Dict]) -> Optional[str]: |
|
"""Identify primary programming language""" |
|
extensions: Dict[str, int] = {} |
|
for item in contents: |
|
if item.get('type') == 'file': |
|
ext = Path(item.get('name', '')).suffix.lower() |
|
if ext: |
|
extensions[ext] = extensions.get(ext, 0) + 1 |
|
|
|
lang_map = { |
|
'.js': 'javascript', |
|
'.ts': 'typescript', |
|
'.py': 'python', |
|
'.go': 'go', |
|
'.java': 'java', |
|
'.cs': 'csharp', |
|
'.rb': 'ruby', |
|
'.php': 'php' |
|
} |
|
|
|
if extensions: |
|
most_common_ext = max(extensions.items(), key=lambda x: x[1])[0] |
|
return lang_map.get(most_common_ext) |
|
|
|
return None |
|
|
|
def _is_important_file(self, filename: str) -> bool: |
|
"""Check if file is important for training""" |
|
important_patterns = [ |
|
'package.json', 'requirements.txt', 'go.mod', 'pom.xml', |
|
'dockerfile', 'docker-compose.yml', 'readme.md', |
|
'app.py', 'main.py', 'server.js', 'app.js', 'index.js', |
|
'settings.py', 'config.py', 'routes.py', 'models.py', |
|
'controller.js', 'service.js', 'middleware.js' |
|
] |
|
|
|
filename_lower = filename.lower() |
|
return any(pattern in filename_lower for pattern in important_patterns) |
|
|
|
def _extract_requirements(self, code_files: Dict[str, str]) -> List[str]: |
|
"""Extract functional requirements from code""" |
|
requirements: List[str] = [] |
|
|
|
if 'package.json' in code_files: |
|
try: |
|
pkg_data = json.loads(code_files['package.json']) |
|
deps = list(pkg_data.get('dependencies', {}).keys()) |
|
requirements.extend([f"Uses {dep}" for dep in deps[:5]]) |
|
except Exception: |
|
pass |
|
|
|
if 'requirements.txt' in code_files: |
|
lines = code_files['requirements.txt'].strip().split('\n') |
|
deps = [line.split('==')[0].split('>=')[0].strip() for line in lines if line.strip()] |
|
requirements.extend([f"Uses {dep}" for dep in deps[:5]]) |
|
|
|
for filename, content in code_files.items(): |
|
if filename.endswith(('.js', '.py')): |
|
endpoints = self._extract_endpoints(content) |
|
requirements.extend(endpoints) |
|
|
|
return requirements[:10] |
|
|
|
def _extract_endpoints(self, code_content: str) -> List[str]: |
|
"""Extract API endpoints from code""" |
|
endpoints: List[str] = [] |
|
lines = code_content.split('\n') |
|
|
|
for line in lines: |
|
s = line.strip() |
|
if any(method in s for method in ['app.get(', 'app.post(', 'app.put(', 'app.delete(']): |
|
endpoints.append(f"Implements {s}") |
|
elif any(decorator in s for decorator in ['@app.get(', '@app.post(', '@app.put(', '@app.delete(']): |
|
endpoints.append(f"Implements {s}") |
|
elif 'def ' in s and any(word in s for word in ['get', 'post', 'put', 'delete']): |
|
endpoints.append(f"Implements {s}") |
|
|
|
return endpoints[:5] |
|
|
|
def _analyze_structure(self, contents: List[Dict]) -> Dict[str, Any]: |
|
"""Analyze project structure""" |
|
structure: Dict[str, Any] = { |
|
'files': [], |
|
'directories': [], |
|
'total_files': 0, |
|
'has_tests': False, |
|
'has_docs': False |
|
} |
|
|
|
for item in contents: |
|
if item.get('type') == 'file': |
|
name = item.get('name', '') |
|
structure['files'].append(name) |
|
structure['total_files'] += 1 |
|
if 'test' in name.lower(): |
|
structure['has_tests'] = True |
|
if name.lower() in ['readme.md', 'docs.md']: |
|
structure['has_docs'] = True |
|
elif item.get('type') == 'dir': |
|
structure['directories'].append(item.get('name', '')) |
|
|
|
return structure |
|
|
|
def generate_synthetic_examples(self, count: int = 100): |
|
"""Generate synthetic training examples""" |
|
logger.info(f"Generating {count} synthetic examples...") |
|
|
|
templates = [ |
|
{ |
|
'description': 'REST API for user management', |
|
'requirements': ['User registration', 'User authentication', 'Profile management'], |
|
'frameworks': ['express', 'fastapi', 'django'] |
|
}, |
|
{ |
|
'description': 'E-commerce backend API', |
|
'requirements': ['Product catalog', 'Shopping cart', 'Order processing', 'Payment integration'], |
|
'frameworks': ['nestjs', 'fastapi', 'django'] |
|
}, |
|
{ |
|
'description': 'Task management system', |
|
'requirements': ['Task CRUD operations', 'User assignments', 'Status tracking'], |
|
'frameworks': ['express', 'flask', 'gin'] |
|
}, |
|
{ |
|
'description': 'Blog platform backend', |
|
'requirements': ['Article management', 'User comments', 'Category system'], |
|
'frameworks': ['express', 'django', 'fastapi'] |
|
} |
|
] |
|
|
|
for _ in range(count): |
|
template = random.choice(templates) |
|
framework = random.choice(template['frameworks']) |
|
|
|
code_files = self._generate_code_for_template(template, framework) |
|
|
|
example = CodeExample( |
|
description=template['description'], |
|
requirements=template['requirements'], |
|
framework=framework, |
|
language='python' if framework in ['fastapi', 'django', 'flask'] else 'javascript', |
|
code_files=code_files, |
|
project_structure=self._generate_synthetic_structure(framework), |
|
metadata={'synthetic': True} |
|
) |
|
|
|
self.collected_examples.append(example) |
|
|
|
def _generate_code_for_template(self, template: Dict, framework: str) -> Dict[str, str]: |
|
"""Generate code files for a template and framework""" |
|
if framework == 'express': |
|
return { |
|
'package.json': json.dumps({ |
|
"name": template['description'].lower().replace(' ', '-'), |
|
"version": "1.0.0", |
|
"dependencies": { |
|
"express": "^4.18.0", |
|
"mongoose": "^6.0.0", |
|
"bcrypt": "^5.0.0", |
|
"jsonwebtoken": "^8.5.0" |
|
} |
|
}, indent=2), |
|
'app.js': '''const express = require('express'); |
|
const mongoose = require('mongoose'); |
|
const app = express(); |
|
|
|
// Middleware |
|
app.use(express.json()); |
|
|
|
// Routes |
|
app.get('/health', (req, res) => { |
|
res.json({ status: 'OK' }); |
|
}); |
|
|
|
// Start server |
|
const PORT = process.env.PORT || 3000; |
|
app.listen(PORT, () => { |
|
console.log(`Server running on port ${PORT}`); |
|
}); |
|
|
|
module.exports = app;''' |
|
} |
|
elif framework == 'fastapi': |
|
return { |
|
'requirements.txt': '''fastapi==0.68.0 |
|
uvicorn==0.15.0 |
|
sqlalchemy==1.4.23 |
|
pydantic==1.8.2''', |
|
'main.py': '''from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from typing import List, Optional |
|
|
|
app = FastAPI() |
|
|
|
class Item(BaseModel): |
|
id: Optional[int] = None |
|
name: str |
|
description: str |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Hello World"} |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return {"status": "OK"} |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000)''' |
|
} |
|
else: |
|
return {'placeholder.txt': 'Generated code placeholder'} |
|
|
|
def _generate_synthetic_structure(self, framework: str) -> Dict[str, Any]: |
|
"""Generate project structure for framework""" |
|
if framework in ['express', 'nestjs']: |
|
return { |
|
'files': ['package.json', 'app.js', 'README.md'], |
|
'directories': ['routes', 'controllers', 'middleware', 'models'], |
|
'total_files': 3, |
|
'has_tests': True, |
|
'has_docs': True |
|
} |
|
elif framework in ['fastapi', 'django', 'flask']: |
|
return { |
|
'files': ['requirements.txt', 'main.py', 'README.md'], |
|
'directories': ['models', 'routes', 'services'], |
|
'total_files': 3, |
|
'has_tests': True, |
|
'has_docs': True |
|
} |
|
else: |
|
return {} |
|
|
|
def save_dataset(self, filepath: str): |
|
"""Save collected examples to file""" |
|
data = [asdict(example) for example in self.collected_examples] |
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
json.dump(data, f, indent=2, ensure_ascii=False) |
|
logger.info(f"Saved {len(data)} examples to {filepath}") |
|
|
|
|
|
class DataPreprocessor: |
|
"""Preprocesses collected data for training""" |
|
|
|
def __init__(self, tokenizer_name: str = "microsoft/DialoGPT-medium"): |
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
try: |
|
model_max = getattr(self.tokenizer, 'model_max_length', 1024) |
|
|
|
if model_max and model_max > 0 and model_max < 100000: |
|
self.max_length = min(1024, int(model_max)) |
|
else: |
|
self.max_length = 1024 |
|
except Exception: |
|
self.max_length = 1024 |
|
|
|
def preprocess_examples(self, examples: List[CodeExample]) -> List[Dict[str, str]]: |
|
"""Convert examples to training format""" |
|
processed: List[Dict[str, str]] = [] |
|
|
|
for example in examples: |
|
input_text = self._create_input_text(example) |
|
output_text = self._create_output_text(example) |
|
|
|
processed.append({ |
|
'input': input_text, |
|
'output': output_text, |
|
'framework': example.framework, |
|
'language': example.language |
|
}) |
|
|
|
return processed |
|
|
|
def _create_input_text(self, example: CodeExample) -> str: |
|
"""Create model input text""" |
|
input_parts: List[str] = [ |
|
f"Description: {example.description}", |
|
f"Framework: {example.framework}", |
|
f"Language: {example.language}", |
|
"Requirements:", |
|
] |
|
|
|
for req in example.requirements: |
|
input_parts.append(f"- {req}") |
|
|
|
input_parts.append("Generate the backend application:") |
|
|
|
return "\n".join(input_parts) |
|
|
|
def _create_output_text(self, example: CodeExample) -> str: |
|
"""Create model output text""" |
|
output_parts: List[str] = [] |
|
|
|
output_parts.append("Project Structure:") |
|
for directory in example.project_structure.get('directories', []): |
|
output_parts.append(f"/{directory}/") |
|
|
|
output_parts.append("\nGenerated Files:") |
|
|
|
for filename, content in example.code_files.items(): |
|
output_parts.append(f"\n--- {filename} ---") |
|
output_parts.append(content) |
|
output_parts.append("--- End ---") |
|
|
|
return "\n".join(output_parts) |
|
|
|
def create_training_dataset(self, processed_examples: List[Dict[str, str]]) -> HFDataset: |
|
"""Create Hugging Face dataset for training""" |
|
|
|
def tokenize_function(examples: Dict[str, List[str]]): |
|
texts: List[str] = [] |
|
for inp, out in zip(examples['input'], examples['output']): |
|
text = f"<|startoftext|>{inp}<|separator|>{out}<|endoftext|>" |
|
texts.append(text) |
|
|
|
return self.tokenizer( |
|
texts, |
|
truncation=True, |
|
padding=True, |
|
max_length=self.max_length |
|
) |
|
|
|
dataset_dict = { |
|
'input': [ex['input'] for ex in processed_examples], |
|
'output': [ex['output'] for ex in processed_examples], |
|
'framework': [ex['framework'] for ex in processed_examples], |
|
'language': [ex['language'] for ex in processed_examples] |
|
} |
|
|
|
dataset = HFDataset.from_dict(dataset_dict) |
|
tokenized_dataset = dataset.map(tokenize_function, batched=True) |
|
|
|
return tokenized_dataset |
|
|
|
|
|
class CodeGenerationModel: |
|
"""Custom model for backend code generation""" |
|
|
|
def __init__(self, base_model: str = "microsoft/DialoGPT-medium"): |
|
self.base_model = base_model |
|
self.tokenizer = AutoTokenizer.from_pretrained(base_model) |
|
self.model = AutoModelForCausalLM.from_pretrained(base_model) |
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
def fine_tune(self, dataset: HFDataset, output_dir: str = "./trained_model"): |
|
"""Fine-tune the model on backend code generation""" |
|
logger.info("Starting model fine-tuning...") |
|
|
|
training_args = TrainingArguments( |
|
output_dir=output_dir, |
|
overwrite_output_dir=True, |
|
num_train_epochs=1, |
|
per_device_train_batch_size=1, |
|
per_device_eval_batch_size=1, |
|
warmup_steps=50, |
|
max_steps=100, |
|
logging_steps=10, |
|
save_steps=50, |
|
save_total_limit=2, |
|
prediction_loss_only=True, |
|
fp16=torch.cuda.is_available(), |
|
dataloader_pin_memory=False, |
|
gradient_accumulation_steps=4, |
|
learning_rate=5e-5, |
|
) |
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=self.tokenizer, |
|
mlm=False, |
|
) |
|
|
|
train_size = int(0.8 * len(dataset)) |
|
eval_size = len(dataset) - train_size |
|
train_dataset, eval_dataset = torch.utils.data.random_split( |
|
dataset, [train_size, eval_size] |
|
) |
|
|
|
trainer = Trainer( |
|
model=self.model, |
|
args=training_args, |
|
data_collator=data_collator, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
) |
|
|
|
trainer.train() |
|
trainer.save_model() |
|
|
|
logger.info("Fine-tuning completed!") |
|
|
|
def generate_code(self, description: str, framework: str, language: str) -> str: |
|
"""Generate backend code for given requirements""" |
|
input_text = ( |
|
f"Description: {description}\n" |
|
f"Framework: {framework}\n" |
|
f"Language: {language}\n" |
|
f"Generate the backend application:" |
|
) |
|
|
|
|
|
model_max_len = getattr(self.tokenizer, 'model_max_length', 1024) |
|
max_len = 1024 if model_max_len is None or model_max_len > 100000 else min(1024, int(model_max_len)) |
|
|
|
inputs = self.tokenizer.encode(input_text, return_tensors='pt', truncation=True, max_length=max_len) |
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
inputs, |
|
max_length=max_len, |
|
num_return_sequences=1, |
|
temperature=0.7, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return generated_text[len(input_text):] |
|
|
|
|
|
class ModelEvaluator: |
|
"""Evaluates model performance""" |
|
|
|
def __init__(self): |
|
self.metrics: Dict[str, float] = {} |
|
|
|
def evaluate_code_quality(self, generated_code: str, language: str) -> Dict[str, float]: |
|
"""Evaluate generated code quality""" |
|
metrics = { |
|
'syntax_correctness': self._check_syntax(generated_code, language), |
|
'completeness': self._check_completeness(generated_code), |
|
'best_practices': self._check_best_practices(generated_code, language) |
|
} |
|
|
|
return metrics |
|
|
|
def _check_syntax(self, code: str, language: str) -> float: |
|
"""Check if generated code has valid syntax""" |
|
if language == 'python': |
|
try: |
|
ast.parse(code) |
|
return 1.0 |
|
except SyntaxError: |
|
return 0.0 |
|
elif language == 'javascript': |
|
if '{' in code and '}' in code: |
|
return 0.8 |
|
return 0.5 |
|
|
|
return 0.5 |
|
|
|
def _check_completeness(self, code: str) -> float: |
|
"""Check if code appears complete""" |
|
completeness_indicators = [ |
|
'import', 'require', 'function', 'def', 'class', |
|
'app.', 'router.', '@app.', 'app.listen', 'if __name__' |
|
] |
|
|
|
indicators_found = sum(1 for indicator in completeness_indicators if indicator in code) |
|
return min(indicators_found / 3.0, 1.0) |
|
|
|
def _check_best_practices(self, code: str, language: str) -> float: |
|
"""Check adherence to best practices""" |
|
best_practices_score = 0.0 |
|
|
|
if 'try:' in code or 'catch' in code: |
|
best_practices_score += 0.2 |
|
|
|
if any(comment in code for comment in ['#', '//', '/*']): |
|
best_practices_score += 0.2 |
|
|
|
if language == 'python': |
|
if 'if __name__ == "__main__"' in code: |
|
best_practices_score += 0.2 |
|
elif language == 'javascript': |
|
if 'const' in code or 'let' in code: |
|
best_practices_score += 0.2 |
|
|
|
return min(best_practices_score, 1.0) |
|
|
|
def benchmark_model(self, model: 'CodeGenerationModel', test_cases: List[Dict]) -> Dict[str, float]: |
|
"""Benchmark model on test cases""" |
|
total_scores = {'syntax': 0.0, 'completeness': 0.0, 'best_practices': 0.0} |
|
|
|
for i, test_case in enumerate(test_cases): |
|
generated_code = model.generate_code( |
|
test_case['description'], |
|
test_case['framework'], |
|
test_case['language'] |
|
) |
|
|
|
scores = self.evaluate_code_quality(generated_code, test_case['language']) |
|
|
|
total_scores['syntax'] += scores['syntax_correctness'] |
|
total_scores['completeness'] += scores['completeness'] |
|
total_scores['best_practices'] += scores['best_practices'] |
|
|
|
logger.info(f"Test case {i+1}: {scores}") |
|
|
|
num_cases = max(1, len(test_cases)) |
|
avg_scores = {key: value / num_cases for key, value in total_scores.items()} |
|
|
|
return avg_scores |
|
|
|
|
|
class TrainingPipeline: |
|
"""Main training pipeline orchestrator""" |
|
|
|
def __init__(self, config: Dict[str, Any]): |
|
self.config = config |
|
self.data_collector = DataCollector() |
|
self.preprocessor = DataPreprocessor(config.get('tokenizer', 'microsoft/DialoGPT-medium')) |
|
self.model = CodeGenerationModel(config.get('base_model', 'microsoft/DialoGPT-medium')) |
|
self.evaluator = ModelEvaluator() |
|
|
|
async def run_full_pipeline(self): |
|
"""Run the complete training pipeline""" |
|
logger.info("Starting full training pipeline...") |
|
|
|
logger.info("Step 1: Collecting training data...") |
|
|
|
if self.data_collector.github_token: |
|
github_queries = [ |
|
'express api backend', |
|
'fastapi python backend', |
|
'django rest api', |
|
'nodejs backend server', |
|
'flask api backend' |
|
] |
|
await self.data_collector.collect_github_repositories(github_queries, max_repos=50) |
|
|
|
self.data_collector.generate_synthetic_examples(count=200) |
|
|
|
self.data_collector.save_dataset('raw_dataset.json') |
|
|
|
logger.info("Step 2: Preprocessing data...") |
|
processed_examples = self.preprocessor.preprocess_examples(self.data_collector.collected_examples) |
|
training_dataset = self.preprocessor.create_training_dataset(processed_examples) |
|
|
|
logger.info("Step 3: Training model...") |
|
self.model.fine_tune(training_dataset, output_dir=self.config.get('output_dir', './trained_model')) |
|
|
|
logger.info("Step 4: Evaluating model...") |
|
test_cases = [ |
|
{ |
|
'description': 'REST API for user management with authentication', |
|
'framework': 'express', |
|
'language': 'javascript' |
|
}, |
|
{ |
|
'description': 'FastAPI backend for e-commerce platform', |
|
'framework': 'fastapi', |
|
'language': 'python' |
|
}, |
|
{ |
|
'description': 'Django REST API for blog platform', |
|
'framework': 'django', |
|
'language': 'python' |
|
} |
|
] |
|
|
|
benchmark_results = self.evaluator.benchmark_model(self.model, test_cases) |
|
logger.info(f"Benchmark results: {benchmark_results}") |
|
|
|
logger.info("Training pipeline completed!") |
|
return benchmark_results |
|
|
|
|
|
if __name__ == "__main__": |
|
config = { |
|
'base_model': 'microsoft/DialoGPT-medium', |
|
'tokenizer': 'microsoft/DialoGPT-medium', |
|
'output_dir': './backend_code_model', |
|
'github_token': os.getenv('GITHUB_TOKEN'), |
|
} |
|
|
|
pipeline = TrainingPipeline(config) |
|
|
|
asyncio.run(pipeline.run_full_pipeline()) |
|
|
|
logger.info("\nTesting trained model...") |
|
generated_code = pipeline.model.generate_code( |
|
description="Create a REST API for managing tasks with CRUD operations", |
|
framework="express", |
|
language="javascript" |
|
) |
|
|
|
print("\nGenerated Code:") |
|
print("=" * 50) |
|
print(generated_code) |
|
|
|
|
|
|