Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add WithKeepSeparator option for RecursiveCharacter #721

Merged
merged 5 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ require (
github.com/docker/go-units v0.5.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/gage-technologies/mistral-go v1.0.0 // indirect
github.com/getsentry/sentry-go v0.12.0 // indirect
github.com/go-logr/logr v1.3.0 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
Expand Down Expand Up @@ -185,6 +184,7 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.27.4
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.1
github.com/cohere-ai/tokenizer v1.1.2
github.com/gage-technologies/mistral-go v1.0.0
github.com/go-openapi/strfmt v0.21.3
github.com/go-sql-driver/mysql v1.7.1
github.com/gocolly/colly v1.2.0
Expand Down
21 changes: 17 additions & 4 deletions textsplitter/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ type Options struct {
ChunkSize int
ChunkOverlap int
Separators []string
KeepSeparator bool
LenFunc func(string) int
ModelName string
EncodingName string
Expand All @@ -20,10 +21,11 @@ type Options struct {
// DefaultOptions returns the default options for all text splitter.
func DefaultOptions() Options {
return Options{
ChunkSize: _defaultTokenChunkSize,
ChunkOverlap: _defaultTokenChunkOverlap,
Separators: []string{"\n\n", "\n", " ", ""},
LenFunc: utf8.RuneCountInString,
ChunkSize: _defaultTokenChunkSize,
ChunkOverlap: _defaultTokenChunkOverlap,
Separators: []string{"\n\n", "\n", " ", ""},
KeepSeparator: false,
LenFunc: utf8.RuneCountInString,

ModelName: _defaultTokenModelName,
EncodingName: _defaultTokenEncoding,
Expand Down Expand Up @@ -118,3 +120,14 @@ func WithReferenceLinks(referenceLinks bool) Option {
o.ReferenceLinks = referenceLinks
}
}

// WithKeepSeparator sets whether the separators should be kept in the resulting
// split text or not. When it is set to True, the separators are included in the
// resulting split text. When it is set to False, the separators are not included
zhangi marked this conversation as resolved.
Show resolved Hide resolved
// in the resulting split text. The purpose of having this parameter is to provide
// flexibility in how text splitting is handled.
func WithKeepSeparator(keepSeparator bool) Option {
return func(o *Options) {
o.KeepSeparator = keepSeparator
}
}
48 changes: 35 additions & 13 deletions textsplitter/recursive_character.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ import (
// RecursiveCharacter is a text splitter that will split texts recursively by different
// characters.
type RecursiveCharacter struct {
Separators []string
ChunkSize int
ChunkOverlap int
LenFunc func(string) int
Separators []string
ChunkSize int
ChunkOverlap int
LenFunc func(string) int
KeepSeparator bool
}

// NewRecursiveCharacter creates a new recursive character splitter with default values. By
Expand All @@ -23,31 +24,52 @@ func NewRecursiveCharacter(opts ...Option) RecursiveCharacter {
}

s := RecursiveCharacter{
Separators: options.Separators,
ChunkSize: options.ChunkSize,
ChunkOverlap: options.ChunkOverlap,
LenFunc: options.LenFunc,
Separators: options.Separators,
ChunkSize: options.ChunkSize,
ChunkOverlap: options.ChunkOverlap,
LenFunc: options.LenFunc,
KeepSeparator: options.KeepSeparator,
}

return s
}

// SplitText splits a text into multiple text.
func (s RecursiveCharacter) SplitText(text string) ([]string, error) {
return s.splitText(text, s.Separators)
}

// addSeparatorInSplits adds the separator in each of splits.
func (s RecursiveCharacter) addSeparatorInSplits(splits []string, separator string) []string {
splitsWithSeparator := make([]string, 0, len(splits))
for i, s := range splits {
if i > 0 {
s = separator + s
}
splitsWithSeparator = append(splitsWithSeparator, s)
}
return splitsWithSeparator
}

func (s RecursiveCharacter) splitText(text string, separators []string) ([]string, error) {
finalChunks := make([]string, 0)

// Find the appropriate separator
separator := s.Separators[len(s.Separators)-1]
// Find the appropriate separator.
separator := separators[len(separators)-1]
newSeparators := []string{}
for i, c := range s.Separators {
for i, c := range separators {
if c == "" || strings.Contains(text, c) {
separator = c
newSeparators = s.Separators[i+1:]
newSeparators = separators[i+1:]
break
}
}

splits := strings.Split(text, separator)
if s.KeepSeparator {
splits = s.addSeparatorInSplits(splits, separator)
separator = ""
}
goodSplits := make([]string, 0)

// Merge the splits, recursively splitting larger texts.
Expand All @@ -67,7 +89,7 @@ func (s RecursiveCharacter) SplitText(text string) ([]string, error) {
if len(newSeparators) == 0 {
finalChunks = append(finalChunks, split)
} else {
otherInfo, err := s.SplitText(split)
otherInfo, err := s.splitText(split, newSeparators)
if err != nil {
return nil, err
}
Expand Down
43 changes: 38 additions & 5 deletions textsplitter/recursive_character_test.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
package textsplitter

import (
"strings"
"testing"

"github.com/pkoukk/tiktoken-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/schema"
)

//nolint:dupword,funlen
func TestRecursiveCharacterSplitter(t *testing.T) {
tokenEncoder, _ := tiktoken.GetEncoding("cl100k_base")

t.Parallel()
type testCase struct {
text string
chunkOverlap int
chunkSize int
separators []string
expectedDocs []schema.Document
text string
chunkOverlap int
chunkSize int
separators []string
expectedDocs []schema.Document
keepSeparator bool
LenFunc func(string) int
}
testCases := []testCase{
{
Expand Down Expand Up @@ -106,12 +112,39 @@ Bye!
{PageContent: "Bye!\n\n-H.", Metadata: map[string]any{}},
},
},
{
text: "Hi, Harrison. \nI am glad to meet you",
chunkOverlap: 0,
chunkSize: 10,
separators: []string{"\n", "$"},
keepSeparator: true,
expectedDocs: []schema.Document{
{PageContent: "Hi, Harrison. ", Metadata: map[string]any{}},
{PageContent: "\nI am glad to meet you", Metadata: map[string]any{}},
},
},
{
text: strings.Repeat("The quick brown fox jumped over the lazy dog. ", 2),
chunkOverlap: 0,
chunkSize: 10,
separators: []string{" "},
keepSeparator: true,
LenFunc: func(s string) int { return len(tokenEncoder.Encode(s, nil, nil)) },
expectedDocs: []schema.Document{
{PageContent: "The quick brown fox jumped over the lazy dog.", Metadata: map[string]any{}},
{PageContent: "The quick brown fox jumped over the lazy dog.", Metadata: map[string]any{}},
},
},
}
splitter := NewRecursiveCharacter()
for _, tc := range testCases {
splitter.ChunkOverlap = tc.chunkOverlap
splitter.ChunkSize = tc.chunkSize
splitter.Separators = tc.separators
splitter.KeepSeparator = tc.keepSeparator
if tc.LenFunc != nil {
splitter.LenFunc = tc.LenFunc
}

docs, err := CreateDocuments(splitter, []string{tc.text}, nil)
require.NoError(t, err)
Expand Down
Loading