Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Google GenerativeAI SDK updated to latest version #163

Merged
merged 4 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<PackageVersion Include="FluentAssertions" Version="8.0.1" />
<PackageVersion Include="GitHubActionsTestLogger" Version="2.4.1" />
<PackageVersion Include="Google.Cloud.AIPlatform.V1" Version="3.9.0" />
<PackageVersion Include="Google_GenerativeAI" Version="1.0.2" />
<PackageVersion Include="Google_GenerativeAI" Version="2.0.4" />
<PackageVersion Include="GroqSharp" Version="1.1.2" />
<PackageVersion Include="H.Generators.Extensions" Version="1.22.0" />
<PackageVersion Include="H.Generators.Tests.Extensions" Version="1.22.0" />
Expand Down
66 changes: 41 additions & 25 deletions src/Google/src/Extensions/GoogleGeminiExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,47 +1,63 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using CSharpToJsonSchema;
using GenerativeAI.Tools;
using GenerativeAI;
using GenerativeAI.Types;
using Tool = CSharpToJsonSchema.Tool;

namespace LangChain.Providers.Google.Extensions;

internal static class GoogleGeminiExtensions
{
public static bool IsFunctionCall(this EnhancedGenerateContentResponse response)
public static bool IsFunctionCall(this GenerateContentResponse response)
{
return response.GetFunction() != null;
}

public static List<GenerativeAITool> ToGenerativeAiTools(this IEnumerable<Tool> functions)
public static List<GenerativeAI.Types.Tool?> ToGenerativeAiTools(this IEnumerable<Tool> functions)
{
return new List<GenerativeAITool>([
new GenerativeAITool
var declarations = functions
.Where(x => x != null)
.Select(x => new FunctionDeclaration
{
FunctionDeclaration = functions.Select(x => new ChatCompletionFunction
Name = x.Name ?? string.Empty,
Description = x.Description ?? string.Empty,
Parameters = x.Parameters is OpenApiSchema schema ? ToFunctionParameters(schema) : null,
})
.ToList();

if (declarations.Any())
{
return new List<GenerativeAI.Types.Tool?>
{
new GenerativeAI.Types.Tool
{
Name = x.Name ?? string.Empty,
Description = x.Description ?? string.Empty,
Parameters = ToFunctionParameters((OpenApiSchema)x.Parameters!),
}).ToList(),
}
]);
FunctionDeclarations = declarations
}
};
}

return null;
}

public static string GetStringForFunctionArgs(this object? arguments)
{
if (arguments == null)
return string.Empty;
if (arguments is JsonElement jsonElement)
return jsonElement.ToString();
else
{
return null;
}
}
public static ChatCompletionFunctionParameters ToFunctionParameters(this OpenApiSchema schema)

public static Schema? ToFunctionParameters(this OpenApiSchema openApiSchema)
{
if (schema.Items == null) return new ChatCompletionFunctionParameters();
var parameters = new ChatCompletionFunctionParameters();

parameters.AdditionalProperties.Add("type", schema.Items.Type);
if (schema.Items.Description != null && !string.IsNullOrEmpty(schema.Items.Description))
parameters.AdditionalProperties.Add("description", schema.Items.Description);
if (schema.Items.Properties != null)
parameters.AdditionalProperties.Add("properties", schema.Items.Properties);
if (schema.Items.Required != null)
parameters.AdditionalProperties.Add("required", schema.Items.Required);

return parameters;
var text = JsonSerializer.Serialize(openApiSchema);
return JsonSerializer.Deserialize<Schema?>(text);
}

