Skip to content

Commit

Permalink
update to Spring AI 1.0.0-M5
Browse files Browse the repository at this point in the history
  • Loading branch information
alexcheng1982 committed Dec 26, 2024
1 parent c09afe2 commit cf2004f
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ open class LLMPlanExecutor(
spec.text(userInput).param("agent_scratchpad", thoughts)
}
.call().content()
if (response.isEmpty()) {
if (response?.isEmpty() != false) {
return ActionPlanningResult.finish(
AgentFinish.fromOutput(
"No response from LLM",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,20 @@ class ReActJsonPromptAdvisor : CallAroundAdvisor {
advisedRequest: AdvisedRequest,
chain: CallAroundAdvisorChain
): AdvisedResponse {
val systemParams = HashMap(advisedRequest.systemParams ?: mapOf())
val systemParams = HashMap(advisedRequest.systemParams)
systemParams["system_instruction"] = advisedRequest.systemText ?: ""
val userParams = HashMap(advisedRequest.userParams ?: mapOf())
userParams["user_input"] = advisedRequest.userText ?: ""
val userParams = HashMap(advisedRequest.userParams)
userParams["user_input"] = advisedRequest.userText
val chatOptions = ChatOptionsHelper.buildChatOptions(
advisedRequest.chatOptions,
ChatOptionsConfigurer.ChatOptionsConfig(listOf("Observation:"))
)
val request = AdvisedRequest.from(advisedRequest)
.withSystemText(defaultSystemTextTemplate)
.withSystemParams(systemParams)
.withUserText(defaultUserTextTemplate)
.withUserParams(userParams)
.withChatOptions(chatOptions)
.build()
return chain.nextAroundCall(request)
val builder = AdvisedRequest.from(advisedRequest)
.systemText(defaultSystemTextTemplate)
.systemParams(systemParams)
.userText(defaultUserTextTemplate)
.userParams(userParams)
chatOptions?.let { builder.chatOptions(it) }
return chain.nextAroundCall(builder.build())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,20 @@ class ReActPromptAdvisor : CallAroundAdvisor {
advisedRequest: AdvisedRequest,
chain: CallAroundAdvisorChain
): AdvisedResponse {
val systemParams = HashMap(advisedRequest.systemParams ?: mapOf())
val systemParams = HashMap(advisedRequest.systemParams)
systemParams["system_instruction"] = advisedRequest.systemText ?: ""
val userParams = HashMap(advisedRequest.userParams ?: mapOf())
userParams["user_input"] = advisedRequest.userText ?: ""
val userParams = HashMap(advisedRequest.userParams)
userParams["user_input"] = advisedRequest.userText
val chatOptions = ChatOptionsHelper.buildChatOptions(
advisedRequest.chatOptions,
ChatOptionsConfigurer.ChatOptionsConfig(listOf("\\nObservation"))
)
val request = AdvisedRequest.from(advisedRequest)
.withSystemText(defaultSystemTextTemplate)
.withSystemParams(systemParams)
.withUserText(defaultUserTextTemplate)
.withUserParams(userParams)
.withChatOptions(chatOptions)
.build()
return chain.nextAroundCall(request)
val builder = AdvisedRequest.from(advisedRequest)
.systemText(defaultSystemTextTemplate)
.systemParams(systemParams)
.userText(defaultUserTextTemplate)
.userParams(userParams)
chatOptions?.let { builder.chatOptions(it) }
return chain.nextAroundCall(builder.build())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ class StructuredChatPromptAdvisor : CallAroundAdvisor {
advisedRequest: AdvisedRequest,
chain: CallAroundAdvisorChain
): AdvisedResponse {
val systemParams = HashMap(advisedRequest.systemParams ?: mapOf())
val systemParams = HashMap(advisedRequest.systemParams)
systemParams["system_instruction"] = advisedRequest.systemText ?: ""
val userParams = HashMap(advisedRequest.userParams ?: mapOf())
userParams["user_input"] = advisedRequest.userText ?: ""
val userParams = HashMap(advisedRequest.userParams)
userParams["user_input"] = advisedRequest.userText
val request = AdvisedRequest.from(advisedRequest)
.withSystemText(defaultSystemTextTemplate)
.withSystemParams(systemParams)
.withUserText(defaultUserTextTemplate)
.withUserParams(userParams)
.systemText(defaultSystemTextTemplate)
.systemParams(systemParams)
.userText(defaultUserTextTemplate)
.userParams(userParams)
.build()
return chain.nextAroundCall(request)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package io.github.llmagentbuilder.core

import org.springframework.ai.chat.model.ChatModel
import org.springframework.ai.model.function.FunctionCallbackContext
import org.springframework.ai.model.function.FunctionCallbackResolver

interface ChatModelProvider {
fun configKey(): String

fun provideChatModel(
functionCallbackContext: FunctionCallbackContext,
functionCallbackResolver: FunctionCallbackResolver,
config: Map<String, Any?>? = null,
): ChatModel?
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,31 @@ package io.github.llmagentbuilder.core.tool

import io.micrometer.observation.ObservationRegistry
import org.slf4j.LoggerFactory
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver
import org.springframework.ai.model.function.FunctionCallback
import org.springframework.ai.model.function.FunctionCallbackContext

class AgentToolFunctionCallbackContext(
agentToolsProvider: AgentToolsProvider,
observationRegistry: ObservationRegistry? = null,
) :
FunctionCallbackContext() {
DefaultFunctionCallbackResolver() {
private val agentToolWrappersProvider =
AgentToolWrappersProvider(agentToolsProvider, observationRegistry)

private val logger = LoggerFactory.getLogger(javaClass)

override fun getFunctionCallback(
beanName: String,
defaultDescription: String?
): FunctionCallback {
override fun resolve(name: String): FunctionCallback {
try {
return super.getFunctionCallback(beanName, defaultDescription)
return super.resolve(name)
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug(
"Failed to get bean {} from application context, ignoring",
beanName
name
)
}
}
return agentToolWrappersProvider.get()[beanName]
?: throw IllegalArgumentException("Function $beanName not found")
return agentToolWrappersProvider.get()[name]
?: throw IllegalArgumentException("Function $name not found")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ import io.github.llmagentbuilder.core.observation.AgentToolExecutionObservationD
import io.github.llmagentbuilder.core.observation.DefaultAgentToolExecutionObservationConvention
import io.micrometer.observation.ObservationRegistry
import org.slf4j.LoggerFactory
import org.springframework.ai.model.function.DefaultFunctionCallbackBuilder
import org.springframework.ai.model.function.FunctionCallback
import org.springframework.ai.model.function.FunctionCallbackContext
import org.springframework.ai.model.function.FunctionCallbackWrapper
import org.springframework.core.GenericTypeResolver
import java.util.*
import java.util.function.Supplier
Expand All @@ -34,39 +33,38 @@ class AgentToolWrappersProvider(
AgentTool::class.java
)
InstrumentedFunctionCallbackWrapper(
FunctionCallbackWrapper.builder(tool)
.withName(tool.name())
.withSchemaType(FunctionCallbackContext.SchemaType.JSON_SCHEMA)
.withDescription(tool.description())
.withInputType(
DefaultFunctionCallbackBuilder().function(tool.name(), tool)
.schemaType(FunctionCallback.SchemaType.JSON_SCHEMA)
.description(tool.description())
.inputType(
types?.get(0)
?: throw IllegalArgumentException("Bad type")
)
.withObjectMapper(objectMapper)
.objectMapper(objectMapper)
.build(), observationRegistry
)
}
}

private class InstrumentedFunctionCallbackWrapper<I, O>(
private val functionCallbackWrapper: FunctionCallbackWrapper<I, O>,
private class InstrumentedFunctionCallbackWrapper(
private val functionCallback: FunctionCallback,
private val observationRegistry: ObservationRegistry? = null
) :
FunctionCallback {
override fun getName(): String {
return functionCallbackWrapper.name
return functionCallback.name
}

override fun getDescription(): String {
return functionCallbackWrapper.description
return functionCallback.description
}

override fun getInputTypeSchema(): String {
return functionCallbackWrapper.inputTypeSchema
return functionCallback.inputTypeSchema
}

override fun call(functionInput: String): String {
val action = { functionCallbackWrapper.call(functionInput) }
val action = { functionCallback.call(functionInput) }
return observationRegistry?.let { registry ->
instrumentedCall(functionInput, action, registry)
} ?: action.invoke()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.github.llmagentbuilder.core.ChatModelProvider
import io.github.llmagentbuilder.core.MapToObject
import org.apache.commons.lang3.StringUtils
import org.springframework.ai.chat.model.ChatModel
import org.springframework.ai.model.function.FunctionCallbackContext
import org.springframework.ai.model.function.FunctionCallbackResolver
import org.springframework.ai.openai.OpenAiChatModel
import org.springframework.ai.openai.OpenAiChatOptions
import org.springframework.ai.openai.api.OpenAiApi
Expand All @@ -16,7 +16,7 @@ class OpenAiChatModelProvider : ChatModelProvider {
}

override fun provideChatModel(
functionCallbackContext: FunctionCallbackContext,
functionCallbackResolver: FunctionCallbackResolver,
config: Map<String, Any?>?,
): ChatModel? {
val openAiConfig = MapToObject.toObject<OpenAiConfig>(config)
Expand All @@ -33,8 +33,8 @@ class OpenAiChatModelProvider : ChatModelProvider {
openAiConfig?.model ?: OpenAiApi.ChatModel.GPT_3_5_TURBO.value
val chatModel = OpenAiChatModel(
OpenAiApi(apiKey),
OpenAiChatOptions.builder().withModel(model).build(),
functionCallbackContext,
OpenAiChatOptions.builder().model(model).build(),
functionCallbackResolver,
RetryTemplate.defaultInstance()
)
return chatModel
Expand Down
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
<kotlin.code.style>official</kotlin.code.style>
<kotlin.compiler.jvmTarget>21</kotlin.compiler.jvmTarget>
<java.version>21</java.version>
<spring-ai.version>1.0.0-SNAPSHOT</spring-ai.version>
<kotlin.version>1.9.23</kotlin.version>
<spring-ai.version>1.0.0-M5</spring-ai.version>
<kotlin.version>2.1.0</kotlin.version>
<spring-boot.version>3.2.4</spring-boot.version>
<spring.version>6.1.4</spring.version>
<jackson.version>2.16.1</jackson.version>
Expand Down

0 comments on commit cf2004f

Please sign in to comment.