Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server,session: Add status vars for compression #48152

Merged
merged 11 commits into from
Nov 16, 2023
3 changes: 3 additions & 0 deletions pkg/parser/mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,9 @@ const (
CursorTypeScrollable
)

// ZlibCompressDefaultLevel is the zlib compression level for the compressed protocol
const ZlibCompressDefaultLevel = 6

const (
// CompressionNone is no compression in use
CompressionNone = iota
Expand Down
1 change: 1 addition & 0 deletions pkg/server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ go_library(
"//pkg/util/versioninfo",
"@com_github_blacktear23_go_proxyprotocol//:go-proxyprotocol",
"@com_github_gorilla_mux//:mux",
"@com_github_klauspost_compress//zstd",
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_pingcap_fn//:fn",
Expand Down
57 changes: 52 additions & 5 deletions pkg/server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import (
"time"
"unsafe"

"github.com/klauspost/compress/zstd"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/config"
Expand Down Expand Up @@ -113,9 +114,15 @@ const (
connStatusWaitShutdown = 3 // Notified by server to close.
)

var (
statusCompression = "Compression"
statusCompressionAlgorithm = "Compression_algorithm"
statusCompressionLevel = "Compression_level"
)

// newClientConn creates a *clientConn object.
func newClientConn(s *Server) *clientConn {
return &clientConn{
cc := &clientConn{
server: s,
connectionID: s.dom.NextConnID(),
collation: mysql.DefaultCollationID,
Expand All @@ -127,6 +134,8 @@ func newClientConn(s *Server) *clientConn {
quit: make(chan struct{}),
ppEnabled: s.cfg.ProxyProtocol.Networks != "",
}
variable.RegisterStatistics(cc)
return cc
}

// clientConn represents a connection between server and client, it maintains connection specific state,
Expand Down Expand Up @@ -323,8 +332,10 @@ func (cc *clientConn) handshake(ctx context.Context) error {
// With mysql --compression-algorithms=zlib,zstd both flags are set, the result is Zlib
if cc.capability&mysql.ClientCompress > 0 {
cc.pkt.SetCompressionAlgorithm(mysql.CompressionZlib)
cc.ctx.SetCompressionAlgorithm(mysql.CompressionZlib)
} else if cc.capability&mysql.ClientZstdCompressionAlgorithm > 0 {
cc.pkt.SetCompressionAlgorithm(mysql.CompressionZstd)
cc.ctx.SetCompressionAlgorithm(mysql.CompressionZstd)
}

return err
Expand Down Expand Up @@ -546,7 +557,7 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
cc.dbname = resp.DBName
cc.collation = resp.Collation
cc.attrs = resp.Attrs
cc.pkt.SetZstdLevel(resp.ZstdLevel)
cc.pkt.SetZstdLevel(zstd.EncoderLevelFromZstd(resp.ZstdLevel))

err = cc.handleAuthPlugin(ctx, &resp)
if err != nil {
Expand Down Expand Up @@ -575,7 +586,7 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
return errors.New("Unknown auth plugin")
}

err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin)
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin, resp.ZstdLevel)
if err != nil {
logutil.Logger(ctx).Warn("open new session or authentication failure", zap.Error(err))
}
Expand Down Expand Up @@ -716,7 +727,7 @@ func (cc *clientConn) openSession() error {
return nil
}

func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) error {
func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string, zstdLevel int) error {
// Open a context unless this was done before.
if ctx := cc.getCtx(); ctx == nil {
err := cc.openSession()
Expand Down Expand Up @@ -744,6 +755,7 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) e
return err
}
cc.ctx.SetPort(port)
cc.ctx.SetCompressionLevel(zstdLevel)
if cc.dbname != "" {
_, err = cc.useDB(context.Background(), cc.dbname)
if err != nil {
Expand Down Expand Up @@ -2434,7 +2446,7 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error {
fakeResp.Auth = newpass
}
}
if err := cc.openSessionAndDoAuth(fakeResp.Auth, fakeResp.AuthPlugin); err != nil {
if err := cc.openSessionAndDoAuth(fakeResp.Auth, fakeResp.AuthPlugin, fakeResp.ZstdLevel); err != nil {
return err
}
return cc.handleCommonConnectionReset(ctx)
Expand Down Expand Up @@ -2581,3 +2593,38 @@ func (cc *clientConn) ReadPacket() ([]byte, error) {
func (cc *clientConn) Flush(ctx context.Context) error {
return cc.flush(ctx)
}

// Stats returns the connection statistics.
func (*clientConn) Stats(vars *variable.SessionVars) (map[string]interface{}, error) {
m := make(map[string]interface{}, 3)

switch vars.CompressionAlgorithm {
case mysql.CompressionNone:
m[statusCompression] = "OFF"
m[statusCompressionAlgorithm] = ""
m[statusCompressionLevel] = 0
case mysql.CompressionZlib:
m[statusCompression] = "ON"
m[statusCompressionAlgorithm] = "zlib"
m[statusCompressionLevel] = mysql.ZlibCompressDefaultLevel
case mysql.CompressionZstd:
m[statusCompression] = "ON"
m[statusCompressionAlgorithm] = "zstd"
m[statusCompressionLevel] = vars.CompressionLevel
default:
logutil.BgLogger().Debug(
"unexpected compression algorithm value",
zap.Int("algorithm", vars.CompressionAlgorithm),
)
m[statusCompression] = "OFF"
m[statusCompressionAlgorithm] = ""
m[statusCompressionLevel] = 0
}

return m, nil
}

// GetScope gets the status variables scope.
func (*clientConn) GetScope(_ string) variable.ScopeFlag {
return variable.ScopeSession
}
63 changes: 58 additions & 5 deletions pkg/server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1617,25 +1617,25 @@ func TestAuthSessionTokenPlugin(t *testing.T) {
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin)
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin, resp.ZstdLevel)
require.NoError(t, err)

// login succeeds even if the password expires now
tk.MustExec("ALTER USER auth_session_token PASSWORD EXPIRE")
err = cc.openSessionAndDoAuth([]byte{}, mysql.AuthNativePassword)
err = cc.openSessionAndDoAuth([]byte{}, mysql.AuthNativePassword, 0)
require.ErrorContains(t, err, "Your password has expired")
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin)
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin, resp.ZstdLevel)
require.NoError(t, err)

// wrong token should fail
tokenBytes[0] ^= 0xff
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin)
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin, resp.ZstdLevel)
require.ErrorContains(t, err, "Access denied")
tokenBytes[0] ^= 0xff

// using the token to auth with another user should fail
cc.user = "another_user"
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin)
err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin, resp.ZstdLevel)
require.ErrorContains(t, err, "Access denied")
}

