diff --git a/protocol/mysql.go b/protocol/mysql.go index 24e3b38..75b9150 100644 --- a/protocol/mysql.go +++ b/protocol/mysql.go @@ -4,16 +4,20 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "errors" "fmt" - "io" - "unicode/utf8" - "github.com/fatih/color" "github.com/kevwan/tproxy/display" + "io" + "strings" + "unicode" + "unicode/utf8" ) type mysqlInterop struct{} +const maxDecodeResponseBodySize = 32 * 1 << 10 // Limit 32KB (only result set may reach this limitation.) + var comTypeMap = map[byte]string{ 0x00: "SLEEP", 0x01: "QUIT", @@ -46,6 +50,14 @@ var comTypeMap = map[byte]string{ 0x1c: "FETCH", } +var statusFlagMap = map[uint16]string{ + 0x01: "SERVER_STATUS_AUTOCOMMIT", + 0x02: "SERVER_STATUS_COMMAND", + 0x04: "SERVER_STATUS_CONNECTED", + 0x08: "SERVER_STATUS_MORE_RESULTS", + 0x10: "SERVER_STATUS_SESSION_STATE ", +} + type ServerResponse struct { PacketLength int SequenceID byte @@ -53,9 +65,205 @@ type ServerResponse struct { Data []byte } +type ResponsePkgType string + +const ( + MySQLResponseTypeOK ResponsePkgType = "OK" + MySQLResponseTypeError ResponsePkgType = "Error" + MySQLResponseTypeEOF ResponsePkgType = "EOF" + MySQLResponseTypeResultSet ResponsePkgType = "Result Set" + MySQLResponseTypeUnknown ResponsePkgType = "Unknown" +) + +func getPkgType(flag byte) ResponsePkgType { + if flag == 0x00 || flag == 0xfe { + return MySQLResponseTypeOK + } else if flag == 0xff { + return MySQLResponseTypeError + } else if flag > 0x01 && flag < 0xfa { + return MySQLResponseTypeResultSet + } else { + return MySQLResponseTypeUnknown + } +} + +func readLCInt(buf []byte) ([]byte, uint64, error) { + if len(buf) == 0 { + return nil, 0, errors.New("empty buffer") + } + + lcbyte := buf[0] + + switch { + case lcbyte == 0xFB: // 0xFB + return buf[1:], 0, nil + case lcbyte < 0xFB: + return buf[1:], uint64(lcbyte), nil + case lcbyte == 0xFC: // 0xFC + return buf[3:], uint64(binary.LittleEndian.Uint16(buf[1:3])), nil + case lcbyte == 0xFD: // 0xFD + return buf[4:], uint64(binary.LittleEndian.Uint32(append(buf[1:4], 0))), nil + case lcbyte == 0xFE: // 0xFE + return buf[9:], binary.LittleEndian.Uint64(buf[1:9]), nil + default: + return nil, 0, errors.New("failed reading length encoded integer") + } +} + +func processOkResponse(sequenceId byte, payload []byte) { + var ( + affectedRows, lastInsertID uint64 + statusFlag string + ok bool + err error + remaining []byte + ) + remaining, affectedRows, err = readLCInt(payload[1:]) + if err != nil { + display.PrintlnWithTime(color.HiRedString("Failed reading length encoded integer: " + err.Error())) + return + } + remaining, lastInsertID, err = readLCInt(remaining) + if err != nil { + display.PrintlnWithTime(color.HiRedString("Failed reading length encoded integer: " + err.Error())) + return + } + + if err != nil { + display.PrintlnWithTime(color.HiRedString("Failed reading length encoded integer: " + err.Error())) + return + } + + statusFlag, ok = statusFlagMap[binary.LittleEndian.Uint16(remaining[:2])] + if !ok { + statusFlag = "unknown" + } + + remaining = remaining[2:] + + warningsCount := binary.LittleEndian.Uint16(remaining[:2]) + + remaining = remaining[2:] + + display.PrintlnWithTime( + fmt.Sprintf("[Server -> Client] %d-%s: affectRows: %d, lastInsertID: %d, warningsCount: %d, status: %s, data: %s", + sequenceId, MySQLResponseTypeOK, affectedRows, lastInsertID, warningsCount, statusFlag, remaining)) +} + +var sqlStateDescriptions = map[string]string{ + "42000": "Syntax error or access rule violation.", + "23000": "Integrity constraint violation.", + "08000": "Connection exception.", + "28000": "Invalid authorization specification.", + "42001": "Syntax error in SQL statement.", +} + +func processErrorResponse(sequenceId byte, payload []byte) { + errCode := binary.LittleEndian.Uint16(payload[1:3]) + sqlStateMarker := payload[3] + sqlState := string(payload[5:9]) + sqlStateDescription, ok := sqlStateDescriptions[sqlState[1:]] + if !ok { + sqlStateDescription = "Unknown SQLSTATE" + } + errorMessage := string(payload[9:]) + + display.PrintfWithTime( + color.HiYellowString(fmt.Sprintf("[Server -> Client] %d-%s: ErrCode: %d, ErrMsg: %s, SqlState: %s, sqlStateMaker: %v", + sequenceId, MySQLResponseTypeError, errCode, errorMessage, sqlStateDescription, sqlStateMarker)), + ) +} + +func processResultSetResponse(sequenceId byte, payload []byte) { + display.PrintfWithTime(fmt.Sprintf("[Server -> Client] %d-%s: \n %s", sequenceId, MySQLResponseTypeResultSet, hexDump(payload))) + +} + +func insertSpace(hexStr string) string { + var spaced strings.Builder + for i := 0; i < len(hexStr); i += 2 { + spaced.WriteString(hexStr[i:i+2] + " ") + } + return spaced.String() +} + +func toPrintableASCII(data []byte) string { + var result strings.Builder + for i := 0; i < len(data); { + r, size := utf8.DecodeRune(data[i:]) + if r == utf8.RuneError && size == 1 { + result.WriteByte('.') + i++ + } else { + if unicode.IsPrint(r) { + result.WriteRune(r) + } else { + result.WriteByte('.') + } + i += size + } + } + return result.String() +} + +func hexDump(data []byte) string { + var result strings.Builder + const chunkSize = 16 + + for i := 0; i < len(data); i += chunkSize { + end := i + chunkSize + if end > len(data) { + end = len(data) + } + chunk := data[i:end] + + hexStr := hex.EncodeToString(chunk) + hexStr = insertSpace(hexStr) + asciiStr := toPrintableASCII(chunk) + result.WriteString(fmt.Sprintf("%04x %-48s |%s|\n", i, hexStr, asciiStr)) + } + + return result.String() +} + +func processUnknownResponse(sequenceId byte, payload []byte) { + display.PrintlnWithTime(fmt.Sprintf("[Server -> Client] %d-%s:\n%s", sequenceId, MySQLResponseTypeUnknown, hexDump(payload))) +} + +func (mysql *mysqlInterop) dumpServer(r io.Reader, id int, quiet bool, data []byte) { + if len(data) < 4 { + display.PrintlnWithTime("Invalid packet: insufficient data for header") + return + } + + sequenceId := data[3] + payload := data[4:] + + if len(payload) > maxDecodeResponseBodySize { + display.PrintlnWithTime(color.HiRedString(fmt.Sprintf("Packet too large to, just decode %d MB", maxDecodeResponseBodySize/1024/1024))) + payload = payload[:maxDecodeResponseBodySize] + } + + switch getPkgType(payload[0]) { + case MySQLResponseTypeOK: + processOkResponse(sequenceId, payload) + case MySQLResponseTypeError: + processErrorResponse(sequenceId, payload) + case MySQLResponseTypeResultSet: + processResultSetResponse(sequenceId, payload) + case MySQLResponseTypeEOF: + default: + processUnknownResponse(sequenceId, payload) + } + +} + func (mysql *mysqlInterop) dumpClient(r io.Reader, id int, quiet bool, data []byte) { // parse packet length - var packetLength uint32 + var ( + packetLength uint32 + sequenceId uint32 + ) reader := bytes.NewReader(data[:4]) err := binary.Read(reader, binary.BigEndian, &packetLength) if err != nil { @@ -68,7 +276,11 @@ func (mysql *mysqlInterop) dumpClient(r io.Reader, id int, quiet bool, data []by commandName := comTypeMap[commandType] // parse sequence id - sequenceId := data[5] + if len(data) < 6 { + sequenceId = 0 + } else { + sequenceId = uint32(data[5]) + } // parse query var query []byte @@ -80,44 +292,12 @@ func (mysql *mysqlInterop) dumpClient(r io.Reader, id int, quiet bool, data []by } if utf8.Valid(query) { - display.PrintlnWithTime(fmt.Sprintf("[Client] %d-%s: %s", sequenceId, commandName, string(query))) + display.PrintlnWithTime(fmt.Sprintf("[Client -> Server] %d-%s: %s", sequenceId, commandName, string(query))) } else { display.PrintlnWithTime(color.HiRedString("Invalid Query %v", query)) } } -func (mysql *mysqlInterop) dumpServer(r io.Reader, id int, quiet bool, data []byte) { - header := make([]byte, 4) - _, err := r.Read(header) - if err != nil { - display.PrintlnWithTime(color.HiRedString("Error reading packet length: %v\n", err)) - return - } - - packetLength := int(binary.BigEndian.Uint16(header[:3])) - responseData := make([]byte, packetLength) - _, err = r.Read(responseData) - if err != nil { - display.PrintlnWithTime(color.HiRedString("Error reading packet data: %v\n", err)) - return - } - - // OK packet, value is 0x00 - // Error packet, value is 0xFF - responseType := data[0] - if responseType == 0x00 { - fmt.Println("OK packet", hex.Dump(responseData)) - } else if responseType == 0xff { - fmt.Println("Error packet", hex.Dump(responseData)) - } else if responseType == 0xfe { - fmt.Println("EOF packet", hex.Dump(responseData)) - } else if responseType > 0x00 && responseType < 0xfa { - fmt.Println("other packet", hex.Dump(responseData)) - } else { - display.PrintlnWithTime(color.HiRedString("invalid packet")) - } -} - func (mysql *mysqlInterop) Dump(r io.Reader, source string, id int, quiet bool) { buffer := make([]byte, bufferSize) for { @@ -129,7 +309,6 @@ func (mysql *mysqlInterop) Dump(r io.Reader, source string, id int, quiet bool) if n > 0 && !quiet { data := buffer[:n] - if source == "CLIENT" { mysql.dumpClient(r, id, quiet, data) } else {