public static string GetString(this IDictionary<string, object>? arguments)
{
if (arguments == null)
Expand Down
24 changes: 18 additions & 6 deletions src/Google/src/Extensions/StringExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System.Text.Json;
using System.Text.Json.Nodes;
using GenerativeAI.Extensions;
using GenerativeAI.Tools;
using GenerativeAI;
using GenerativeAI.Types;

namespace LangChain.Providers.Google.Extensions;
Expand Down Expand Up @@ -56,10 +55,9 @@ public static Content AsFunctionCallContent(this string args, string functionNam
var content = new Content([
new Part
{
FunctionCall = new ChatFunctionCall
FunctionCall = new FunctionCall()
{
Arguments = JsonSerializer.Deserialize(args, SourceGenerationContext.Default.DictionaryStringString)?
.ToDictionary(x => x.Key, x => (object)x.Value) ?? [],
Args = JsonNode.Parse(args),
Name = functionName
}
}
Expand All @@ -76,6 +74,20 @@ public static Content AsFunctionCallContent(this string args, string functionNam
[CLSCompliant(false)]
public static Content AsFunctionResultContent(this string args, string functionName)
{
return JsonNode.Parse(args).ToFunctionCallContent(functionName);
var functionResponse = new FunctionResponse()
{
Response = new
{
Name = functionName,
Content = JsonNode.Parse(args)
},
Name = functionName
};
var content = new Content(){Role = Roles.Function};
content.AddPart(new Part()
{
FunctionResponse = functionResponse
});
return content;
}
}
2 changes: 1 addition & 1 deletion src/Google/src/GoogleChatModel.Tokens.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public async Task<int> CountTokens(string text)

public async Task<int> CountTokens(IEnumerable<Message> messages)
{
var response = await this.Api.CountTokens(new CountTokensRequest() { Contents = messages.Select(ToRequestMessage).ToArray() }).ConfigureAwait(false);
var response = await this.Api.CountTokensAsync(new CountTokensRequest() { Contents = messages.Select(ToRequestMessage).ToList() }).ConfigureAwait(false);

return response.TotalTokens;
}
Expand Down
44 changes: 29 additions & 15 deletions src/Google/src/GoogleChatModel.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using GenerativeAI.Models;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using GenerativeAI;
using GenerativeAI.Core;
using GenerativeAI.Types;
using LangChain.Providers.Google.Extensions;

Expand All @@ -27,11 +31,14 @@ public partial class GoogleChatModel(
private GenerativeModel Api { get; } = new(
provider.ApiKey,
id,
provider.HttpClient)
httpClient:provider.HttpClient)
{
AutoCallFunction = false,
AutoReplyFunction = false,
AutoHandleBadFunctionCalls = false
FunctionCallingBehaviour = new FunctionCallingBehaviour()
{
AutoCallFunction = false,
AutoReplyFunction = false,
AutoHandleBadFunctionCalls = false
}
};

#endregion
Expand All @@ -52,13 +59,13 @@ private static Content ToRequestMessage(Message message)
};
}

private static Message ToMessage(EnhancedGenerateContentResponse message)
private static Message ToMessage(GenerateContentResponse message)
{
if (message.GetFunction() != null)
{
var function = message.GetFunction();

return new Message(function?.Arguments.GetString() ?? string.Empty,
return new Message( function?.Args.GetStringForFunctionArgs() ?? string.Empty,
MessageRole.ToolCall, function?.Name);
}

Expand All @@ -67,13 +74,13 @@ private static Message ToMessage(EnhancedGenerateContentResponse message)
MessageRole.Ai);
}

private async Task<EnhancedGenerateContentResponse> CreateChatCompletionAsync(
private async Task<GenerateContentResponse> CreateChatCompletionAsync(
IReadOnlyCollection<Message> messages,
CancellationToken cancellationToken = default)
{
var request = new GenerateContentRequest
{
Contents = messages.Select(ToRequestMessage).ToArray(),
Contents = messages.Select(ToRequestMessage).ToList(),
Tools = GlobalTools.ToGenerativeAiTools()
};

Expand All @@ -94,7 +101,7 @@ private async Task<Message> StreamCompletionAsync(IReadOnlyCollection<Message> m
{
var request = new GenerateContentRequest
{
Contents = messages.Select(ToRequestMessage).ToArray()
Contents = messages.Select(ToRequestMessage).ToList()
};
if (provider.Configuration != null)
request.GenerationConfig = new GenerationConfig
Expand All @@ -104,11 +111,18 @@ private async Task<Message> StreamCompletionAsync(IReadOnlyCollection<Message> m
TopP = provider.Configuration.TopP,
Temperature = provider.Configuration.Temperature
};
var res = await Api.StreamContentAsync(request, OnDeltaReceived, cancellationToken)
.ConfigureAwait(false);
StringBuilder sb = new StringBuilder();
await foreach (var response in Api.StreamContentAsync(request, cancellationToken))
{
var text = response.Text() ?? string.Empty;

sb.Append(text);
OnDeltaReceived(text);
}


return new Message(
res,
sb.ToString(),
MessageRole.Ai);
}

Expand Down Expand Up @@ -173,7 +187,7 @@ public override async IAsyncEnumerable<ChatResponse> GenerateAsync(

if (Calls.TryGetValue(name, out var func))
{
var args = function?.Arguments.GetString() ?? string.Empty;
var args = function?.Args.GetStringForFunctionArgs() ?? string.Empty;

var jsonResult = await func(args, cancellationToken).ConfigureAwait(false);
messages.Add(jsonResult.AsToolResultMessage(name));
Expand Down Expand Up @@ -216,7 +230,7 @@ public override async IAsyncEnumerable<ChatResponse> GenerateAsync(

yield return chatResponse;
}
private Usage GetUsage(EnhancedGenerateContentResponse response)
private Usage GetUsage(GenerateContentResponse response)
{
var outputTokens = response.UsageMetadata?.CandidatesTokenCount ?? 0;
var inputTokens = response.UsageMetadata?.PromptTokenCount ?? 0;
Expand Down
12 changes: 9 additions & 3 deletions src/Google/src/GoogleConfiguration.cs
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
namespace LangChain.Providers.Google;

/// <summary>
/// Configuration options for the Google AI provider.
/// </summary>
public class GoogleConfiguration
{
/// <summary>
/// Gets or sets the API key used for authentication with the Google AI service.
/// </summary>
public string? ApiKey { get; set; }

/// <summary>
/// ID of the model to use. <br />
/// Gets or sets the ID of the model to use. The default value is "gemini-1.5-flash".
/// </summary>
public string? ModelId { get; set; } = "gemini-pro";
public string? ModelId { get; set; } = "gemini-1.5-flash";

/// <summary>
/// Gets or sets the Top-K sampling value, which determines the number of highest-probability tokens considered during decoding.
/// </summary>
public int? TopK { get; set; } = default!;

/// <summary>
/// Gets or sets the Top-P sampling value, which determines the cumulative probability threshold for token selection during decoding.
/// </summary>
public double? TopP { get; set; } = default!;

/// <summary>
/// Gets or sets the temperature value, which controls the randomness of the output.
/// Higher values produce more random results, while lower values make the output more deterministic. The default is 1.0.
/// </summary>
public double? Temperature { get; set; } = 1D;

/// <summary>
/// Maximum Output Tokens
/// Gets or sets the maximum number of output tokens allowed in the response.
/// </summary>
public int? MaxOutputTokens { get; set; } = default!;
}
67 changes: 67 additions & 0 deletions src/Google/src/GoogleEmbeddingModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using GenerativeAI;
using GenerativeAI.Exceptions;
using GenerativeAI.Types;
using LangChain.Providers.OpenAI;

Check failure on line 4 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The type or namespace name 'OpenAI' does not exist in the namespace 'LangChain.Providers' (are you missing an assembly reference?)

Check failure on line 4 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The type or namespace name 'OpenAI' does not exist in the namespace 'LangChain.Providers' (are you missing an assembly reference?)

Check failure on line 4 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The type or namespace name 'OpenAI' does not exist in the namespace 'LangChain.Providers' (are you missing an assembly reference?)

Check failure on line 4 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The type or namespace name 'OpenAI' does not exist in the namespace 'LangChain.Providers' (are you missing an assembly reference?)
using tryAGI.OpenAI;

Check failure on line 5 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The type or namespace name 'tryAGI' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 5 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The type or namespace name 'tryAGI' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 5 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The type or namespace name 'tryAGI' could not be found (are you missing a using directive or an assembly reference?)

namespace LangChain.Providers.Google;

using System.Diagnostics;

public class GoogleEmbeddingModel(
GoogleProvider provider,
string id)
: Model<EmbeddingSettings>(id), IEmbeddingModel
{
public GoogleEmbeddingModel(
GoogleProvider provider,
CreateEmbeddingRequestModel id)

Check failure on line 18 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The type or namespace name 'CreateEmbeddingRequestModel' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 18 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The type or namespace name 'CreateEmbeddingRequestModel' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 18 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The type or namespace name 'CreateEmbeddingRequestModel' could not be found (are you missing a using directive or an assembly reference?)
: this(provider, id.ToValueString())
{
}

public EmbeddingModel EmbeddingModel { get; } =

Check warning on line 23 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Type of 'GoogleEmbeddingModel.EmbeddingModel' is not CLS-compliant

Check warning on line 23 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Type of 'GoogleEmbeddingModel.EmbeddingModel' is not CLS-compliant

Check warning on line 23 in src/Google/src/GoogleEmbeddingModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Type of 'GoogleEmbeddingModel.EmbeddingModel' is not CLS-compliant
new EmbeddingModel(provider.ApiKey, id, httpClient: provider.HttpClient);


public async Task<EmbeddingResponse> CreateEmbeddingsAsync(EmbeddingRequest request,
EmbeddingSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add input validation for request strings.

While the request itself is validated, there's no validation for request.Strings. Add checks for null or empty strings to prevent potential issues.

         request = request ?? throw new ArgumentNullException(nameof(request));
+        if (request.Strings == null || !request.Strings.Any())
+        {
+            throw new ArgumentException("Request must contain at least one string to embed.", nameof(request));
+        }
+        if (request.Strings.Any(string.IsNullOrEmpty))
+        {
+            throw new ArgumentException("Request strings cannot be null or empty.", nameof(request));
+        }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
request = request ?? throw new ArgumentNullException(nameof(request));
request = request ?? throw new ArgumentNullException(nameof(request));
if (request.Strings == null || !request.Strings.Any())
{
throw new ArgumentException("Request must contain at least one string to embed.", nameof(request));
}
if (request.Strings.Any(string.IsNullOrEmpty))
{
throw new ArgumentException("Request strings cannot be null or empty.", nameof(request));
}


var watch = Stopwatch.StartNew();

var usedSettings = GoogleEmbeddingSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.EmbeddingSettings);

var embedRequest = new EmbedContentRequest();
embedRequest.Content = new Content();
embedRequest.Content.AddParts(request.Strings.Select(s => new Part(s)));

embedRequest.OutputDimensionality = usedSettings.OutputDimensionality;
var embedResponse = await this.EmbeddingModel.EmbedContentAsync(embedRequest, cancellationToken)
.ConfigureAwait(false);

var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);

if (embedResponse.Embedding == null)
throw new GenerativeAIException("Failed to create embeddings.", "");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Provide meaningful error message.

The error message for embedding creation failure is empty. Include details about what might have caused the failure to help with debugging.

-            throw new GenerativeAIException("Failed to create embeddings.", "");
+            throw new GenerativeAIException(
+                "Failed to create embeddings. The embedding response was null.",
+                "Ensure the model is properly configured and the input is valid.");
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
throw new GenerativeAIException("Failed to create embeddings.", "");
throw new GenerativeAIException(
"Failed to create embeddings. The embedding response was null.",
"Ensure the model is properly configured and the input is valid.");

var values = embedResponse.Embedding.Values.ToList();

return new EmbeddingResponse
{
Values = new[] { values.ToArray() }.ToArray(),
Usage = Usage.Empty,
UsedSettings = usedSettings,
Dimensions = values.Count,
};
}
}
Loading
Loading