diff --git a/app/models/generate_text_request.rb b/app/models/generate_text_request.rb index 4add62b0..b704aaa8 100644 --- a/app/models/generate_text_request.rb +++ b/app/models/generate_text_request.rb @@ -86,7 +86,7 @@ def response_wrapper_class when :anthropic GenerativeText::Anthropic::InvokeModelResponse when :aws - GenerativeText::AWS::Client::InvokeModelResponse + GenerativeText::AWS::InvokeModelResponse end end diff --git a/lib/generative_text/anthropic/client.rb b/lib/generative_text/anthropic/client.rb index 6b0985be..73bff263 100644 --- a/lib/generative_text/anthropic/client.rb +++ b/lib/generative_text/anthropic/client.rb @@ -20,9 +20,7 @@ def invoke_model(generate_text_request) response = conn.post(MESSAGES_PATH) do |req| req.body = InvokeModelRequest.new(generate_text_request).to_json end - InvokeModelResponse.new(response.body).tap do |r| - Rails.logger.info "#{self.class}: invoke_model usage: #{r.usage}" - end + InvokeModelResponse.new(response.body) rescue Faraday::Error => e raise ClientError, "#{e.response_status}: #{e.response_body}" end diff --git a/lib/generative_text/aws/client/event_stream_handler.rb b/lib/generative_text/aws/client/event_stream_handler.rb deleted file mode 100644 index d5e37cec..00000000 --- a/lib/generative_text/aws/client/event_stream_handler.rb +++ /dev/null @@ -1,39 +0,0 @@ -class GenerativeText - module AWS - class Client - class EventStreamHandler - def initialize(&block) - @block = block - end - - def to_proc - ->(stream) { - stream.on_chunk_event(&method(:on_chunk_event)) - stream.on_internal_server_exception_event(&method(:on_exception)) - stream.on_validation_exception_event(&method(:on_exception)) - stream.on_throttling_exception_event(&method(:on_exception)) - stream.on_model_timeout_exception_event(&method(:on_exception)) - stream.on_model_stream_error_exception_event(&method(:on_model_stream_error)) - stream.on_error_event(&method(:on_generic_error)) - } - end - - def on_chunk_event(event) - @block.call(InvokeModelStreamResponse.new(event.bytes)) - end - - def on_exception(event) - raise InvalidRequestError, "#{event.event_type}: #{event.message}" - end - - def on_model_stream_error(event) - raise InvalidRequestError, "#{event.event_type}: #{event.message} : #{event.original_message}" - end - - def on_generic_error(event) - raise InvalidRequestError, "#{event.event_type}: #{event.error_code} : #{event.error_message}" - end - end - end - end -end diff --git a/lib/generative_text/aws/client/invoke_model_request.rb b/lib/generative_text/aws/client/invoke_model_request.rb deleted file mode 100644 index c09bbfc7..00000000 --- a/lib/generative_text/aws/client/invoke_model_request.rb +++ /dev/null @@ -1,68 +0,0 @@ -class GenerativeText - module AWS - class Client - class InvokeModelRequest - TOP_P = 0.8 # Between 0..1. Only consider possibilities gte this value. Lower it for weirdness. - - attr_reader :generate_text_request - - delegate :model, :prompt, :temperature, to: :generate_text_request - - def initialize(generate_text_request, **opts) - @generate_text_request = generate_text_request - @opts = opts - end - - def to_h - { - model_id: model.api_name, - content_type: 'application/json', - accept: 'application/json', - body: - }.merge(event_stream_options) - end - - private - - def body - { - 'inputText' => input_text, - 'textGenerationConfig' => text_gen_config - }.to_json - end - - def input_text - <<~TXT - #{generate_text_request.system_message} - #{turns.join} - TXT - end - - def turns - generate_text_request.conversation.exchange << next_turn - end - - def next_turn - format(Turn::TEMPLATE, prompt:, response: '') - end - - def text_gen_config - { - 'maxTokenCount' => model.max_tokens, - 'stopSequences' => @opts.fetch(:stop_sequences, []), - 'temperature' => temperature, - 'topP' => @opts.fetch(:top_p, TOP_P) - } - end - - def event_stream_options - return {} if @opts[:event_stream_handler].blank? - - { - event_stream_handler: @opts[:event_stream_handler] - } - end - end - end - end -end diff --git a/lib/generative_text/aws/client/invoke_model_response.rb b/lib/generative_text/aws/client/invoke_model_response.rb deleted file mode 100644 index b0eac02f..00000000 --- a/lib/generative_text/aws/client/invoke_model_response.rb +++ /dev/null @@ -1,41 +0,0 @@ -class GenerativeText - module AWS - class Client - class InvokeModelResponse - attr_reader :data - - # JSON response will look like this - # {"inputTextTokenCount"=>5, - # "results"=> - # [{"tokenCount"=>55, - # "outputText"=> "\nHere are 10 fruits:\n1. Apple\n2. Banana\n3..." - # "completionReason"=>"FINISH"}]} - def initialize(data) - @data = if data.respond_to? :keys - data - else - JSON.parse(data) - end - end - - def content - results.fetch('outputText') - end - - def results - data.fetch('results')[0] - end - - def completion_reason - # completionReason could be one of ["LENGTH", "FINISH"]. The latter - # meaning the response was truncated per the max_tokens. - results.fetch('completionReason') - end - - def token_count - (data['inputTextTokenCount'] || 0) + (results['tokenCount'] || 0) - end - end - end - end -end diff --git a/lib/generative_text/aws/client/invoke_model_stream_response.rb b/lib/generative_text/aws/client/invoke_model_stream_response.rb deleted file mode 100644 index 40df51cd..00000000 --- a/lib/generative_text/aws/client/invoke_model_stream_response.rb +++ /dev/null @@ -1,45 +0,0 @@ -class GenerativeText - module AWS - class Client - class InvokeModelStreamResponse - attr_reader :response_data - - # A series of event json will look like this: - # - # {"outputText": "0..-2 parts of stream", "index": 0, - # "totalOutputTextTokenCount": nil, "completionReason": nil, - # "inputTextTokenCount": 9} - # - # {"outputText": "final part of stream", "index": 0, - # "totalOutputTextTokenCount": 104, "completionReason": "FINISH", - # "inputTextTokenCount": nil, "amazon-bedrock-invocationMetrics": - # {"inputTokenCount": 9, "outputTokenCount": 104, "invocationLatency": - # 3407, "firstByteLatency": 2440}} - - def initialize(json) - @response_data = JSON.parse(json) - end - - def content - response_data.fetch('outputText') - end - - def final_chunk? - # completionReason could be one of ["LENGTH", "FINISH"]. The latter - # meaning the response was truncated per the max_tokens. - response_data.fetch('completionReason').present? - end - - def token_count - (metrics['inputTokenCount'] || 0) + (metrics['outputTokenCount'] || 0) - end - - private - - def metrics - response_data.fetch('amazon-bedrock-invocationMetrics', {}) - end - end - end - end -end diff --git a/lib/generative_text/aws/event_stream_handler.rb b/lib/generative_text/aws/event_stream_handler.rb new file mode 100644 index 00000000..ba40b920 --- /dev/null +++ b/lib/generative_text/aws/event_stream_handler.rb @@ -0,0 +1,37 @@ +class GenerativeText + module AWS + class EventStreamHandler + def initialize(&block) + @block = block + end + + def to_proc + ->(stream) { + stream.on_chunk_event(&method(:on_chunk_event)) + stream.on_internal_server_exception_event(&method(:on_exception)) + stream.on_validation_exception_event(&method(:on_exception)) + stream.on_throttling_exception_event(&method(:on_exception)) + stream.on_model_timeout_exception_event(&method(:on_exception)) + stream.on_model_stream_error_exception_event(&method(:on_model_stream_error)) + stream.on_error_event(&method(:on_generic_error)) + } + end + + def on_chunk_event(event) + @block.call(InvokeModelStreamResponse.new(event.bytes)) + end + + def on_exception(event) + raise InvalidRequestError, "#{event.event_type}: #{event.message}" + end + + def on_model_stream_error(event) + raise InvalidRequestError, "#{event.event_type}: #{event.message} : #{event.original_message}" + end + + def on_generic_error(event) + raise InvalidRequestError, "#{event.event_type}: #{event.error_code} : #{event.error_message}" + end + end + end +end diff --git a/lib/generative_text/aws/invoke_model_request.rb b/lib/generative_text/aws/invoke_model_request.rb new file mode 100644 index 00000000..7ddfd6b2 --- /dev/null +++ b/lib/generative_text/aws/invoke_model_request.rb @@ -0,0 +1,66 @@ +class GenerativeText + module AWS + class InvokeModelRequest + TOP_P = 0.8 # Between 0..1. Only consider possibilities gte this value. Lower it for weirdness. + + attr_reader :generate_text_request + + delegate :model, :prompt, :temperature, to: :generate_text_request + + def initialize(generate_text_request, **opts) + @generate_text_request = generate_text_request + @opts = opts + end + + def to_h + { + model_id: model.api_name, + content_type: 'application/json', + accept: 'application/json', + body: + }.merge(event_stream_options) + end + + private + + def body + { + 'inputText' => input_text, + 'textGenerationConfig' => text_gen_config + }.to_json + end + + def input_text + <<~TXT.strip + #{generate_text_request.system_message} + #{turns.join} + TXT + end + + def turns + generate_text_request.conversation.exchange << next_turn + end + + def next_turn + format(Turn::TEMPLATE, prompt:, response: '') + end + + def text_gen_config + { + 'maxTokenCount' => model.max_tokens, + 'stopSequences' => @opts.fetch(:stop_sequences, []), + 'temperature' => temperature, + 'topP' => @opts.fetch(:top_p, TOP_P) + } + end + + def event_stream_options + return {} if @opts[:event_stream_handler].blank? + + { + event_stream_handler: @opts[:event_stream_handler] + } + end + end + end +end diff --git a/lib/generative_text/aws/invoke_model_response.rb b/lib/generative_text/aws/invoke_model_response.rb new file mode 100644 index 00000000..61afa401 --- /dev/null +++ b/lib/generative_text/aws/invoke_model_response.rb @@ -0,0 +1,39 @@ +class GenerativeText + module AWS + class InvokeModelResponse + attr_reader :data + + # JSON response will look like this + # {"inputTextTokenCount"=>5, + # "results"=> + # [{"tokenCount"=>55, + # "outputText"=> "\nHere are 10 fruits:\n1. Apple\n2. Banana\n3..." + # "completionReason"=>"FINISH"}]} + def initialize(data) + @data = if data.respond_to? :keys + data + else + JSON.parse(data) + end + end + + def content + results.fetch('outputText') + end + + def results + data.fetch('results')[0] + end + + def completion_reason + # completionReason could be one of ["LENGTH", "FINISH"]. The latter + # meaning the response was truncated per the max_tokens. + results.fetch('completionReason') + end + + def token_count + (data['inputTextTokenCount'] || 0) + (results['tokenCount'] || 0) + end + end + end +end diff --git a/lib/generative_text/aws/invoke_model_stream_response.rb b/lib/generative_text/aws/invoke_model_stream_response.rb new file mode 100644 index 00000000..f187bbb5 --- /dev/null +++ b/lib/generative_text/aws/invoke_model_stream_response.rb @@ -0,0 +1,43 @@ +class GenerativeText + module AWS + class InvokeModelStreamResponse + attr_reader :response_data + + # A series of event json will look like this: + # + # {"outputText": "0..-2 parts of stream", "index": 0, + # "totalOutputTextTokenCount": nil, "completionReason": nil, + # "inputTextTokenCount": 9} + # + # {"outputText": "final part of stream", "index": 0, + # "totalOutputTextTokenCount": 104, "completionReason": "FINISH", + # "inputTextTokenCount": nil, "amazon-bedrock-invocationMetrics": + # {"inputTokenCount": 9, "outputTokenCount": 104, "invocationLatency": + # 3407, "firstByteLatency": 2440}} + + def initialize(json) + @response_data = JSON.parse(json) + end + + def content + response_data.fetch('outputText') + end + + def final_chunk? + # completionReason could be one of ["LENGTH", "FINISH"]. The latter + # meaning the response was truncated per the max_tokens. + response_data.fetch('completionReason').present? + end + + def token_count + (metrics['inputTokenCount'] || 0) + (metrics['outputTokenCount'] || 0) + end + + private + + def metrics + response_data.fetch('amazon-bedrock-invocationMetrics', {}) + end + end + end +end diff --git a/lib/generative_text/prompt.rb b/lib/generative_text/prompt.rb deleted file mode 100644 index e69de29b..00000000 diff --git a/spec/factories/generate_text_requests.rb b/spec/factories/generate_text_requests.rb index 91da2c1e..c83ce129 100644 --- a/spec/factories/generate_text_requests.rb +++ b/spec/factories/generate_text_requests.rb @@ -6,6 +6,14 @@ model { GenerativeText::MODELS.sample.api_name } user + trait :with_anthropic_model do + model { GenerativeText::Anthropic::MODELS.sample.api_name } + end + + trait :with_aws_model do + model { GenerativeText::AWS::MODELS.sample.api_name } + end + trait :with_preset do generate_text_preset end @@ -19,6 +27,7 @@ end trait :with_response do + model { GenerativeText::Anthropic::MODELS.sample.api_name } status { 'completed' } response do { diff --git a/spec/lib/generative_text/anthropic/client_spec.rb b/spec/lib/generative_text/anthropic/client_spec.rb index a630edfc..6c6e1c70 100644 --- a/spec/lib/generative_text/anthropic/client_spec.rb +++ b/spec/lib/generative_text/anthropic/client_spec.rb @@ -4,18 +4,18 @@ RSpec.describe GenerativeText::Anthropic::Client do let(:client) { described_class.new } let(:prompt) { 'Write a haiku about a rainy day.' } - let(:messages) { [] } - let(:model) { GenerativeText::Anthropic::MODELS.sample } - let(:params) { { temperature: 0.7, system_message: 'this is the system message', model: } } + let(:model) { generate_text_request.model } + let(:temperature) { 0.1 } + let(:generate_text_request) { build_stubbed :generate_text_request, :with_anthropic_model, prompt:, temperature: } describe '#invoke_model' do let(:request_body) do { model: model.api_name, max_tokens: model.max_tokens, - system: params[:system_message], - temperature: params[:temperature], + system: generate_text_request.system_message, + temperature:, messages: [{ role: 'user', - content: [{ type: 'text', text: 'Write a haiku about a rainy day.' }] }] } + content: [{ type: 'text', text: prompt }] }] } end before do @@ -32,19 +32,9 @@ context 'with a valid request' do it 'returns an InvokeModelResponse object' do - response = client.invoke_model(prompt:, messages:, **params) + response = client.invoke_model(generate_text_request) expect(response).to be_a(GenerativeText::Anthropic::InvokeModelResponse) end - - context 'with the model param' do - let(:model) { GenerativeText::Anthropic::MODELS.sample } - let(:params) { { temperature: 0.7, system_message: 'this is the system message', model: } } - - it 'returns an InvokeModelResponse object' do - response = client.invoke_model(prompt:, messages:, **params) - expect(response).to be_a(GenerativeText::Anthropic::InvokeModelResponse) - end - end end context 'with invalid parameters' do @@ -54,7 +44,7 @@ end it 'raises a ClientError exception' do - expect { client.invoke_model(prompt:, messages:, **params) } + expect { client.invoke_model(generate_text_request) } .to raise_error(GenerativeText::Anthropic::ClientError, '500: Invalid request') end end diff --git a/spec/lib/generative_text/aws/client/invoke_model_request_spec.rb b/spec/lib/generative_text/aws/client/invoke_model_request_spec.rb deleted file mode 100644 index 2100f8da..00000000 --- a/spec/lib/generative_text/aws/client/invoke_model_request_spec.rb +++ /dev/null @@ -1,61 +0,0 @@ -require 'rails_helper' - -RSpec.describe GenerativeText::AWS::Client::InvokeModelRequest do - subject(:request) { described_class.new(prompt:, **params) } - - let(:prompt) { 'Generate something interesting.' } - let(:params) do - { - max_tokens: 100, - stop_sequences: ['', ''], - temp: 0.5, - top_p: 0.9 - } - end - let(:model_id) { 'amazon.titan-text-express-v1' } - - describe '#to_h' do - it 'returns the request as a hash' do - expected_hash = { - model_id:, - content_type: 'application/json', - accept: 'application/json', - body: { - 'inputText' => prompt, - 'textGenerationConfig' => { - 'maxTokenCount' => params.fetch(:max_tokens), - 'stopSequences' => params.fetch(:stop_sequences), - 'temperature' => params.fetch(:temp), - 'topP' => params.fetch(:top_p) - } - }.to_json - } - - expect(request.to_h).to eq(expected_hash) - end - - context 'with default values' do - let(:params) { {} } - let(:default_model) { GenerativeText::AWS::MODELS.find { _1.api_name == 'amazon.titan-text-express-v1' } } - - it 'sets default values for each of the text gen config params' do - expected_hash = { - model_id: default_model.api_name, - content_type: 'application/json', - accept: 'application/json', - body: { - 'inputText' => prompt, - 'textGenerationConfig' => { - 'maxTokenCount' => default_model.max_tokens, - 'stopSequences' => [], - 'temperature' => described_class::TEMP, - 'topP' => described_class::TOP_P - } - }.to_json - } - - expect(request.to_h).to eq(expected_hash) - end - end - end -end diff --git a/spec/lib/generative_text/aws/client_spec.rb b/spec/lib/generative_text/aws/client_spec.rb index 925f8a58..ddde94c3 100644 --- a/spec/lib/generative_text/aws/client_spec.rb +++ b/spec/lib/generative_text/aws/client_spec.rb @@ -34,48 +34,46 @@ describe '#invoke_model' do subject(:client) { described_class.new } - let(:prompt) { 'List 18.5 fruits' } - let(:params) { { a: 1, b: 2 } } - let(:request) { instance_double GenerativeText::AWS::Client::InvokeModelRequest, to_h: request_hash } + let(:generate_text_request) { build_stubbed :generate_text_request } + let(:request) { instance_double GenerativeText::AWS::InvokeModelRequest, to_h: request_hash } let(:request_hash) { double } let(:response_string) { double } let(:body) { instance_double StringIO, read: response_string } let(:client_response) { instance_double Aws::BedrockRuntime::Types::InvokeModelResponse, body: } - let(:response) { instance_double GenerativeText::AWS::Client::InvokeModelResponse } + let(:response) { instance_double GenerativeText::AWS::InvokeModelResponse } before do - allow(GenerativeText::AWS::Client::InvokeModelRequest).to receive(:new) - .with(prompt:, **params).and_return(request) + allow(GenerativeText::AWS::InvokeModelRequest).to receive(:new) + .with(generate_text_request).and_return(request) allow(aws_client).to receive(:invoke_model).with(request_hash).and_return(client_response) - allow(GenerativeText::AWS::Client::InvokeModelResponse).to receive(:new) + allow(GenerativeText::AWS::InvokeModelResponse).to receive(:new) .with(response_string).and_return(response) end it 'returns the InvokeModelResponse object' do - expect(client.invoke_model(prompt:, **params)).to eq response + expect(client.invoke_model(generate_text_request)).to eq response end end describe '#invoke_model_stream' do subject(:client) { described_class.new } - let(:prompt) { 'tell me something awesme' } + let(:generate_text_request) { build_stubbed :generate_text_request } let(:handler_proc) { proc {} } - let(:handler) { instance_double(GenerativeText::AWS::Client::EventStreamHandler, to_proc: handler_proc) } - let(:request) { instance_double(GenerativeText::AWS::Client::InvokeModelRequest, to_h: request_params) } + let(:handler) { instance_double(GenerativeText::AWS::EventStreamHandler, to_proc: handler_proc) } + let(:request) { instance_double(GenerativeText::AWS::InvokeModelRequest, to_h: request_params) } let(:request_params) { { b: 2 } } - let(:params) { { a: 1 } } before do - allow(GenerativeText::AWS::Client::EventStreamHandler).to receive(:new).and_return(handler) - allow(GenerativeText::AWS::Client::InvokeModelRequest).to( - receive(:new).with(prompt:, **params.merge(event_stream_handler: handler_proc)).and_return(request) + allow(GenerativeText::AWS::EventStreamHandler).to receive(:new).and_return(handler) + allow(GenerativeText::AWS::InvokeModelRequest).to( + receive(:new).with(generate_text_request, event_stream_handler: handler_proc).and_return(request) ) end it 'delegates to the aws client with the proper params' do block = proc {} - client.invoke_model_stream(prompt:, **params, &block) + client.invoke_model_stream(generate_text_request, &block) expect(aws_client).to have_received(:invoke_model_with_response_stream).with(request_params) end diff --git a/spec/lib/generative_text/aws/client/event_stream_handler_spec.rb b/spec/lib/generative_text/aws/event_stream_handler_spec.rb similarity index 89% rename from spec/lib/generative_text/aws/client/event_stream_handler_spec.rb rename to spec/lib/generative_text/aws/event_stream_handler_spec.rb index 3c6d2ac6..d81d0861 100644 --- a/spec/lib/generative_text/aws/client/event_stream_handler_spec.rb +++ b/spec/lib/generative_text/aws/event_stream_handler_spec.rb @@ -1,6 +1,6 @@ require 'rails_helper' -RSpec.describe GenerativeText::AWS::Client::EventStreamHandler do +RSpec.describe GenerativeText::AWS::EventStreamHandler do describe '#initialize' do it 'stores the block' do block = proc { |event| puts event } @@ -24,8 +24,8 @@ it 'calls the provided block with the event' do bytes = double event = instance_double(Aws::BedrockRuntime::Types::PayloadPart, bytes:) - response = instance_double(GenerativeText::AWS::Client::InvokeModelStreamResponse) - allow(GenerativeText::AWS::Client::InvokeModelStreamResponse).to receive(:new).with(bytes).and_return(response) + response = instance_double(GenerativeText::AWS::InvokeModelStreamResponse) + allow(GenerativeText::AWS::InvokeModelStreamResponse).to receive(:new).with(bytes).and_return(response) block = proc { |e| e } handler = described_class.new(&block) diff --git a/spec/lib/generative_text/aws/invoke_model_request_spec.rb b/spec/lib/generative_text/aws/invoke_model_request_spec.rb new file mode 100644 index 00000000..14d9351d --- /dev/null +++ b/spec/lib/generative_text/aws/invoke_model_request_spec.rb @@ -0,0 +1,33 @@ +require 'rails_helper' + +RSpec.describe GenerativeText::AWS::InvokeModelRequest do + subject(:request) { described_class.new(generate_text_request) } + + let(:generate_text_request) do + build_stubbed :generate_text_request, :with_aws_model, prompt:, temperature:, markdown_format: false + end + let(:model) { generate_text_request.model } + let(:temperature) { 0.5 } + let(:prompt) { 'Generate something interesting.' } + + describe '#to_h' do + it 'returns the request as a hash' do + expected_hash = { + model_id: model.api_name, + content_type: 'application/json', + accept: 'application/json', + body: { + 'inputText' => format(GenerativeText::AWS::Turn::TEMPLATE, prompt:, response: '').strip, + 'textGenerationConfig' => { + 'maxTokenCount' => model.max_tokens, + 'stopSequences' => [], + 'temperature' => temperature, + 'topP' => described_class::TOP_P + } + }.to_json + } + + expect(request.to_h).to eq(expected_hash) + end + end +end diff --git a/spec/lib/generative_text/aws/client/invoke_model_response_spec.rb b/spec/lib/generative_text/aws/invoke_model_response_spec.rb similarity index 94% rename from spec/lib/generative_text/aws/client/invoke_model_response_spec.rb rename to spec/lib/generative_text/aws/invoke_model_response_spec.rb index b0603e05..de286259 100644 --- a/spec/lib/generative_text/aws/client/invoke_model_response_spec.rb +++ b/spec/lib/generative_text/aws/invoke_model_response_spec.rb @@ -1,6 +1,6 @@ require 'rails_helper' -RSpec.describe GenerativeText::AWS::Client::InvokeModelResponse do +RSpec.describe GenerativeText::AWS::InvokeModelResponse do subject(:response) { described_class.new(json_response) } let(:content) { Faker::Lorem.paragraph } diff --git a/spec/lib/generative_text/aws/client/invoke_model_stream_response_spec.rb b/spec/lib/generative_text/aws/invoke_model_stream_response_spec.rb similarity index 95% rename from spec/lib/generative_text/aws/client/invoke_model_stream_response_spec.rb rename to spec/lib/generative_text/aws/invoke_model_stream_response_spec.rb index a0601f3b..c109bf29 100644 --- a/spec/lib/generative_text/aws/client/invoke_model_stream_response_spec.rb +++ b/spec/lib/generative_text/aws/invoke_model_stream_response_spec.rb @@ -1,6 +1,6 @@ require 'rails_helper' -RSpec.describe GenerativeText::AWS::Client::InvokeModelStreamResponse do +RSpec.describe GenerativeText::AWS::InvokeModelStreamResponse do let(:input_token_count) { 9 } let(:output_token_count) { 104 } let(:data) do diff --git a/spec/models/generate_text_request_spec.rb b/spec/models/generate_text_request_spec.rb index fcf5fe05..145b07cc 100644 --- a/spec/models/generate_text_request_spec.rb +++ b/spec/models/generate_text_request_spec.rb @@ -25,13 +25,14 @@ describe '#system_message' do subject(:request) do - build(:generate_text_request, generate_text_preset: preset) + build(:generate_text_request, generate_text_preset: preset, markdown_format:) end + let(:markdown_format) { true } let(:preset) { build(:generate_text_preset, system_message: 'Custom system message') } it 'combines markdown format message with preset system message' do - expected_message = "#{described_class::MARKDOWN_FORMAT_SYSTEM_MESSAGE}\nCustom system message" + expected_message = "#{GenerativeText::Helpers.markdown_sys_msg}\nCustom system message" expect(request.system_message).to eq(expected_message) end @@ -39,26 +40,48 @@ let(:preset) { nil } it 'returns only the markdown format message' do - expect(request.system_message).to eq(described_class::MARKDOWN_FORMAT_SYSTEM_MESSAGE) + expect(request.system_message).to eq(GenerativeText::Helpers.markdown_sys_msg) + end + + context 'when markdown_format is false' do + let(:markdown_format) { false } + + it 'is blank' do + expect(request.system_message).to be_blank + end end end - end - describe '#response' do - subject(:request) do - build(:generate_text_request, response:) + context 'when markdown_format is false' do + let(:markdown_format) { false } + + it 'returns only the markdown format message' do + expect(request.system_message).to eq(preset.system_message) + end end + end + describe '#response' do context 'when response data exists' do - let(:response) { { 'content' => 'Test response' } } + context 'with an Anthropic model' do + subject(:request) { build(:generate_text_request, :with_response, :with_anthropic_model) } + + it 'returns an Anthropic::InvokeModelResponse instance' do + expect(request.response).to be_a(GenerativeText::Anthropic::InvokeModelResponse) + end + end + + context 'with an AWS model' do + subject(:request) { build(:generate_text_request, :with_response, :with_aws_model) } - it 'returns an InvokeModelResponse instance' do - expect(request.response).to be_a(GenerativeText::Anthropic::InvokeModelResponse) + it 'returns an AWS::InvokeModelResponse instance' do + expect(request.response).to be_a(GenerativeText::AWS::InvokeModelResponse) + end end end context 'when response is nil' do - let(:response) { nil } + subject(:request) { build(:generate_text_request, response: nil) } it 'returns nil' do expect(request.response).to be_nil diff --git a/spec/sidekiq/transcription_summary_job_spec.rb b/spec/sidekiq/transcription_summary_job_spec.rb index 4e6c32b0..5639b0c2 100644 --- a/spec/sidekiq/transcription_summary_job_spec.rb +++ b/spec/sidekiq/transcription_summary_job_spec.rb @@ -9,7 +9,7 @@ let(:transcriptions) { Transcription.none } let(:generative_text) { instance_double(GenerativeText) } let(:invoke_model_stream_response) do - instance_double GenerativeText::AWS::Client::InvokeModelStreamResponse, + instance_double GenerativeText::AWS::InvokeModelStreamResponse, content: 'response content', final_chunk?: final_chunk? end let(:final_chunk?) { false }