Skip to content

Commit e727363

Browse files
committed
Refactor AI analysis functionality to support multiple providers and prompt types.
1 parent 8535142 commit e727363

6 files changed

Lines changed: 197 additions & 29 deletions

File tree

config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ ai_providers:
2121
api_key_env: ''
2222
enabled: false
2323
max_tokens: 512
24-
model: microsoft/DialoGPT-medium
24+
model: deepseek-r1:1.5b
2525
name: local
2626
priority: 4
2727
rate_limit: 0.1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "sqlmap-ai"
7-
version = "2.0.3"
7+
version = "2.0.4"
88
description = "AI-powered SQL injection testing tool with multiple AI providers"
99
readme = "README.md"
1010
license = "MIT"

sqlmap_ai/ai_analyzer.py

Lines changed: 89 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,81 @@
1-
from utils.groq_utils import get_groq_response
1+
from utils.ai_providers import ai_manager, AIProvider
22
from sqlmap_ai.ui import print_info, print_warning, print_success
33
from sqlmap_ai.parser import extract_sqlmap_info
44
import json
5-
def ai_suggest_next_steps(report, scan_history=None, extracted_data=None):
5+
import asyncio
6+
def ai_suggest_next_steps(report, scan_history=None, extracted_data=None, ai_provider=None, use_advanced=None):
67
print_info("Analyzing SQLMap results with AI...")
78
if not report:
89
return ["--technique=BT", "--level=2", "--risk=1"]
910
if report.startswith("TIMEOUT_WITH_PARTIAL_DATA:"):
1011
report = report[len("TIMEOUT_WITH_PARTIAL_DATA:"):]
1112
structured_info = extract_sqlmap_info(report)
12-
prompt = create_advanced_prompt(report, structured_info, scan_history, extracted_data)
13-
print_info("Sending detailed analysis request to Groq AI...")
14-
response = get_groq_response(prompt=prompt)
15-
if not response:
13+
14+
# Determine which prompt to use based on provider and user preference
15+
use_simple = False
16+
17+
# Default behavior: simple for Ollama, advanced for others
18+
if ai_provider == AIProvider.OLLAMA or ai_provider == "ollama":
19+
use_simple = True
20+
21+
# Override based on user preference
22+
if use_advanced is not None:
23+
use_simple = not use_advanced
24+
25+
if use_simple:
26+
prompt = create_simple_prompt(report, structured_info, scan_history, extracted_data)
27+
print_info("Using simple prompt for AI analysis")
28+
else:
29+
prompt = create_advanced_prompt(report, structured_info, scan_history, extracted_data)
30+
print_info("Using advanced prompt for AI analysis")
31+
32+
# Determine which AI provider to use
33+
provider_name = "AI"
34+
if ai_provider:
35+
provider_name = ai_provider.upper()
36+
print_info(f"Sending detailed analysis request to {provider_name}...")
37+
38+
# Use the AI provider system
39+
try:
40+
# Convert string provider to AIProvider enum if needed
41+
provider_enum = None
42+
if ai_provider:
43+
try:
44+
provider_enum = AIProvider(ai_provider)
45+
print_info(f"Using AI provider: {provider_enum}")
46+
except ValueError:
47+
print_warning(f"Invalid AI provider: {ai_provider}")
48+
return ["--technique=BEU", "--level=3"]
49+
50+
response = asyncio.run(ai_manager.get_response(prompt, provider=provider_enum))
51+
if response and response.success:
52+
response_text = response.content
53+
else:
54+
print_warning(f"AI provider {ai_provider} failed: {response.error if response else 'Unknown error'}")
55+
return ["--technique=BEU", "--level=3"]
56+
except Exception as e:
57+
print_warning(f"AI analysis failed: {e}")
58+
return ["--technique=BEU", "--level=3"]
59+
if not response_text:
1660
print_warning("AI couldn't suggest options, using fallback options")
1761
return ["--technique=BEU", "--level=3"]
1862
print_success("Received AI recommendations!")
1963
try:
2064
# Try parsing JSON responses
21-
if "```json" in response:
22-
json_start = response.find("```json") + 7
23-
json_end = response.find("```", json_start)
24-
json_str = response[json_start:json_end].strip()
65+
if "```json" in response_text:
66+
json_start = response_text.find("```json") + 7
67+
json_end = response_text.find("```", json_start)
68+
json_str = response_text[json_start:json_end].strip()
2569
recommendation = json.loads(json_str)
2670
if "sqlmap_options" in recommendation:
2771
return recommendation["sqlmap_options"]
2872
elif "options" in recommendation:
2973
return recommendation["options"]
3074
# Look for code blocks without json tag
31-
elif "```" in response:
32-
code_start = response.find("```") + 3
33-
code_end = response.find("```", code_start)
34-
code_block = response[code_start:code_end].strip()
75+
elif "```" in response_text:
76+
code_start = response_text.find("```") + 3
77+
code_end = response_text.find("```", code_start)
78+
code_block = response_text[code_start:code_end].strip()
3579
# Check if content is JSON
3680
try:
3781
recommendation = json.loads(code_block)
@@ -44,7 +88,7 @@ def ai_suggest_next_steps(report, scan_history=None, extracted_data=None):
4488

