-
-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathOpenAI.kt
executable file
·334 lines (298 loc) · 12.2 KB
/
OpenAI.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
package com.cjcrafter.openai
import com.cjcrafter.openai.assistants.AssistantHandler
import com.cjcrafter.openai.assistants.AssistantHandlerImpl
import com.cjcrafter.openai.chat.*
import com.cjcrafter.openai.chat.tool.ToolChoice
import com.cjcrafter.openai.completions.CompletionRequest
import com.cjcrafter.openai.completions.CompletionResponse
import com.cjcrafter.openai.completions.CompletionResponseChunk
import com.cjcrafter.openai.embeddings.EmbeddingsRequest
import com.cjcrafter.openai.embeddings.EmbeddingsResponse
import com.cjcrafter.openai.files.*
import com.cjcrafter.openai.moderations.ModerationHandler
import com.cjcrafter.openai.threads.ThreadHandler
import com.cjcrafter.openai.threads.message.TextAnnotation
import com.cjcrafter.openai.util.OpenAIDslMarker
import com.fasterxml.jackson.annotation.JsonAutoDetect
import com.fasterxml.jackson.annotation.JsonInclude
import com.fasterxml.jackson.databind.DeserializationFeature
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.module.SimpleModule
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import okhttp3.OkHttpClient
import org.jetbrains.annotations.ApiStatus
import org.jetbrains.annotations.Contract
import org.slf4j.LoggerFactory
/**
* The main interface for the OpenAI API. This interface contains methods for
* all the API endpoints. To instantiate an instance of this interface, use
* [builder].
*
* All the methods in this class are blocking (except the stream methods,
* [streamCompletion] and [streamChatCompletion], which return an iterator
* which blocks the thread).
*
* The methods in this class all throw io exceptions if the request fails. The
* error message will contain the JSON response from the API (if present).
* Common errors include:
* 1. Not having a valid API key
* 2. Passing a bad parameter to a request
*/
interface OpenAI {
/**
* Calls the [completions](https://platform.openai.com/docs/api-reference/completions)
* API endpoint. This method is blocking.
*
* Completions are considered Legacy, and OpenAI officially recommends that
* all developers use the **chat completion** endpoint instead. See
* [createChatCompletion].
*
* @param request The request to send to the API
* @return The response from the API
*/
@ApiStatus.Obsolete
@Contract(pure = true)
fun createCompletion(request: CompletionRequest): CompletionResponse
/**
* Calls the [completions](https://platform.openai.com/docs/api-reference/completions)
* API endpoint and streams each token 1 at a time for a faster response
* time.
*
* This method is **technically** not blocking, but the returned iterable
* will block until the next token is generated.
* ```
* // Each iteration of the loop will block until the next token is streamed
* for (chunk in openAI.streamCompletion(request)) {
* // Do something with the chunk
* }
* ```
*
* Completions are considered Legacy, and OpenAI officially recommends that
* all developers use the **chat completion** endpoint isntead. See
* [streamChatCompletion].
*
* @param request The request to send to the API
* @return The response from the API
*/
@ApiStatus.Obsolete
@Contract(pure = true)
fun streamCompletion(request: CompletionRequest): Iterable<CompletionResponseChunk>
/**
* Calls the [chat completions](https://platform.openai.com/docs/api-reference/chat)
* API endpoint. This method is blocking.
*
* @param request The request to send to the API
* @return The response from the API
*/
@Contract(pure = true)
fun createChatCompletion(request: ChatRequest): ChatResponse
/**
* Calls the [chat completions](https://platform.openai.com/docs/api-reference/chat)
* API endpoint and streams each token 1 at a time for a faster response.
*
* This method is **technically** not blocking, but the returned iterable
* will block until the next token is generated.
* ```
* // Each iteration of the loop will block until the next token is streamed
* for (chunk in openAI.streamChatCompletion(request)) {
* // Do something with the chunk
* }
* ```
*
* @param request The request to send to the API
* @return The response from the API
*/
@Contract(pure = true)
fun streamChatCompletion(request: ChatRequest): Iterable<ChatResponseChunk>
/**
* Calls the [embeddings](https://beta.openai.com/docs/api-reference/embeddings)
* API endpoint to generate the vector representation of text. The returned
* vector can be used in Machine Learning models. This method is blocking.
*
* @param request The request to send to the API
* @return The response from the API
*/
@Contract(pure = true)
@ApiStatus.Experimental
fun createEmbeddings(request: EmbeddingsRequest): EmbeddingsResponse
/**
* Returns the handler for the files endpoint. This handler can be used to
* create, retrieve, and delete files.
*/
val files: FileHandler
/**
* Returns the handler for the files endpoint. This method is purely
* syntactic sugar for Java users.
*/
@Contract(pure = true)
fun files(): FileHandler = files
/**
* Returns the handler for the moderations endpoint. This handler can be used
* to create moderations.
*/
val moderations: ModerationHandler
/**
* Returns the handler for the moderations endpoint. This method is purely
* syntactic sugar for Java users.
*/
@Contract(pure = true)
fun moderations(): ModerationHandler = moderations
/**
* Returns the handler for the assistants endpoint. This handler can be used
* to create, retrieve, and delete assistants.
*/
@get:ApiStatus.Experimental
val assistants: AssistantHandler
/**
* Returns the handler for the assistants endpoint. This method is purely
* syntactic sugar for Java users.
*/
@ApiStatus.Experimental
@Contract(pure = true)
fun assistants(): AssistantHandler = assistants
/**
* Returns the handler for the threads endpoint. This handler can be used
* to create, retrieve, and delete threads.
*/
@get:ApiStatus.Experimental
val threads: ThreadHandler
/**
* Returns the handler for the threads endpoint. This method is purely
* syntactic sugar for Java users.
*/
@ApiStatus.Experimental
@Contract(pure = true)
fun threads(): ThreadHandler = threads
/**
* Constructs a default [OpenAI] instance.
*/
@OpenAIDslMarker
open class Builder internal constructor() {
protected var apiKey: String? = null
protected var organization: String? = null
protected var client: OkHttpClient = OkHttpClient()
protected var baseUrl: String = "https://api.openai.com"
/**
* Sets the API key to use for requests. This is required.
*
* Your API key can be found at: [https://platform.openai.com/api-keys](https://platform.openai.com/api-keys).
*
* @param apiKey The API key to use for requests, starting with `sk-`
*/
fun apiKey(apiKey: String) = apply { this.apiKey = apiKey }
/**
* If you belong to multiple organizations, you can specify which one to use.
* Defaults to your default organization configured in the OpenAI dashboard.
*
* @param organization The organization ID to use for requests, starting with `org-`
*/
fun organization(organization: String?) = apply { this.organization = organization }
/**
* Sets the [OkHttpClient] used to make requests. Modify this if you want to
* change the timeout, add interceptors, add a proxy, etc.
*
* @param client The client to use for requests
*/
fun client(client: OkHttpClient) = apply { this.client = client }
/**
* Sets the base URL to use for requests. This is useful for testing.
* This can also be used to use the Azure OpenAI API, though we
* recommend using [azureBuilder] instead for that. Defaults to
* `https://api.openai.com`.
*
* @param baseUrl The base url
*/
fun baseUrl(baseUrl: String) = apply { this.baseUrl = baseUrl }
/**
* Builds the OpenAI instance.
*/
@Contract(pure = true)
open fun build(): OpenAI {
return OpenAIImpl(
apiKey = apiKey ?: throw IllegalStateException("apiKey must be defined to use OpenAI"),
organization = organization,
client = client,
baseUrl = baseUrl,
)
}
}
@OpenAIDslMarker
class AzureBuilder internal constructor(): Builder() {
private var apiVersion: String? = null
private var modelName: String? = null
/**
* Sets the azure api version
*/
fun apiVersion(apiVersion: String) = apply { this.apiVersion = apiVersion }
/**
* Sets the azure model name
*/
fun modelName(modelName: String) = apply { this.modelName = modelName }
/**
* Builds the OpenAI instance.
*/
@Contract(pure = true)
override fun build(): OpenAI {
return AzureOpenAI(
apiKey = apiKey ?: throw IllegalStateException("apiKey must be defined to use OpenAI"),
organization = organization,
client = client,
baseUrl = if (baseUrl == "https://api.openai.com") throw IllegalStateException("baseUrl must be set to an azure endpoint") else baseUrl,
apiVersion = apiVersion ?: throw IllegalStateException("apiVersion must be defined for azure"),
modelName = modelName ?: throw IllegalStateException("modelName must be defined for azure")
)
}
}
companion object {
internal val logger = LoggerFactory.getLogger(OpenAI::class.java)
/**
* Instantiates a builder for a default OpenAI instance. For Azure's
* OpenAI, use [azureBuilder] instead.
*/
@JvmStatic
@Contract(pure = true)
fun builder() = Builder()
/**
* Instantiates a builder for an Azure OpenAI.
*/
@JvmStatic
@Contract(pure = true)
fun azureBuilder() = AzureBuilder()
/**
* Returns an ObjectMapper instance with the default OpenAI adapters registered.
* This can be used to save conversations (and other data) to file.
*/
@Contract(pure = true)
fun createObjectMapper(): ObjectMapper = jacksonObjectMapper().apply {
setSerializationInclusion(JsonInclude.Include.NON_NULL)
configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
// By default, Jackson can serialize fields AND getters. We just want fields.
setVisibility(serializationConfig.getDefaultVisibilityChecker()
.withFieldVisibility(JsonAutoDetect.Visibility.ANY)
.withGetterVisibility(JsonAutoDetect.Visibility.NONE)
.withSetterVisibility(JsonAutoDetect.Visibility.NONE)
.withCreatorVisibility(JsonAutoDetect.Visibility.NONE)
)
// Register modules with custom serializers/deserializers
val module = SimpleModule().apply {
addSerializer(ToolChoice::class.java, ToolChoice.serializer())
addDeserializer(ToolChoice::class.java, ToolChoice.deserializer())
}
registerModule(module)
}
/**
* Extension function to stream a completion using kotlin coroutines.
*/
fun OpenAI.streamCompletion(request: CompletionRequest, consumer: (CompletionResponseChunk) -> Unit) {
for (chunk in streamCompletion(request))
consumer(chunk)
}
/**
* Extension function to stream a chat completion using kotlin coroutines.
*/
fun OpenAI.streamChatCompletion(request: ChatRequest, consumer: (ChatResponseChunk) -> Unit) {
for (chunk in streamChatCompletion(request))
consumer(chunk)
}
}
}