diff --git a/go.mod b/go.mod index 66dcaa76b..848478289 100644 --- a/go.mod +++ b/go.mod @@ -67,7 +67,6 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/gage-technologies/mistral-go v1.0.0 // indirect github.com/getsentry/sentry-go v0.12.0 // indirect github.com/go-logr/logr v1.3.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -185,6 +184,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.27.4 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.1 github.com/cohere-ai/tokenizer v1.1.2 + github.com/gage-technologies/mistral-go v1.0.0 github.com/go-openapi/strfmt v0.21.3 github.com/go-sql-driver/mysql v1.7.1 github.com/gocolly/colly v1.2.0 diff --git a/textsplitter/options.go b/textsplitter/options.go index 1d0c140d2..288101af5 100644 --- a/textsplitter/options.go +++ b/textsplitter/options.go @@ -7,6 +7,7 @@ type Options struct { ChunkSize int ChunkOverlap int Separators []string + KeepSeparator bool LenFunc func(string) int ModelName string EncodingName string @@ -20,10 +21,11 @@ type Options struct { // DefaultOptions returns the default options for all text splitter. func DefaultOptions() Options { return Options{ - ChunkSize: _defaultTokenChunkSize, - ChunkOverlap: _defaultTokenChunkOverlap, - Separators: []string{"\n\n", "\n", " ", ""}, - LenFunc: utf8.RuneCountInString, + ChunkSize: _defaultTokenChunkSize, + ChunkOverlap: _defaultTokenChunkOverlap, + Separators: []string{"\n\n", "\n", " ", ""}, + KeepSeparator: false, + LenFunc: utf8.RuneCountInString, ModelName: _defaultTokenModelName, EncodingName: _defaultTokenEncoding, @@ -118,3 +120,14 @@ func WithReferenceLinks(referenceLinks bool) Option { o.ReferenceLinks = referenceLinks } } + +// WithKeepSeparator sets whether the separators should be kept in the resulting +// split text or not. When it is set to True, the separators are included in the +// resulting split text. When it is set to False, the separators are not included +// in the resulting split text. The purpose of having this parameter is to provide +// flexibility in how text splitting is handled. Default to False if not specified. +func WithKeepSeparator(keepSeparator bool) Option { + return func(o *Options) { + o.KeepSeparator = keepSeparator + } +} diff --git a/textsplitter/recursive_character.go b/textsplitter/recursive_character.go index 98d4d7bd5..36db7b677 100644 --- a/textsplitter/recursive_character.go +++ b/textsplitter/recursive_character.go @@ -7,10 +7,11 @@ import ( // RecursiveCharacter is a text splitter that will split texts recursively by different // characters. type RecursiveCharacter struct { - Separators []string - ChunkSize int - ChunkOverlap int - LenFunc func(string) int + Separators []string + ChunkSize int + ChunkOverlap int + LenFunc func(string) int + KeepSeparator bool } // NewRecursiveCharacter creates a new recursive character splitter with default values. By @@ -23,10 +24,11 @@ func NewRecursiveCharacter(opts ...Option) RecursiveCharacter { } s := RecursiveCharacter{ - Separators: options.Separators, - ChunkSize: options.ChunkSize, - ChunkOverlap: options.ChunkOverlap, - LenFunc: options.LenFunc, + Separators: options.Separators, + ChunkSize: options.ChunkSize, + ChunkOverlap: options.ChunkOverlap, + LenFunc: options.LenFunc, + KeepSeparator: options.KeepSeparator, } return s @@ -34,20 +36,40 @@ func NewRecursiveCharacter(opts ...Option) RecursiveCharacter { // SplitText splits a text into multiple text. func (s RecursiveCharacter) SplitText(text string) ([]string, error) { + return s.splitText(text, s.Separators) +} + +// addSeparatorInSplits adds the separator in each of splits. +func (s RecursiveCharacter) addSeparatorInSplits(splits []string, separator string) []string { + splitsWithSeparator := make([]string, 0, len(splits)) + for i, s := range splits { + if i > 0 { + s = separator + s + } + splitsWithSeparator = append(splitsWithSeparator, s) + } + return splitsWithSeparator +} + +func (s RecursiveCharacter) splitText(text string, separators []string) ([]string, error) { finalChunks := make([]string, 0) - // Find the appropriate separator - separator := s.Separators[len(s.Separators)-1] + // Find the appropriate separator. + separator := separators[len(separators)-1] newSeparators := []string{} - for i, c := range s.Separators { + for i, c := range separators { if c == "" || strings.Contains(text, c) { separator = c - newSeparators = s.Separators[i+1:] + newSeparators = separators[i+1:] break } } splits := strings.Split(text, separator) + if s.KeepSeparator { + splits = s.addSeparatorInSplits(splits, separator) + separator = "" + } goodSplits := make([]string, 0) // Merge the splits, recursively splitting larger texts. @@ -67,7 +89,7 @@ func (s RecursiveCharacter) SplitText(text string) ([]string, error) { if len(newSeparators) == 0 { finalChunks = append(finalChunks, split) } else { - otherInfo, err := s.SplitText(split) + otherInfo, err := s.splitText(split, newSeparators) if err != nil { return nil, err } diff --git a/textsplitter/recursive_character_test.go b/textsplitter/recursive_character_test.go index 8a0a880f8..2e087b1df 100644 --- a/textsplitter/recursive_character_test.go +++ b/textsplitter/recursive_character_test.go @@ -1,8 +1,10 @@ package textsplitter import ( + "strings" "testing" + "github.com/pkoukk/tiktoken-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tmc/langchaingo/schema" @@ -10,13 +12,17 @@ import ( //nolint:dupword,funlen func TestRecursiveCharacterSplitter(t *testing.T) { + tokenEncoder, _ := tiktoken.GetEncoding("cl100k_base") + t.Parallel() type testCase struct { - text string - chunkOverlap int - chunkSize int - separators []string - expectedDocs []schema.Document + text string + chunkOverlap int + chunkSize int + separators []string + expectedDocs []schema.Document + keepSeparator bool + LenFunc func(string) int } testCases := []testCase{ { @@ -106,12 +112,39 @@ Bye! {PageContent: "Bye!\n\n-H.", Metadata: map[string]any{}}, }, }, + { + text: "Hi, Harrison. \nI am glad to meet you", + chunkOverlap: 0, + chunkSize: 10, + separators: []string{"\n", "$"}, + keepSeparator: true, + expectedDocs: []schema.Document{ + {PageContent: "Hi, Harrison. ", Metadata: map[string]any{}}, + {PageContent: "\nI am glad to meet you", Metadata: map[string]any{}}, + }, + }, + { + text: strings.Repeat("The quick brown fox jumped over the lazy dog. ", 2), + chunkOverlap: 0, + chunkSize: 10, + separators: []string{" "}, + keepSeparator: true, + LenFunc: func(s string) int { return len(tokenEncoder.Encode(s, nil, nil)) }, + expectedDocs: []schema.Document{ + {PageContent: "The quick brown fox jumped over the lazy dog.", Metadata: map[string]any{}}, + {PageContent: "The quick brown fox jumped over the lazy dog.", Metadata: map[string]any{}}, + }, + }, } splitter := NewRecursiveCharacter() for _, tc := range testCases { splitter.ChunkOverlap = tc.chunkOverlap splitter.ChunkSize = tc.chunkSize splitter.Separators = tc.separators + splitter.KeepSeparator = tc.keepSeparator + if tc.LenFunc != nil { + splitter.LenFunc = tc.LenFunc + } docs, err := CreateDocuments(splitter, []string{tc.text}, nil) require.NoError(t, err)