4589
# Extract options from the response text
4690
options = []
47-
for line in response.split('\n'):
91+
for line in response_text.split('\n'):
4892
line = line.strip()
4993
if line.startswith('--') or line.startswith('-p ') or line.startswith('-D ') or line.startswith('-T ') or \
5094
line.startswith('--data=') or line.startswith('--cookie=') or line.startswith('--headers=') or \
@@ -58,7 +102,7 @@ def ai_suggest_next_steps(report, scan_history=None, extracted_data=None):
58102
print_warning(f"Error parsing AI response: {str(e)}")
59103
# Fallback to simple extraction
60104
options = []
61-
for line in response.strip().split('\n'):
105+
for line in response_text.strip().split('\n'):
62106
for part in line.split():
63107
if part.startswith('--') or part.startswith('-p ') or part.startswith('-D ') or part.startswith('-T ') or \
64108
part.startswith('--data=') or part.startswith('--cookie=') or part.startswith('--headers=') or \
@@ -126,6 +170,35 @@ def ai_suggest_next_steps(report, scan_history=None, extracted_data=None):
126170
return ["--technique=BEU", "--level=3"]
127171

128172
return valid_options
173+
174+
def create_simple_prompt(report, structured_info, scan_history=None, extracted_data=None):
175+
"""Create a simpler prompt for Ollama to avoid timeouts"""
176+
prompt = """
177+
You are a SQLMap expert. Analyze this SQL injection scan result and suggest the next steps.
178+
179+
DBMS: {dbms}
180+
Vulnerable Parameters: {vulnerable_params}
181+
Databases Found: {databases}
182+
183+
Based on this information, suggest the next SQLMap options to use. Focus on:
184+
1. Extracting more database information
185+
2. Using conservative settings to prevent timeouts
186+
3. Using specific techniques rather than broad scanning
187+
188+
Return your recommendation as a simple list of options, one per line:
189+
--level=2
190+
--risk=1
191+
--dbs
192+
"""
193+
194+
formatted_prompt = prompt.format(
195+
dbms=structured_info.get("dbms", "Unknown"),
196+
vulnerable_params=', '.join(structured_info.get("vulnerable_parameters", [])) or "None",
197+
databases=', '.join(structured_info.get("databases", [])) or "None"
198+
)
199+
200+
return formatted_prompt
201+
129202
def create_advanced_prompt(report, structured_info, scan_history=None, extracted_data=None):
130203
prompt = """
131204
You are a SQLMap expert. You are given a SQLMap scan report and a list of previous scan steps.

sqlmap_ai/enhanced_cli.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,28 @@ def create_parser(self) -> argparse.ArgumentParser:
9595
ai_group = parser.add_argument_group('AI Configuration')
9696
ai_group.add_argument(
9797
'--ai-provider',
98-
choices=['groq', 'openai', 'anthropic', 'local', 'auto'],
98+
choices=['groq', 'openai', 'anthropic', 'local', 'ollama', 'auto'],
9999
default='auto',
100100
help='AI provider to use (default: auto)'
101101
)
102102
ai_group.add_argument(
103103
'--ai-model',
104104
help='Specific AI model to use'
105105
)
106+
ai_group.add_argument(
107+
'--ollama-model',
108+
help='Specific Ollama model to use (overrides OLLAMA_MODEL env var)'
109+
)
110+
ai_group.add_argument(
111+
'--advanced',
112+
action='store_true',
113+
help='Use advanced AI prompts (default: simple for Ollama, advanced for other providers)'
114+
)
115+
ai_group.add_argument(
116+
'--simple',
117+
action='store_true',
118+
help='Use simple AI prompts (default: simple for Ollama, advanced for other providers)'
119+
)
106120
ai_group.add_argument(
107121
'--disable-ai',
108122
action='store_true',

sqlmap_ai/main.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sqlmap_ai.adaptive_testing import run_adaptive_test_sequence
2424
from sqlmap_ai.advanced_reporting import report_generator
2525
from sqlmap_ai.evasion_engine import evasion_engine
26-
from utils.ai_providers import ai_manager, get_available_ai_providers
26+
from utils.ai_providers import ai_manager, get_available_ai_providers, AIProvider
2727
from typing import Optional
2828
def main():
2929
"""Enhanced main function with improved CLI and security"""
@@ -160,7 +160,7 @@ def run_enhanced_adaptive_mode(runner, target_url, user_timeout, interactive_mod
160160
# Enhance with AI analysis if available
161161
if result and not args.disable_ai:
162162
try:
163-
asyncio.run(enhance_with_ai_analysis(result, target_url))
163+
asyncio.run(enhance_with_ai_analysis(result, target_url, args))
164164
except Exception as e:
165165
print_warning(f"AI analysis failed: {e}")
166166

@@ -171,17 +171,62 @@ def run_enhanced_standard_mode(runner, target_url, user_timeout, interactive_mod
171171
"""Run enhanced standard mode"""
172172
print_info("Starting enhanced standard testing...")
173173

174+
# Determine AI provider from arguments
175+
ai_provider = None
176+
if hasattr(args, 'ai_provider') and args.ai_provider and args.ai_provider != 'auto':
177+
ai_provider = args.ai_provider
178+
179+
# Enable Ollama if selected
180+
if ai_provider == 'ollama':
181+
import os
182+
os.environ['ENABLE_OLLAMA'] = 'true'
183+
# Reinitialize AI manager to include Ollama provider
184+
ai_manager.reinitialize_providers()
185+
186+
# Set Ollama model if specified
187+
if hasattr(args, 'ollama_model') and args.ollama_model:
188+
import os
189+
os.environ['OLLAMA_MODEL'] = args.ollama_model
190+
191+
# Set Ollama model if specified
192+
if hasattr(args, 'ollama_model') and args.ollama_model:
193+
import os
194+
os.environ['OLLAMA_MODEL'] = args.ollama_model
195+
# Update the Ollama provider's model if it exists
196+
if AIProvider.OLLAMA in ai_manager.providers:
197+
ai_manager.providers[AIProvider.OLLAMA].update_model(args.ollama_model)
198+
174199
# Use the existing standard mode but with enhancements
175-
return run_standard_mode(runner, target_url, user_timeout, interactive_mode)
200+
return run_standard_mode(runner, target_url, user_timeout, interactive_mode, ai_provider, args)
176201

177202

178-
async def enhance_with_ai_analysis(result, target_url):
203+
async def enhance_with_ai_analysis(result, target_url, args):
179204
"""Enhance scan results with AI analysis"""
180205
if not result:
181206
return
182207

183208
print_info("Enhancing results with AI analysis...")
184209

210+
# Determine AI provider from arguments
211+
ai_provider = None
212+
if hasattr(args, 'ai_provider') and args.ai_provider and args.ai_provider != 'auto':
213+
ai_provider = args.ai_provider
214+
215+
# Enable Ollama if selected
216+
if ai_provider == 'ollama':
217+
import os
218+
os.environ['ENABLE_OLLAMA'] = 'true'
219+
# Reinitialize AI manager to include Ollama provider
220+
ai_manager.reinitialize_providers()
221+
222+
# Set Ollama model if specified
223+
if hasattr(args, 'ollama_model') and args.ollama_model:
224+
import os
225+
os.environ['OLLAMA_MODEL'] = args.ollama_model
226+
# Update the Ollama provider's model if it exists
227+
if AIProvider.OLLAMA in ai_manager.providers:
228+
ai_manager.providers[AIProvider.OLLAMA].update_model(args.ollama_model)
229+
185230
# Get AI response for advanced analysis
186231
scan_data = result.get('scan_history', [])
187232
if scan_data:
@@ -197,7 +242,17 @@ async def enhance_with_ai_analysis(result, target_url):
197242
"""
198243

