Skip to content

Commit

Permalink
llama: preserve field order in user-defined JSON schemas (ollama#8002)
Browse files Browse the repository at this point in the history
Previously we decoded and re-encoded JSON schemas during validation,
which served no purpose since json.RawMessage already validates JSON
syntax. Worse, the re-encoding lost field ordering from the original
schema, which affects inference quality during step-by-step reasoning.

While fixing this ordering issue by using json.RawMessage directly,
testing revealed that schema_to_grammar (from llama.cpp) also fails to
preserve field order during grammar generation. This appears to be the
root cause of inference degradation.

This change prevents us from mangling the user's original schema order,
but we still need to address the ordering issue in schema_to_grammar.
That will be a separate change.

Updates ollama#7978
  • Loading branch information
bmizerany authored Dec 11, 2024
1 parent 581a4a5 commit 9039c82
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 114 deletions.
80 changes: 80 additions & 0 deletions llama/grammar_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package llama

import (
"bufio"
"bytes"
"strings"
"testing"
)

// https://github.com/ollama/ollama/issues/7978
const issue7978JSONSchema = `{
"type": "object",
"properties": {
"steps": {
"type": "array",
"items": {
"type": "object",
"properties": {
"explanation": { "type": "string" },
"output": { "type": "string" }
},
"required": ["explanation", "output"],
"additionalProperties": false
}
},
"final_answer": { "type": "string" }
},
"required": ["steps", "final_answer"],
"additionalProperties": false
}`

func TestIssue7978(t *testing.T) {
t.Skip("schema_to_grammar is broken; skipping until fixed")

g := SchemaToGrammar([]byte(issue7978JSONSchema))
if g == nil {
t.Fatal("failed to convert JSON schema to grammar")
}

t.Logf("grammar:\n%s", g)
t.Log()

var sawSteps bool
s := bufio.NewScanner(bytes.NewReader(g))
for s.Scan() {
line := s.Text()
if strings.Contains(line, "steps") {
sawSteps = true
}
if strings.Contains(line, "final-answer") && !sawSteps {
t.Error("expected 'steps' before 'final-answer'")
}
}
}

func TestSchemaToGrammer(t *testing.T) {
t.Skip("schema_to_grammar is broken; skipping until fixed")

cases := []struct {
schema string
prefix []byte // nil is check as nil
}{
{`invalid`, nil},

// Simple heuristic/smoke test
{`{"type":"object"}`, []byte("object ::=")},
}

for _, c := range cases {
t.Run("x", func(t *testing.T) {
g := SchemaToGrammar([]byte(c.schema))
if c.prefix == nil && g != nil {
t.Fatalf("grammar = %v, want nil", g)
}
if !bytes.HasPrefix(g, c.prefix) {
t.Errorf("grammar = %q, want %q", g, c.prefix)
}
})
}
}
32 changes: 9 additions & 23 deletions llama/llama.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,9 @@ COMPILER inline get_compiler() {
import "C"

import (
"bytes"
_ "embed"
"encoding/json"
"errors"
"fmt"
"log/slog"
"runtime"
"runtime/cgo"
"slices"
Expand Down Expand Up @@ -721,32 +718,21 @@ func (s *SamplingContext) Accept(id int, applyGrammar bool) {
C.common_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
}

type JsonSchema struct {
Defs map[string]any `json:"$defs,omitempty"`
Properties map[string]any `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
Title string `json:"title,omitempty"`
Type string `json:"type,omitempty"`
}

func (js JsonSchema) AsGrammar() string {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(js); err != nil {
return ""
}

cStr := C.CString(b.String())
// SchemaToGrammar converts the provided JSON schema to a grammar. It returns
// nil if the provided schema is invalid JSON or an invalid JSON schema.
func SchemaToGrammar(schema []byte) []byte {
cStr := C.CString(string(schema))
defer C.free(unsafe.Pointer(cStr))

// Allocate buffer for grammar output with reasonable size
const maxLen = 32768 // 32KB
buf := make([]byte, maxLen)

// Call C function to convert schema to grammar
length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
if length == 0 {
slog.Warn("unable to convert schema to grammar")
n := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
if n == 0 {
// preserve nil
return nil
}

return string(buf[:length])
return buf[:n]
}
69 changes: 0 additions & 69 deletions llama/llama_test.go
Original file line number Diff line number Diff line change
@@ -1,70 +1 @@
package llama

import (
"strings"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestJsonSchema(t *testing.T) {
testCases := []struct {
name string
schema JsonSchema
expected string
}{
{
name: "empty schema",
schema: JsonSchema{
Type: "object",
},
expected: `array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
null ::= "null" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
value ::= object | array | string | number | boolean | null`,
},
{
name: "invalid schema with circular reference",
schema: JsonSchema{
Type: "object",
Properties: map[string]any{
"self": map[string]any{
"$ref": "#", // Self reference
},
},
},
expected: "", // Should return empty string for invalid schema
},
{
name: "schema with invalid type",
schema: JsonSchema{
Type: "invalid_type", // Invalid type
Properties: map[string]any{
"foo": map[string]any{
"type": "string",
},
},
},
expected: "", // Should return empty string for invalid schema
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := tc.schema.AsGrammar()
if !strings.EqualFold(strings.TrimSpace(result), strings.TrimSpace(tc.expected)) {
if diff := cmp.Diff(tc.expected, result); diff != "" {
t.Fatalf("grammar mismatch (-want +got):\n%s", diff)
}
}
})
}
}
29 changes: 13 additions & 16 deletions llm/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
}
}

const jsonGrammar = `
var grammarJSON = `
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
Expand Down Expand Up @@ -722,22 +722,19 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("unexpected server status: %s", status.ToString())
}

// TODO (parthsareen): Move conversion to grammar with sampling logic
// API should do error handling for invalid formats
if req.Format != nil && strings.TrimSpace(string(req.Format)) != "null" {
if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
request["grammar"] = jsonGrammar
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
if len(req.Format) > 0 {
switch {
case bytes.Equal(req.Format, []byte(`"json"`)):
request["grammar"] = grammarJSON
case bytes.HasPrefix(req.Format, []byte("{")):
// User provided a JSON schema
g := llama.SchemaToGrammar(req.Format)
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
}
} else if schema, err := func() (llama.JsonSchema, error) {
var schema llama.JsonSchema
err := json.Unmarshal(req.Format, &schema)
return schema, err
}(); err == nil {
request["grammar"] = schema.AsGrammar()
} else {
slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
request["grammar"] = string(g)
default:
return errors.New(`invalid format: expected "json" or a JSON schema`)
}
}

Expand Down
8 changes: 2 additions & 6 deletions openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ type ResponseFormat struct {
}

type JsonSchema struct {
Schema map[string]any `json:"schema"`
Schema json.RawMessage `json:"schema"`
}

type EmbedRequest struct {
Expand Down Expand Up @@ -495,11 +495,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
format = json.RawMessage(`"json"`)
case "json_schema":
if r.ResponseFormat.JsonSchema != nil {
schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
if err != nil {
return nil, fmt.Errorf("failed to marshal json schema: %w", err)
}
format = schema
format = r.ResponseFormat.JsonSchema.Schema
}
}
}
Expand Down

0 comments on commit 9039c82

Please sign in to comment.