Expand Down Expand Up @@ -1999,3 +1999,56 @@ func TestEmptyOrgName(t *testing.T) {

testDispatch(t, inputs, 0)
}

func TestStats(t *testing.T) {
var outBuffer bytes.Buffer

store := testkit.CreateMockStore(t)
cfg := serverutil.NewTestConfig()
cfg.Port = 0
cfg.Status.StatusPort = 0
drv := NewTiDBDriver(store)
server, err := NewServer(cfg, drv)
require.NoError(t, err)
tk := testkit.NewTestKit(t, store)

cc := &clientConn{
connectionID: 1,
salt: []byte{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14,
},
server: server,
pkt: internal.NewPacketIOForTest(bufio.NewWriter(&outBuffer)),
collation: mysql.DefaultCollationID,
peerHost: "localhost",
alloc: arena.NewAllocator(512),
chunkAlloc: chunk.NewAllocator(),
capability: mysql.ClientProtocol41,
}

// No compression
vars := tk.Session().GetSessionVars()
m, err := cc.Stats(vars)
require.NoError(t, err)
require.Equal(t, "OFF", m["Compression"])
require.Equal(t, "", m["Compression_algorithm"])
require.Equal(t, 0, m["Compression_level"])

// zlib compression
vars.CompressionAlgorithm = mysql.CompressionZlib
m, err = cc.Stats(vars)
require.NoError(t, err)
require.Equal(t, "ON", m["Compression"])
require.Equal(t, "zlib", m["Compression_algorithm"])
require.Equal(t, mysql.ZlibCompressDefaultLevel, m["Compression_level"])

// zstd compression, with level 1
vars.CompressionAlgorithm = mysql.CompressionZstd
vars.CompressionLevel = 1
m, err = cc.Stats(vars)
require.NoError(t, err)
require.Equal(t, "ON", m["Compression"])
require.Equal(t, "zstd", m["Compression_algorithm"])
require.Equal(t, 1, m["Compression_level"])
}
1 change: 0 additions & 1 deletion pkg/server/internal/handshake/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@ go_library(
srcs = ["handshake.go"],
importpath = "github.com/pingcap/tidb/pkg/server/internal/handshake",
visibility = ["//pkg/server:__subpackages__"],
deps = ["@com_github_klauspost_compress//zstd"],
)
4 changes: 1 addition & 3 deletions pkg/server/internal/handshake/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@