199244
try:
200-
ai_response = await ai_manager.get_response(prompt)
245+
# Convert string provider to AIProvider enum if needed
246+
provider_enum = None
247+
if ai_provider:
248+
try:
249+
from utils.ai_providers import AIProvider
250+
provider_enum = AIProvider(ai_provider)
251+
except ValueError:
252+
print_warning(f"Invalid AI provider: {ai_provider}")
253+
return
254+
255+
ai_response = await ai_manager.get_response(prompt, provider=provider_enum)
201256
if ai_response.success:
202257
result['ai_analysis'] = ai_response.content
203258
print_success("AI analysis completed")
@@ -317,7 +372,7 @@ def run_adaptive_mode(runner, target_url, user_timeout, interactive_mode):
317372
print_error("Adaptive testing failed. Check target URL and try again.")
318373
if result and "message" in result:
319374
print_info(f"Error: {result['message']}")
320-
def run_standard_mode(runner, target_url, user_timeout, interactive_mode):
375+
def run_standard_mode(runner, target_url, user_timeout, interactive_mode, ai_provider=None, args=None):
321376
print_info("Starting initial reconnaissance...")
322377
scan_history = []
323378
extracted_data = {}
@@ -346,11 +401,24 @@ def run_standard_mode(runner, target_url, user_timeout, interactive_mode):
346401
print_warning("Scan was interrupted by user. Stopping here.")
347402
return
348403
display_report(report)
349-
print_info("Analyzing results with Groq AI and determining next steps...")
404+
# Determine AI provider name for display
405+
provider_name = "AI"
406+
if ai_provider:
407+
provider_name = ai_provider.upper()
408+
print_info(f"Analyzing results with {provider_name} and determining next steps...")
409+
# Determine if user wants advanced prompts
410+
use_advanced = None
411+
if hasattr(args, 'advanced') and args.advanced:
412+
use_advanced = True
413+
elif hasattr(args, 'simple') and args.simple:
414+
use_advanced = False
415+
350416
next_options = ai_suggest_next_steps(
351417
report=report,
352418
scan_history=scan_history,
353-
extracted_data=extracted_data
419+
extracted_data=extracted_data,
420+
ai_provider=ai_provider,
421+
use_advanced=use_advanced
354422
)
355423
if next_options:
356424
user_options = get_user_choice(next_options)

