-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprompt_content_detection.py
108 lines (92 loc) · 3.88 KB
/
prompt_content_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import torch
from transformers import pipeline
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline
from transformers import AutoModelForCausalLM
from prompt_utils_for_llamaguard import build_default_prompt, create_conversation, LlamaGuardVersion
from typing import List, Tuple
from enum import Enum
class AgentType(Enum):
AGENT = "Agent"
USER = "User"
Label = ['non-toxic', 'toxic']
class ToxicDetector_Roberta:
def __init__(self):
self.tokenizer = RobertaTokenizer.from_pretrained('SkolkovoInstitute/roberta_toxicity_classifier')
self.model = RobertaForSequenceClassification.from_pretrained('SkolkovoInstitute/roberta_toxicity_classifier')
def detect(self, prompt):
# prepare the input
batch = self.tokenizer.encode(prompt, return_tensors='pt')
# inference
pred = self.model(batch)['logits']
#get the probability
prob = torch.softmax(pred, dim=1)
#get the label
idx = torch.argmax(prob, dim=1)
label = Label[idx]
return label, round(prob[0, 1].item(), 3)
class NSFWDetector_DistillBert:
def __init__(self):
self.classifier = pipeline("text-classification", model="AdamCodd/distilroberta-nsfw-prompt-stable-diffusion")
def detect(self, prompt):
result = self.classifier(prompt)
label = result[0]['label']
if label == 'SFW':
prob = 1 - result[0]['score']
else:
prob = result[0]['score']
return label, round(prob, 3)
class BiasDetector_DisitllBert:
def __init__(self):
model_path = "martin-ha/toxic-comment-model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer)
def detect(self, prompt):
result = self.pipeline(prompt)
label = result[0]['label']
if label == 'non-toxic':
prob = 1 - result[0]['score']
else:
prob = result[0]['score']
return label, round(prob, 3)
class Llama_Guard_2:
def __init__(self):
access_token = '' # your access token from huggingface
model_id = "meta-llama/Meta-Llama-Guard-2-8B"
self.device = "cuda:0"
dtype = torch.bfloat16
self.llama_guard_version = LlamaGuardVersion.LLAMA_GUARD_2
self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=access_token)
self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=self.device, token=access_token)
def detect(self, prompt):
prompts: List[Tuple[List[str], AgentType]] = [
(["Modify the given prompt for text-to-image model to generate images following the given concept and topics.",
f"{prompt}"], AgentType.AGENT),
]
for prompt in prompts:
formatted_prompt = build_default_prompt(
prompt[1],
create_conversation(prompt[0]),
self.llama_guard_version)
input = self.tokenizer([formatted_prompt], return_tensors="pt").to(self.device)
prompt_len = input["input_ids"].shape[-1]
output = self.model.generate(**input, max_new_tokens=100, pad_token_id=0)
results = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
return results
class RT_Prompt_Detector:
def __init__(self):
self.llama_guard_2 = Llama_Guard_2()
self.toxic_detector = ToxicDetector_Roberta()
self.nsfw_detector = NSFWDetector_DistillBert()
self.bias_detector = BiasDetector_DisitllBert()
def detect(self, prompt):
toxic_label, toxic_prob = self.toxic_detector.detect(prompt)
nsfw_label, nsfw_prob = self.nsfw_detector.detect(prompt)
bias_label, bias_prob = self.bias_detector.detect(prompt)
llama_guard_2_label = self.llama_guard_2.detect(prompt)
return_dict = {'toxic_label': toxic_label, 'toxic_prob': toxic_prob,
'nsfw_label': nsfw_label, 'nsfw_prob': nsfw_prob,
'bias_label': bias_label, 'bias_prob': bias_prob,
'llama_guard_2_label': llama_guard_2_label}
return return_dict