-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinjecdetect.py
More file actions
115 lines (80 loc) · 2.75 KB
/
injecdetect.py
File metadata and controls
115 lines (80 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import joblib
import numpy as np
import warnings
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.exceptions import InconsistentVersionWarning
warnings.filterwarnings(
"ignore",
category=InconsistentVersionWarning
)
# TF-IDF + Logistic Regression Detector
class TfidfLRDetector:
def __init__(self, model_path):
self.model = joblib.load(model_path)
def score(self, prompt: str) -> float:
return self.model.predict_proba([prompt])[0][1]
# DistilBERT Detector
class DistilBERTDetector:
def __init__(self, model_path):
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.model.eval()
def score(self, prompt: str) -> float:
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
return probs[0][1].item()
# Decision Logic
def classify(score: float, threshold: float) -> str:
return "MALICIOUS" if score >= threshold else "BENIGN"
def ensemble_decision(tfidf_score, bert_score,
tfidf_thresh=0.7,
bert_thresh=0.6):
return "MALICIOUS" if (
tfidf_score >= tfidf_thresh or
bert_score >= bert_thresh
) else "BENIGN"
# Interactive Prompt Analyzer
def analyze_prompt(prompt, tfidf_detector, bert_detector):
tfidf_score = tfidf_detector.score(prompt)
bert_score = bert_detector.score(prompt)
tfidf_decision = classify(tfidf_score, 0.7)
bert_decision = classify(bert_score, 0.6)
ensemble = ensemble_decision(tfidf_score, bert_score)
print("\n" + "=" * 70)
print("PROMPT:")
print(prompt)
print("-" * 70)
print(f"TF-IDF + LR score : {tfidf_score:.4f} → {tfidf_decision}")
print(f"DistilBERT score : {bert_score:.4f} → {bert_decision}")
print("-" * 70)
print(f"FINAL (ENSEMBLE) : {ensemble}")
print("=" * 70)
# Main Loop
if __name__ == "__main__":
print("\n Prompt Injection Detection System")
print("Type a prompt to analyze, or 'exit' to quit.\n")
tfidf_detector = TfidfLRDetector(
"prompt_injection_classifier.joblib"
)
bert_detector = DistilBERTDetector(
"distilbert_prompt_injection_detector"
)
while True:
user_prompt = input(">>> ")
if user_prompt.lower() in {"exit", "quit"}:
print("Exiting.")
break
analyze_prompt(
user_prompt,
tfidf_detector,
bert_detector
)