Skip to content

Commit

Permalink
text2img task to support negative prompts (#407)
Browse files Browse the repository at this point in the history
Co-authored-by: grajguru <grajguru@microsoft.com>
  • Loading branch information
gauravrajguru and grajguru authored Feb 12, 2024
1 parent 5cc9012 commit af874da
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 22 deletions.
3 changes: 3 additions & 0 deletions mii/legacy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def query(self, request_dict, **query_kwargs):
elif self.task == TaskType.ZERO_SHOT_IMAGE_CLASSIFICATION:
args = (request_dict["image"], request_dict["candidate_labels"])
kwargs = query_kwargs
elif self.task == TaskType.TEXT2IMG:
args = (request_dict["prompt"], request_dict.get("negative_prompt", None))
kwargs = query_kwargs
else:
args = (request_dict["query"], )
kwargs = query_kwargs
Expand Down
2 changes: 1 addition & 1 deletion mii/legacy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class ModelProvider(str, Enum):
"past_user_inputs",
"generated_responses",
],
TaskType.TEXT2IMG: ["query"],
TaskType.TEXT2IMG: ["prompt"],
TaskType.ZERO_SHOT_IMAGE_CLASSIFICATION: ["image",
"candidate_labels"],
}
Expand Down
8 changes: 7 additions & 1 deletion mii/legacy/grpc_related/proto/legacymodelresponse.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ service ModelResponse {
rpc FillMaskReply(SingleStringRequest) returns (SingleStringReply) {}
rpc TokenClassificationReply(SingleStringRequest) returns (SingleStringReply) {}
rpc ConversationalReply(ConversationRequest) returns (ConversationReply) {}
rpc Txt2ImgReply(MultiStringRequest) returns (ImageReply) {}
rpc Txt2ImgReply(Text2ImageRequest) returns (ImageReply) {}
rpc ZeroShotImgClassificationReply (ZeroShotImgClassificationRequest) returns (SingleStringReply) {}
}

Expand Down Expand Up @@ -103,6 +103,12 @@ message ImageReply {
float time_taken = 6;
}

message Text2ImageRequest {
repeated string prompt = 1;
repeated string negative_prompt = 2;
map<string,Value> query_kwargs = 3;
}