package handshake

import "github.com/klauspost/compress/zstd"

// Response41 is the response message for a successful initial handshake.
type Response41 struct {
Attrs map[string]string
User string
DBName string
AuthPlugin string
Auth []byte
ZstdLevel zstd.EncoderLevel
ZstdLevel int
Capability uint32
Collation uint8
}
3 changes: 1 addition & 2 deletions pkg/server/internal/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,12 @@ func (cw *compressedWriter) Flush() error {
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_compression_packet.html
// suggests a MIN_COMPRESS_LENGTH of 50.
minCompressLength := 50
zlibCompressDefaultLevel := 6
data := cw.buf.Bytes()
cw.buf.Reset()

switch cw.compressionAlgorithm {
case mysql.CompressionZlib:
w, err = zlib.NewWriterLevel(&payload, zlibCompressDefaultLevel)
w, err = zlib.NewWriterLevel(&payload, mysql.ZlibCompressDefaultLevel)
case mysql.CompressionZstd:
w, err = zstd.NewWriter(&payload, zstd.WithEncoderLevel(cw.zstdLevel))
default:
Expand Down
1 change: 0 additions & 1 deletion pkg/server/internal/parse/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ go_library(
"//pkg/server/internal/handshake",
"//pkg/server/internal/util",
"//pkg/util/logutil",
"@com_github_klauspost_compress//zstd",
"@org_uber_go_zap//:zap",
],
)
Expand Down
3 changes: 1 addition & 2 deletions pkg/server/internal/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"context"
"encoding/binary"

"github.com/klauspost/compress/zstd"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/server/internal/handshake"
util2 "github.com/pingcap/tidb/pkg/server/internal/util"
Expand Down Expand Up @@ -146,7 +145,7 @@ func HandshakeResponseBody(ctx context.Context, packet *handshake.Response41, da
}

if packet.Capability&mysql.ClientZstdCompressionAlgorithm > 0 {
packet.ZstdLevel = zstd.EncoderLevelFromZstd(int(data[offset]))
packet.ZstdLevel = int(data[offset])
}

return nil
Expand Down
10 changes: 10 additions & 0 deletions pkg/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ type Session interface {
SetClientCapability(uint32) // Set client capability flags.
SetConnectionID(uint64)
SetCommandValue(byte)
SetCompressionAlgorithm(int)
SetCompressionLevel(int)
SetProcessInfo(string, time.Time, byte, uint64)
SetTLSState(*tls.ConnectionState)
SetCollation(coID int) error
Expand Down Expand Up @@ -408,6 +410,14 @@ func (s *session) SetTLSState(tlsState *tls.ConnectionState) {
}
}

func (s *session) SetCompressionAlgorithm(ca int) {
s.sessionVars.CompressionAlgorithm = ca
}

func (s *session) SetCompressionLevel(level int) {
s.sessionVars.CompressionLevel = level
}

func (s *session) SetCommandValue(command byte) {
atomic.StoreUint32(&s.sessionVars.CommandValue, uint32(command))
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,9 @@ type SessionVars struct {
// OptObjectiveModerate: The default value. The optimizer considers the real-time stats (real-time row count, modify count).
// OptObjectiveDeterminate: The optimizer doesn't consider the real-time stats.
OptObjective string

CompressionAlgorithm int
CompressionLevel int
}

// GetOptimizerFixControlMap returns the specified value of the optimizer fix control.
Expand Down