Skip to content

Commit

Permalink
Uses generate text request object as argument to client
Browse files Browse the repository at this point in the history
  • Loading branch information
apmiller108 committed Jan 24, 2025
1 parent a0bbe08 commit d9e0514
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 41 deletions.
11 changes: 10 additions & 1 deletion app/models/generate_text_request.rb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def system_message
end

def response
@response ||= GenerativeText::Anthropic::InvokeModelResponse.new(super) if super.present?
@response ||= response_wrapper_class.new(super) if super.present?
end

def response_token_count
Expand Down Expand Up @@ -68,4 +68,13 @@ def acceptable_file
errors.add(:file, 'must be less that 10 MB') if file.blob.byte_size > 10.megabytes
errors.add(:file, 'must be GIF, JPEG, PNG or WEBP') unless file.blob.content_type.in? SUPPORTED_MIME_TYPES
end

def response_wrapper_class
case model.vendor
when :aws
GenerativeText::AWS::Client::InvokeModelResponse
when :anthropic
GenerativeText::Anthropic::InvokeModelResponse
end
end
end
6 changes: 1 addition & 5 deletions app/sidekiq/generate_text_job.rb
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ def broadcast_component(generate_text_request, user)
end

def invoke_model(generate_text_request)
client = GenerativeText::Anthropic::Client.new
GenerativeText.new(client).invoke_model(
**generate_text_request.slice(:prompt, :temperature, :system_message, :model).symbolize_keys,
messages: generate_text_request.conversation.exchange
)
GenerativeText.new.invoke_model(generate_text_request)
end

class << self
Expand Down
3 changes: 2 additions & 1 deletion app/sidekiq/transcription_summary_job.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ def perform(user_id, transcription_id)
transcription = user.transcriptions.find(transcription_id)
summary = transcription.summary
prompt = GenerativeText.summary_prompt_for(transcription:)
client = GenerativeText::AWS::Client.new

summary.in_progress!

# @stream_response [InvokeModelStreamResponse, #content, #final_chunk?]
GenerativeText.new.invoke_model_stream(prompt:) do |stream_response|
GenerativeText.new.invoke_model_stream(client:, prompt:) do |stream_response|
summary.content += stream_response.content

if stream_response.final_chunk?
Expand Down
23 changes: 14 additions & 9 deletions lib/generative_text.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,25 @@ def self.summary_prompt_for(transcription:)
Prompt.transcription_summary_prompt(transcription)
end

attr_reader :client

def initialize(client = AWS::Client.new)
@client = client
def self.client_for(generate_text_request)
case generate_text_request.model.vendor
when :aws
AWS::Client
when :anthropic
Anthropic::Client
end
end

def invoke_model_stream(prompt:, **opts, &block)
# The block is what should yield to each stream chunk
# Only the AWS client supports streaming at this time
def invoke_model_stream(prompt:, client: AWS::Client.new, **opts, &block)
client.invoke_model_stream(prompt:, **opts, &block)
end

# @param prompt [String]
# See InvokeModelRequest for available options
# @param [GenerateTextRequest] request object
# @return [InvokeModelResponse] the object containing the generated text.
def invoke_model(prompt:, **opts)
client.invoke_model(prompt:, **opts)
def invoke_model(generate_text_request)
client = self.class.client_for(generate_text_request).new
client.invoke_model(generate_text_request)
end
end
13 changes: 7 additions & 6 deletions lib/generative_text/anthropic/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,24 @@ def initialize

# rubocop:disable Metrics/AbcSize
# @return [InvokeModelResponse]
def invoke_model(prompt:, messages: [], **params)
model = params.fetch(:model)
def invoke_model(generate_text_request)
model = generate_text_request.model
messages = generate_text_request.conversation.exchange
request_body = {
model: model.api_name,
max_tokens: params.fetch(:max_tokens, model.max_tokens),
temperature: params[:temperature],
max_tokens: model.max_tokens,
temperature: generate_text_request.temperature,
messages: messages.push(
{
role: :user,
content: [
type: :text,
text: prompt
text: generate_text_request.prompt
]
}
)
}
request_body[:system] = params[:system_message] if params[:system_message]
request_body[:system] = generate_text_request.system_message if generate_text_request.system_message
response = conn.post(MESSAGES_PATH) do |req|
req.body = request_body.to_json
end
Expand Down
10 changes: 7 additions & 3 deletions lib/generative_text/aws/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@ def initialize

def invoke_model_stream(prompt:, **params, &block)
params[:event_stream_handler] = EventStreamHandler.new(&block).to_proc
params = InvokeModelRequest.new(prompt:, **params).to_h
request = GenerateTextRequest.new(prompt:,
model: 'amazon.titan-text-express-v1',
temperature: 0.2)
params = InvokeModelRequest.new(request, **params).to_h
invoke_model_with_response_stream(params)
rescue Aws::BedrockRuntime::Errors
raise InvalidRequestError
end

def invoke_model(prompt:, **params)
params = InvokeModelRequest.new(prompt:, **params).to_h
# @param [GenerateTextRequest] request object
def invoke_model(generate_text_request)
params = InvokeModelRequest.new(generate_text_request).to_h
response = @client.invoke_model(params)
InvokeModelResponse.new(response.body.read)
rescue Aws::BedrockRuntime::Errors
Expand Down
28 changes: 14 additions & 14 deletions lib/generative_text/aws/client/invoke_model_request.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@ class GenerativeText
module AWS
class Client
class InvokeModelRequest
TEMP = 0.2 # Between 0..1. Increase for more randomness.
TOP_P = 0.8 # Between 0..1. Only consider possibilities gte this value. Lower it for weirdness.

attr_reader :params
attr_reader :generate_text_request

def initialize(prompt:, **params)
@prompt = prompt
@params = params
@model = MODELS.find { _1.api_name == 'amazon.titan-text-express-v1' }
delegate :model, :prompt, :temperature, to: :generate_text_request

def initialize(generate_text_request, **opts)
@generate_text_request = generate_text_request
@opts = opts
end

def to_h
{
model_id: @model.api_name,
model_id: model.api_name,
content_type: 'application/json',
accept: 'application/json',
body:
Expand All @@ -26,25 +26,25 @@ def to_h

def body
{
'inputText' => @prompt,
'inputText' => prompt,
'textGenerationConfig' => text_gen_config
}.to_json
end

def text_gen_config
{
'maxTokenCount' => params.fetch(:max_tokens, @model.max_tokens),
'stopSequences' => params.fetch(:stop_sequences, []),
'temperature' => params.fetch(:temp, TEMP),
'topP' => params.fetch(:top_p, TOP_P)
'maxTokenCount' => model.max_tokens,
'stopSequences' => @opts.fetch(:stop_sequences, []),
'temperature' => temperature,
'topP' => @opts.fetch(:top_p, TOP_P)
}
end

def event_stream_options
return {} if params[:event_stream_handler].blank?
return {} if @opts[:event_stream_handler].blank?

{
event_stream_handler: params[:event_stream_handler]
event_stream_handler: @opts[:event_stream_handler]
}
end
end
Expand Down
8 changes: 6 additions & 2 deletions lib/generative_text/aws/client/invoke_model_response.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ class InvokeModelResponse
# [{"tokenCount"=>55,
# "outputText"=> "\nHere are 10 fruits:\n1. Apple\n2. Banana\n3..."
# "completionReason"=>"FINISH"}]}
def initialize(json)
@data = JSON.parse(json)
def initialize(data)
@data = if data.respond_to? :keys
data
else
JSON.parse(data)
end
end

def content
Expand Down

0 comments on commit d9e0514

Please sign in to comment.