utils/ai_providers.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ async def get_response(
123123
"""Get response from AI provider with fallback"""
124124

125125
providers_to_try = [provider] if provider else self.active_providers
126-
127126
for attempt_provider in providers_to_try:
128127
if attempt_provider not in self.providers:
129128
continue
@@ -149,6 +148,12 @@ async def get_response(
149148
def get_available_providers(self) -> List[AIProvider]:
150149
"""Get list of available providers"""
151150
return self.active_providers.copy()
151+
152+
def reinitialize_providers(self):
153+
"""Reinitialize providers (useful when environment variables change)"""
154+
self.providers = {}
155+
self.active_providers = []
156+
self._setup_providers()
152157

153158

154159
class BaseAIProvider:
@@ -351,6 +356,10 @@ def __init__(self):
351356
self.default_model = os.getenv("OLLAMA_MODEL", "llama3.2")
352357
self.rate_limit_delay = 0.5
353358

359+
def update_model(self, model_name: str):
360+
"""Update the default model name"""
361+
self.default_model = model_name
362+
354363
async def get_response(
355364
self,
356365
prompt: str,
@@ -363,6 +372,9 @@ async def get_response(
363372
model = model or self.default_model
364373
start_time = time.time()
365374

375+
logger.info(f"Ollama provider: Using model {model}")
376+
logger.info(f"Ollama provider: Base URL {self.base_url}")
377+
366378
for attempt in range(max_retries):
367379
try:
368380
await self._rate_limit()
@@ -388,7 +400,7 @@ async def get_response(
388400
"num_predict": kwargs.get('max_tokens', 512)
389401
}
390402
},
391-
timeout=kwargs.get('timeout', 60)
403+
timeout=kwargs.get('timeout', 120) # Increased timeout for complex prompts
392404
)
393405

394406
if response.status_code != 200:
@@ -407,6 +419,7 @@ async def get_response(
407419
)
408420

409421
except Exception as e:
422+
logger.warning(f"Ollama provider attempt {attempt + 1} failed: {e}")
410423
if attempt == max_retries - 1:
411424
return AIResponse(
412425
content="",

0 commit comments

Comments
 (0)