message ZeroShotImgClassificationRequest {
string image = 1;
repeated string candidate_labels = 2;
Expand Down
38 changes: 23 additions & 15 deletions mii/legacy/grpc_related/proto/legacymodelresponse_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: legacymodelresponse.proto
# Protobuf Python Version: 4.25.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
Expand All @@ -16,24 +17,27 @@
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2

DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x19legacymodelresponse.proto\x12\x13legacymodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xc7\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12O\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x39.legacymodelresponse.SingleStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xc5\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12N\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x38.legacymodelresponse.MultiStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xc5\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12\x45\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32/.legacymodelresponse.QARequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x94\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12O\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x39.legacymodelresponse.ConversationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\"\xf9\x01\n ZeroShotImgClassificationRequest\x12\r\n\x05image\x18\x01 \x01(\t\x12\x18\n\x10\x63\x61ndidate_labels\x18\x02 \x03(\t\x12\\\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x46.legacymodelresponse.ZeroShotImgClassificationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\x32\xb8\x08\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12I\n\rCreateSession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12J\n\x0e\x44\x65stroySession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x62\n\x0eGeneratorReply\x12\'.legacymodelresponse.MultiStringRequest\x1a%.legacymodelresponse.MultiStringReply\"\x00\x12i\n\x13\x43lassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x62\n\x16QuestionAndAnswerReply\x12\x1e.legacymodelresponse.QARequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x63\n\rFillMaskReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12n\n\x18TokenClassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12i\n\x13\x43onversationalReply\x12(.legacymodelresponse.ConversationRequest\x1a&.legacymodelresponse.ConversationReply\"\x00\x12Z\n\x0cTxt2ImgReply\x12\'.legacymodelresponse.MultiStringRequest\x1a\x1f.legacymodelresponse.ImageReply\"\x00\x12\x81\x01\n\x1eZeroShotImgClassificationReply\x12\x35.legacymodelresponse.ZeroShotImgClassificationRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x62\x06proto3'
b'\n\x19legacymodelresponse.proto\x12\x13legacymodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xc7\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12O\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x39.legacymodelresponse.SingleStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xc5\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12N\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x38.legacymodelresponse.MultiStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xc5\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12\x45\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32/.legacymodelresponse.QARequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x94\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12O\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x39.legacymodelresponse.ConversationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\"\xdb\x01\n\x11Text2ImageRequest\x12\x0e\n\x06prompt\x18\x01 \x03(\t\x12\x17\n\x0fnegative_prompt\x18\x02 \x03(\t\x12M\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x37.legacymodelresponse.Text2ImageRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xf9\x01\n ZeroShotImgClassificationRequest\x12\r\n\x05image\x18\x01 \x01(\t\x12\x18\n\x10\x63\x61ndidate_labels\x18\x02 \x03(\t\x12\\\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x46.legacymodelresponse.ZeroShotImgClassificationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\x32\xb7\x08\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12I\n\rCreateSession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12J\n\x0e\x44\x65stroySession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x62\n\x0eGeneratorReply\x12\'.legacymodelresponse.MultiStringRequest\x1a%.legacymodelresponse.MultiStringReply\"\x00\x12i\n\x13\x43lassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x62\n\x16QuestionAndAnswerReply\x12\x1e.legacymodelresponse.QARequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x63\n\rFillMaskReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12n\n\x18TokenClassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12i\n\x13\x43onversationalReply\x12(.legacymodelresponse.ConversationRequest\x1a&.legacymodelresponse.ConversationReply\"\x00\x12Y\n\x0cTxt2ImgReply\x12&.legacymodelresponse.Text2ImageRequest\x1a\x1f.legacymodelresponse.ImageReply\"\x00\x12\x81\x01\n\x1eZeroShotImgClassificationReply\x12\x35.legacymodelresponse.ZeroShotImgClassificationRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x62\x06proto3'
)

