mirror of
https://github.com/kevwan/tproxy.git
synced 2026-05-23 07:40:35 +00:00
feat: add mysql protocol, surport client and server parse. (#100)
* fix: bug only processed client, not server. * fix: remove unused code * feat: transmission from the server to the client is inefficient when limited to 32KB, especially for handling large response bodies. * fix: code runner * chore: optimize result reset output --------- Co-authored-by: jxli <jiaxin.li@corerain.com>
This commit is contained in:
committed by
GitHub
parent
ba53ceee53
commit
d52f1095bd
+218
-39
@@ -4,16 +4,20 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/kevwan/tproxy/display"
|
||||
"io"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type mysqlInterop struct{}
|
||||
|
||||
const maxDecodeResponseBodySize = 32 * 1 << 10 // Limit 32KB (only result set may reach this limitation.)
|
||||
|
||||
var comTypeMap = map[byte]string{
|
||||
0x00: "SLEEP",
|
||||
0x01: "QUIT",
|
||||
@@ -46,6 +50,14 @@ var comTypeMap = map[byte]string{
|
||||
0x1c: "FETCH",
|
||||
}
|
||||
|
||||
var statusFlagMap = map[uint16]string{
|
||||
0x01: "SERVER_STATUS_AUTOCOMMIT",
|
||||
0x02: "SERVER_STATUS_COMMAND",
|
||||
0x04: "SERVER_STATUS_CONNECTED",
|
||||
0x08: "SERVER_STATUS_MORE_RESULTS",
|
||||
0x10: "SERVER_STATUS_SESSION_STATE ",
|
||||
}
|
||||
|
||||
type ServerResponse struct {
|
||||
PacketLength int
|
||||
SequenceID byte
|
||||
@@ -53,9 +65,205 @@ type ServerResponse struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
type ResponsePkgType string
|
||||
|
||||
const (
|
||||
MySQLResponseTypeOK ResponsePkgType = "OK"
|
||||
MySQLResponseTypeError ResponsePkgType = "Error"
|
||||
MySQLResponseTypeEOF ResponsePkgType = "EOF"
|
||||
MySQLResponseTypeResultSet ResponsePkgType = "Result Set"
|
||||
MySQLResponseTypeUnknown ResponsePkgType = "Unknown"
|
||||
)
|
||||
|
||||
func getPkgType(flag byte) ResponsePkgType {
|
||||
if flag == 0x00 || flag == 0xfe {
|
||||
return MySQLResponseTypeOK
|
||||
} else if flag == 0xff {
|
||||
return MySQLResponseTypeError
|
||||
} else if flag > 0x01 && flag < 0xfa {
|
||||
return MySQLResponseTypeResultSet
|
||||
} else {
|
||||
return MySQLResponseTypeUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func readLCInt(buf []byte) ([]byte, uint64, error) {
|
||||
if len(buf) == 0 {
|
||||
return nil, 0, errors.New("empty buffer")
|
||||
}
|
||||
|
||||
lcbyte := buf[0]
|
||||
|
||||
switch {
|
||||
case lcbyte == 0xFB: // 0xFB
|
||||
return buf[1:], 0, nil
|
||||
case lcbyte < 0xFB:
|
||||
return buf[1:], uint64(lcbyte), nil
|
||||
case lcbyte == 0xFC: // 0xFC
|
||||
return buf[3:], uint64(binary.LittleEndian.Uint16(buf[1:3])), nil
|
||||
case lcbyte == 0xFD: // 0xFD
|
||||
return buf[4:], uint64(binary.LittleEndian.Uint32(append(buf[1:4], 0))), nil
|
||||
case lcbyte == 0xFE: // 0xFE
|
||||
return buf[9:], binary.LittleEndian.Uint64(buf[1:9]), nil
|
||||
default:
|
||||
return nil, 0, errors.New("failed reading length encoded integer")
|
||||
}
|
||||
}
|
||||
|
||||
func processOkResponse(sequenceId byte, payload []byte) {
|
||||
var (
|
||||
affectedRows, lastInsertID uint64
|
||||
statusFlag string
|
||||
ok bool
|
||||
err error
|
||||
remaining []byte
|
||||
)
|
||||
remaining, affectedRows, err = readLCInt(payload[1:])
|
||||
if err != nil {
|
||||
display.PrintlnWithTime(color.HiRedString("Failed reading length encoded integer: " + err.Error()))
|
||||
return
|
||||
}
|
||||
remaining, lastInsertID, err = readLCInt(remaining)
|
||||
if err != nil {
|
||||
display.PrintlnWithTime(color.HiRedString("Failed reading length encoded integer: " + err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
display.PrintlnWithTime(color.HiRedString("Failed reading length encoded integer: " + err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
statusFlag, ok = statusFlagMap[binary.LittleEndian.Uint16(remaining[:2])]
|
||||
if !ok {
|
||||
statusFlag = "unknown"
|
||||
}
|
||||
|
||||
remaining = remaining[2:]
|
||||
|
||||
warningsCount := binary.LittleEndian.Uint16(remaining[:2])
|
||||
|
||||
remaining = remaining[2:]
|
||||
|
||||
display.PrintlnWithTime(
|
||||
fmt.Sprintf("[Server -> Client] %d-%s: affectRows: %d, lastInsertID: %d, warningsCount: %d, status: %s, data: %s",
|
||||
sequenceId, MySQLResponseTypeOK, affectedRows, lastInsertID, warningsCount, statusFlag, remaining))
|
||||
}
|
||||
|
||||
var sqlStateDescriptions = map[string]string{
|
||||
"42000": "Syntax error or access rule violation.",
|
||||
"23000": "Integrity constraint violation.",
|
||||
"08000": "Connection exception.",
|
||||
"28000": "Invalid authorization specification.",
|
||||
"42001": "Syntax error in SQL statement.",
|
||||
}
|
||||
|
||||
func processErrorResponse(sequenceId byte, payload []byte) {
|
||||
errCode := binary.LittleEndian.Uint16(payload[1:3])
|
||||
sqlStateMarker := payload[3]
|
||||
sqlState := string(payload[5:9])
|
||||
sqlStateDescription, ok := sqlStateDescriptions[sqlState[1:]]
|
||||
if !ok {
|
||||
sqlStateDescription = "Unknown SQLSTATE"
|
||||
}
|
||||
errorMessage := string(payload[9:])
|
||||
|
||||
display.PrintfWithTime(
|
||||
color.HiYellowString(fmt.Sprintf("[Server -> Client] %d-%s: ErrCode: %d, ErrMsg: %s, SqlState: %s, sqlStateMaker: %v",
|
||||
sequenceId, MySQLResponseTypeError, errCode, errorMessage, sqlStateDescription, sqlStateMarker)),
|
||||
)
|
||||
}
|
||||
|
||||
func processResultSetResponse(sequenceId byte, payload []byte) {
|
||||
display.PrintfWithTime(fmt.Sprintf("[Server -> Client] %d-%s: \n %s", sequenceId, MySQLResponseTypeResultSet, hexDump(payload)))
|
||||
|
||||
}
|
||||
|
||||
func insertSpace(hexStr string) string {
|
||||
var spaced strings.Builder
|
||||
for i := 0; i < len(hexStr); i += 2 {
|
||||
spaced.WriteString(hexStr[i:i+2] + " ")
|
||||
}
|
||||
return spaced.String()
|
||||
}
|
||||
|
||||
func toPrintableASCII(data []byte) string {
|
||||
var result strings.Builder
|
||||
for i := 0; i < len(data); {
|
||||
r, size := utf8.DecodeRune(data[i:])
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
result.WriteByte('.')
|
||||
i++
|
||||
} else {
|
||||
if unicode.IsPrint(r) {
|
||||
result.WriteRune(r)
|
||||
} else {
|
||||
result.WriteByte('.')
|
||||
}
|
||||
i += size
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func hexDump(data []byte) string {
|
||||
var result strings.Builder
|
||||
const chunkSize = 16
|
||||
|
||||
for i := 0; i < len(data); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(data) {
|
||||
end = len(data)
|
||||
}
|
||||
chunk := data[i:end]
|
||||
|
||||
hexStr := hex.EncodeToString(chunk)
|
||||
hexStr = insertSpace(hexStr)
|
||||
asciiStr := toPrintableASCII(chunk)
|
||||
result.WriteString(fmt.Sprintf("%04x %-48s |%s|\n", i, hexStr, asciiStr))
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func processUnknownResponse(sequenceId byte, payload []byte) {
|
||||
display.PrintlnWithTime(fmt.Sprintf("[Server -> Client] %d-%s:\n%s", sequenceId, MySQLResponseTypeUnknown, hexDump(payload)))
|
||||
}
|
||||
|
||||
func (mysql *mysqlInterop) dumpServer(r io.Reader, id int, quiet bool, data []byte) {
|
||||
if len(data) < 4 {
|
||||
display.PrintlnWithTime("Invalid packet: insufficient data for header")
|
||||
return
|
||||
}
|
||||
|
||||
sequenceId := data[3]
|
||||
payload := data[4:]
|
||||
|
||||
if len(payload) > maxDecodeResponseBodySize {
|
||||
display.PrintlnWithTime(color.HiRedString(fmt.Sprintf("Packet too large to, just decode %d MB", maxDecodeResponseBodySize/1024/1024)))
|
||||
payload = payload[:maxDecodeResponseBodySize]
|
||||
}
|
||||
|
||||
switch getPkgType(payload[0]) {
|
||||
case MySQLResponseTypeOK:
|
||||
processOkResponse(sequenceId, payload)
|
||||
case MySQLResponseTypeError:
|
||||
processErrorResponse(sequenceId, payload)
|
||||
case MySQLResponseTypeResultSet:
|
||||
processResultSetResponse(sequenceId, payload)
|
||||
case MySQLResponseTypeEOF:
|
||||
default:
|
||||
processUnknownResponse(sequenceId, payload)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (mysql *mysqlInterop) dumpClient(r io.Reader, id int, quiet bool, data []byte) {
|
||||
// parse packet length
|
||||
var packetLength uint32
|
||||
var (
|
||||
packetLength uint32
|
||||
sequenceId uint32
|
||||
)
|
||||
reader := bytes.NewReader(data[:4])
|
||||
err := binary.Read(reader, binary.BigEndian, &packetLength)
|
||||
if err != nil {
|
||||
@@ -68,7 +276,11 @@ func (mysql *mysqlInterop) dumpClient(r io.Reader, id int, quiet bool, data []by
|
||||
commandName := comTypeMap[commandType]
|
||||
|
||||
// parse sequence id
|
||||
sequenceId := data[5]
|
||||
if len(data) < 6 {
|
||||
sequenceId = 0
|
||||
} else {
|
||||
sequenceId = uint32(data[5])
|
||||
}
|
||||
|
||||
// parse query
|
||||
var query []byte
|
||||
@@ -80,44 +292,12 @@ func (mysql *mysqlInterop) dumpClient(r io.Reader, id int, quiet bool, data []by
|
||||
}
|
||||
|
||||
if utf8.Valid(query) {
|
||||
display.PrintlnWithTime(fmt.Sprintf("[Client] %d-%s: %s", sequenceId, commandName, string(query)))
|
||||
display.PrintlnWithTime(fmt.Sprintf("[Client -> Server] %d-%s: %s", sequenceId, commandName, string(query)))
|
||||
} else {
|
||||
display.PrintlnWithTime(color.HiRedString("Invalid Query %v", query))
|
||||
}
|
||||
}
|
||||
|
||||
func (mysql *mysqlInterop) dumpServer(r io.Reader, id int, quiet bool, data []byte) {
|
||||
header := make([]byte, 4)
|
||||
_, err := r.Read(header)
|
||||
if err != nil {
|
||||
display.PrintlnWithTime(color.HiRedString("Error reading packet length: %v\n", err))
|
||||
return
|
||||
}
|
||||
|
||||
packetLength := int(binary.BigEndian.Uint16(header[:3]))
|
||||
responseData := make([]byte, packetLength)
|
||||
_, err = r.Read(responseData)
|
||||
if err != nil {
|
||||
display.PrintlnWithTime(color.HiRedString("Error reading packet data: %v\n", err))
|
||||
return
|
||||
}
|
||||
|
||||
// OK packet, value is 0x00
|
||||
// Error packet, value is 0xFF
|
||||
responseType := data[0]
|
||||
if responseType == 0x00 {
|
||||
fmt.Println("OK packet", hex.Dump(responseData))
|
||||
} else if responseType == 0xff {
|
||||
fmt.Println("Error packet", hex.Dump(responseData))
|
||||
} else if responseType == 0xfe {
|
||||
fmt.Println("EOF packet", hex.Dump(responseData))
|
||||
} else if responseType > 0x00 && responseType < 0xfa {
|
||||
fmt.Println("other packet", hex.Dump(responseData))
|
||||
} else {
|
||||
display.PrintlnWithTime(color.HiRedString("invalid packet"))
|
||||
}
|
||||
}
|
||||
|
||||
func (mysql *mysqlInterop) Dump(r io.Reader, source string, id int, quiet bool) {
|
||||
buffer := make([]byte, bufferSize)
|
||||
for {
|
||||
@@ -129,7 +309,6 @@ func (mysql *mysqlInterop) Dump(r io.Reader, source string, id int, quiet bool)
|
||||
|
||||
if n > 0 && !quiet {
|
||||
data := buffer[:n]
|
||||
|
||||
if source == "CLIENT" {
|
||||
mysql.dumpClient(r, id, quiet, data)
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user