_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'legacymodelresponse_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None
_SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001'
_MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None
_MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001'
_QAREQUEST_QUERYKWARGSENTRY._options = None
_QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001'
_CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None
_CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001'
_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY._options = None
_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001'
_globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._options = None
_globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_options = b'8\001'
_globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._options = None
_globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_options = b'8\001'
_globals['_QAREQUEST_QUERYKWARGSENTRY']._options = None
_globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_options = b'8\001'
_globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._options = None
_globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_options = b'8\001'
_globals['_TEXT2IMAGEREQUEST_QUERYKWARGSENTRY']._options = None
_globals['_TEXT2IMAGEREQUEST_QUERYKWARGSENTRY']._serialized_options = b'8\001'
_globals['_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._options = None
_globals[
'_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._serialized_options = b'8\001'
_globals['_VALUE']._serialized_start = 79
_globals['_VALUE']._serialized_end = 174
_globals['_SESSIONID']._serialized_start = 176
Expand Down Expand Up @@ -62,11 +66,15 @@
_globals['_CONVERSATIONREPLY']._serialized_end = 1405
_globals['_IMAGEREPLY']._serialized_start = 1407
_globals['_IMAGEREPLY']._serialized_end = 1532
_globals['_ZEROSHOTIMGCLASSIFICATIONREQUEST']._serialized_start = 1535
_globals['_ZEROSHOTIMGCLASSIFICATIONREQUEST']._serialized_end = 1784
_globals['_TEXT2IMAGEREQUEST']._serialized_start = 1535
_globals['_TEXT2IMAGEREQUEST']._serialized_end = 1754
_globals['_TEXT2IMAGEREQUEST_QUERYKWARGSENTRY']._serialized_start = 331
_globals['_TEXT2IMAGEREQUEST_QUERYKWARGSENTRY']._serialized_end = 409
_globals['_ZEROSHOTIMGCLASSIFICATIONREQUEST']._serialized_start = 1757
_globals['_ZEROSHOTIMGCLASSIFICATIONREQUEST']._serialized_end = 2006
_globals[
'_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 331
_globals['_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 409
_globals['_MODELRESPONSE']._serialized_start = 1787
_globals['_MODELRESPONSE']._serialized_end = 2867
_globals['_MODELRESPONSE']._serialized_start = 2009
_globals['_MODELRESPONSE']._serialized_end = 3088
# @@protoc_insertion_point(module_scope)
6 changes: 3 additions & 3 deletions mii/legacy/grpc_related/proto/legacymodelresponse_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, channel):
)
self.Txt2ImgReply = channel.unary_unary(
'/legacymodelresponse.ModelResponse/Txt2ImgReply',
request_serializer=legacymodelresponse__pb2.MultiStringRequest.
request_serializer=legacymodelresponse__pb2.Text2ImageRequest.
SerializeToString,
response_deserializer=legacymodelresponse__pb2.ImageReply.FromString,
)
Expand Down Expand Up @@ -220,7 +220,7 @@ def add_ModelResponseServicer_to_server(servicer, server):
'Txt2ImgReply':
grpc.unary_unary_rpc_method_handler(
servicer.Txt2ImgReply,
request_deserializer=legacymodelresponse__pb2.MultiStringRequest.FromString,
request_deserializer=legacymodelresponse__pb2.Text2ImageRequest.FromString,
response_serializer=legacymodelresponse__pb2.ImageReply.SerializeToString,
),
'ZeroShotImgClassificationReply':
Expand Down Expand Up @@ -490,7 +490,7 @@ def Txt2ImgReply(request,
request,
target,
'/legacymodelresponse.ModelResponse/Txt2ImgReply',
legacymodelresponse__pb2.MultiStringRequest.SerializeToString,
legacymodelresponse__pb2.Text2ImageRequest.SerializeToString,
legacymodelresponse__pb2.ImageReply.FromString,
options,
channel_credentials,
Expand Down
25 changes: 23 additions & 2 deletions mii/legacy/method_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,24 @@ class Text2ImgMethods(TaskMethods):
def method(self):
return "Txt2ImgReply"

pack_request_to_proto = multi_string_request_to_proto
unpack_request_from_proto = proto_request_to_list
def run_inference(self, inference_pipeline, args, kwargs):
prompt, negative_prompt = args
return inference_pipeline(prompt=prompt,
negative_prompt=negative_prompt,
**kwargs)

def pack_request_to_proto(self, request_dict, **query_kwargs):
prompt = request_dict["prompt"]
prompt = [prompt] if isinstance(prompt, str) else prompt
negative_prompt = request_dict.get("negative_prompt", [""] * len(prompt))
negative_prompt = [negative_prompt] if isinstance(negative_prompt,
str) else negative_prompt

return modelresponse_pb2.Text2ImageRequest(
prompt=prompt,
negative_prompt=negative_prompt,
query_kwargs=kwarg_dict_to_proto(query_kwargs),
)

def pack_response_to_proto(self, response, time_taken, model_time_taken):
images_bytes = []
Expand All @@ -266,6 +282,11 @@ def pack_response_to_proto(self, response, time_taken, model_time_taken):
def unpack_response_from_proto(self, response):
return ImageResponse(response)

def unpack_request_from_proto(self, request):
kwargs = unpack_proto_query_kwargs(request.query_kwargs)
args = (list(request.prompt), list(request.negative_prompt))
return args, kwargs


class ZeroShotImgClassificationMethods(TaskMethods):
@property
Expand Down

0 comments on commit af874da

Please sign in to comment.