Files
2025-06-18 16:02:45 +03:00

933 lines
23 KiB
Go

// Generic data manipulation utilities.
package main
import (
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"net"
"path/filepath"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
"github.com/tinode/chat/server/auth"
"github.com/tinode/chat/server/store"
"github.com/tinode/chat/server/store/types"
"maps"
"golang.org/x/crypto/acme/autocert"
)
// Tag with prefix:
// * prefix starts with an ASCII letter, contains ASCII letters, numbers, from 2 to 16 chars
// * tag body may contain Unicode letters and numbers, as well as the following symbols: +-.!?#@_
// Tag body can be up to maxTagLength (96) chars long.
var prefixedTagRegexp = regexp.MustCompile(`^([a-z]\w{1,15}):([-_+.!?#@\pL\pN]{1,96})$`)
// Generic tag: the same restrictions as tag body.
var tagRegexp = regexp.MustCompile(`^[-_+.!?#@\pL\pN]{1,96}$`)
const nullValue = "\u2421"
// Convert database ranges into wire protocol ranges.
func rangeDeserialize(in []types.Range) []MsgRange {
if len(in) == 0 {
return nil
}
out := make([]MsgRange, 0, len(in))
for _, r := range in {
out = append(out, MsgRange{LowId: r.Low, HiId: r.Hi})
}
return out
}
// Convert wire protocol ranges into database ranges.
func rangeSerialize(in []MsgRange) []types.Range {
if len(in) == 0 {
return nil
}
out := make([]types.Range, 0, len(in))
for _, r := range in {
out = append(out, types.Range{Low: r.LowId, Hi: r.HiId})
}
return out
}
// stringSliceDelta extracts the slices of added and removed strings from two slices:
//
// added := newSlice - (oldSlice & newSlice) -- present in new but missing in old
// removed := oldSlice - (oldSlice & newSlice) -- present in old but missing in new
// intersection := oldSlice & newSlice -- present in both old and new
func stringSliceDelta(rold, rnew []string) (added, removed, intersection []string) {
if len(rold) == 0 && len(rnew) == 0 {
return nil, nil, nil
}
if len(rold) == 0 {
return rnew, nil, nil
}
if len(rnew) == 0 {
return nil, rold, nil
}
sort.Strings(rold)
sort.Strings(rnew)
// Match old slice against the new slice and separate removed strings from added.
o, n := 0, 0
lold, lnew := len(rold), len(rnew)
for o < lold || n < lnew {
if o == lold || (n < lnew && rold[o] > rnew[n]) {
// Present in new, missing in old: added
added = append(added, rnew[n])
n++
} else if n == lnew || rold[o] < rnew[n] {
// Present in old, missing in new: removed
removed = append(removed, rold[o])
o++
} else {
// present in both
intersection = append(intersection, rold[o])
if o < lold {
o++
}
if n < lnew {
n++
}
}
}
return added, removed, intersection
}
// Process credentials for correctness: remove duplicate and unknown methods.
// In case of duplicate methods only the first one satisfying valueRequired is kept.
// If valueRequired is true, keep only those where Value is non-empty.
func normalizeCredentials(creds []MsgCredClient, valueRequired bool) []MsgCredClient {
if len(creds) == 0 {
return nil
}
index := make(map[string]*MsgCredClient)
for i := range creds {
c := &creds[i]
if _, ok := globals.validators[c.Method]; ok && (!valueRequired || c.Value != "") {
index[c.Method] = c
}
}
creds = make([]MsgCredClient, 0, len(index))
for _, c := range index {
creds = append(creds, *index[c.Method])
}
return creds
}
// Get a string slice with methods of credentials.
func credentialMethods(creds []MsgCredClient) []string {
out := make([]string, len(creds))
for i := range creds {
out[i] = creds[i].Method
}
return out
}
// Takes MsgClientGet query parameters, returns database query parameters
func msgOpts2storeOpts(req *MsgGetOpts) *types.QueryOpt {
var opts *types.QueryOpt
if req != nil {
opts = &types.QueryOpt{
User: types.ParseUserId(req.User),
Topic: req.Topic,
IfModifiedSince: req.IfModifiedSince,
Limit: req.Limit,
Since: req.SinceId,
Before: req.BeforeId,
IdRanges: rangeSerialize(req.IdRanges),
}
}
return opts
}
// Check if the interface contains a string with a single Unicode Del control character.
func isNullValue(i any) bool {
if str, ok := i.(string); ok {
return str == nullValue
}
return false
}
func decodeStoreError(err error, id string, ts time.Time, params map[string]any) *ServerComMessage {
return decodeStoreErrorExplicitTs(err, id, "", ts, ts, params)
}
func decodeStoreErrorExplicitTs(err error, id, topic string, serverTs, incomingReqTs time.Time,
params map[string]any) *ServerComMessage {
var errmsg *ServerComMessage
if err == nil {
errmsg = NoErrExplicitTs(id, topic, serverTs, incomingReqTs)
} else if storeErr, ok := err.(types.StoreError); !ok {
errmsg = ErrUnknownExplicitTs(id, topic, serverTs, incomingReqTs)
} else {
switch storeErr {
case types.ErrInternal:
errmsg = ErrUnknownExplicitTs(id, topic, serverTs, incomingReqTs)
case types.ErrMalformed:
errmsg = ErrMalformedExplicitTs(id, topic, serverTs, incomingReqTs)
case types.ErrFailed:
errmsg = ErrAuthFailed(id, topic, serverTs, incomingReqTs)
case types.ErrPermissionDenied:
errmsg = ErrPermissionDeniedExplicitTs(id, topic, serverTs, incomingReqTs)
case types.ErrDuplicate:
errmsg = ErrDuplicateCredential(id, topic, serverTs, incomingReqTs)
case types.ErrUnsupported:
errmsg = ErrNotImplemented(id, topic, serverTs, incomingReqTs)
case types.ErrExpired:
errmsg = ErrAuthFailed(id, topic, serverTs, incomingReqTs)
case types.ErrPolicy:
errmsg = ErrPolicyExplicitTs(id, topic, serverTs, incomingReqTs)
case types.ErrCredentials:
errmsg = InfoValidateCredentialsExplicitTs(id, serverTs, incomingReqTs)
case types.ErrUserNotFound:
errmsg = ErrUserNotFound(id, topic, serverTs, incomingReqTs)
case types.ErrTopicNotFound:
errmsg = ErrTopicNotFound(id, topic, serverTs, incomingReqTs)
case types.ErrNotFound:
errmsg = ErrNotFoundExplicitTs(id, topic, serverTs, incomingReqTs)
case types.ErrInvalidResponse:
errmsg = ErrInvalidResponse(id, topic, serverTs, incomingReqTs)
case types.ErrRedirected:
errmsg = InfoUseOther(id, topic, params["topic"].(string), serverTs, incomingReqTs)
default:
errmsg = ErrUnknownExplicitTs(id, topic, serverTs, incomingReqTs)
}
}
if params != nil {
errmsg.Ctrl.Params = params
}
return errmsg
}
// Helper function to select access mode for the given auth level
func selectAccessMode(authLvl auth.Level, anonMode, authMode, rootMode types.AccessMode) types.AccessMode {
switch authLvl {
case auth.LevelNone:
return types.ModeNone
case auth.LevelAnon:
return anonMode
case auth.LevelAuth:
return authMode
case auth.LevelRoot:
return rootMode
default:
return types.ModeNone
}
}
// Get default modeWant for the given topic category
func getDefaultAccess(cat types.TopicCat, authUser, isChan bool) types.AccessMode {
if !authUser {
return types.ModeNone
}
switch cat {
case types.TopicCatP2P:
return globals.typesModeCP2P
case types.TopicCatFnd:
return types.ModeNone
case types.TopicCatGrp:
if isChan {
return types.ModeCChnWriter
}
return types.ModeCPublic
case types.TopicCatMe:
return types.ModeCMeFnd
case types.TopicCatSlf:
return types.ModeCSelf
default:
panic("Unknown topic category")
}
}
// Parse topic access parameters
func parseTopicAccess(acs *MsgDefaultAcsMode, defAuth, defAnon types.AccessMode) (authMode, anonMode types.AccessMode,
err error) {
authMode, anonMode = defAuth, defAnon
if acs.Auth != "" {
err = authMode.UnmarshalText([]byte(acs.Auth))
}
if acs.Anon != "" {
err = anonMode.UnmarshalText([]byte(acs.Anon))
}
return
}
// Parse one component of a semantic version string.
func parseVersionPart(vers string) int {
end := strings.IndexFunc(vers, func(r rune) bool {
return !unicode.IsDigit(r)
})
t := 0
var err error
if end > 0 {
t, err = strconv.Atoi(vers[:end])
} else if len(vers) > 0 {
t, err = strconv.Atoi(vers)
}
if err != nil || t > 0x1fff || t <= 0 {
return 0
}
return t
}
// Parses semantic version string in the following formats:
//
// 1.2, 1.2abc, 1.2.3, 1.2.3-abc, v0.12.34-rc5
//
// Unparceable values are replaced with zeros.
func parseVersion(vers string) int {
var major, minor, patch int
// Maybe remove the optional "v" prefix.
vers = strings.TrimPrefix(vers, "v")
// We can handle 3 parts only.
parts := strings.SplitN(vers, ".", 3)
count := len(parts)
if count > 0 {
major = parseVersionPart(parts[0])
if count > 1 {
minor = parseVersionPart(parts[1])
if count > 2 {
patch = parseVersionPart(parts[2])
}
}
}
return (major << 16) | (minor << 8) | patch
}
// Version as a base-10 number. Used by monitoring.
func base10Version(hex int) int64 {
major := hex >> 16 & 0xFF
minor := hex >> 8 & 0xFF
trailer := hex & 0xFF
return int64(major*10000 + minor*100 + trailer)
}
func versionToString(vers int) string {
str := strconv.Itoa(vers>>16) + "." + strconv.Itoa((vers>>8)&0xff)
if vers&0xff != 0 {
str += "-" + strconv.Itoa(vers&0xff)
}
return str
}
// Tag handling
// filterTags takes a slice of tags and a map of namespaces, return a slice of namespace tags
// contained in the input.
// params: Tags to filter, namespaces to use as the filter.
func filterTags(tags []string, namespaces map[string]bool) []string {
var out []string
if len(namespaces) == 0 {
return out
}
for _, s := range tags {
parts := prefixedTagRegexp.FindStringSubmatch(s)
if len(parts) < 2 {
continue
}
// [1] is the prefix. [0] is the whole tag.
if namespaces[parts[1]] {
out = append(out, s)
}
}
return out
}
// rewriteTag attempts to match the original token against the email and telephone number.
// The tag is expected to be in lowercase.
// On success, it returns a slice with the original tag and the tag with the corresponding prefix. It returns an
// empty slice if the tag is invalid.
// TODO: consider inferring country code from user location.
func rewriteTag(orig, countryCode string) []string {
// Check if the tag already has a prefix e.g. basic:alice.
if prefixedTagRegexp.MatchString(orig) {
return []string{orig}
}
// Check if token can be rewritten by any of the validators
param := map[string]any{"countryCode": countryCode}
for name, conf := range globals.validators {
if conf.addToTags {
val := store.Store.GetValidator(name)
if tag, _ := val.PreCheck(orig, param); tag != "" {
return []string{orig, tag}
}
}
}
if tagRegexp.MatchString(orig) {
return []string{orig}
}
// invalid generic tag
return nil
}
// rewriteTagSlice calls rewriteTag for each slice member and return a new slice with original and converted values.
func rewriteTagSlice(tags []string, countryCode string) []string {
var result []string
for _, tag := range tags {
rewritten := rewriteTag(tag, countryCode)
if len(rewritten) != 0 {
result = append(result, rewritten...)
}
}
return result
}
// restrictedTagsEqual checks if two sets of tags contain the same set of restricted tags:
// true - same, false - different.
func restrictedTagsEqual(oldTags, newTags []string, namespaces map[string]bool) bool {
rold := filterTags(oldTags, namespaces)
rnew := filterTags(newTags, namespaces)
if len(rold) != len(rnew) {
return false
}
sort.Strings(rold)
sort.Strings(rnew)
// Match old tags against the new tags.
for i := range rnew {
if rold[i] != rnew[i] {
return false
}
}
return true
}
// Trim whitespace, remove short/empty tags and duplicates, convert to lowercase, ensure
// the number of tags does not exceed the maximum.
func normalizeTags(src []string, maxTags int) types.StringSlice {
if src == nil {
return nil
}
// Make sure the number of tags does not exceed the maximum.
// Technically it may result in fewer tags than the maximum due to empty tags and
// duplicates, but that's user's fault.
if len(src) > maxTags {
src = src[:maxTags]
}
// Trim whitespace and force to lowercase.
for i := range src {
src[i] = strings.ToLower(strings.TrimSpace(src[i]))
}
// Sort tags
sort.Strings(src)
// Remove short, invalid tags and de-dupe keeping the order. It may result in fewer tags than could have
// been if length were enforced later, but that's client's fault.
var prev string
var dst []string
for _, curr := range src {
if isNullValue(curr) {
// Return non-nil empty array
return make([]string, 0, 1)
}
// Unicode handling
ucurr := []rune(curr)
// Enforce length in characters, not in bytes.
if len(ucurr) < minTagLength || len(ucurr) > maxTagLength || curr == prev {
continue
}
// Make sure the tag starts with a letter or a number.
if unicode.IsLetter(ucurr[0]) || unicode.IsDigit(ucurr[0]) {
dst = append(dst, curr)
prev = curr
}
}
return types.StringSlice(dst)
}
func validateTag(tag string) (string, string) {
// Check if the tag already has a prefix e.g. basic:alice.
if parts := prefixedTagRegexp.FindStringSubmatch(tag); len(parts) == 3 {
// Valid prefixed tag.
return parts[1], parts[2]
}
if tagRegexp.MatchString(tag) {
// Valid unprefixed tag (tag value only).
return "", tag
}
return "", ""
}
// hasDuplicateNamespaceTags checks for duplication of unique NS tags.
// Each namespace can have only one tag. This does not prevent tags from
// being duplicate across requests, just saves an extra DB call.
func hasDuplicateNamespaceTags(src []string, uniqueNS string) bool {
found := map[string]bool{}
for _, tag := range src {
parts := prefixedTagRegexp.FindStringSubmatch(tag)
if len(parts) != 3 {
// Invalid tag, ignored.
continue
}
if uniqueNS == parts[1] && found[parts[1]] {
return true
}
found[parts[1]] = true
}
return false
}
// Parser for search queries. The query may contain non-ASCII characters,
// i.e. length of string in bytes != length of string in runes.
// Returns
// * required tags: AND tags (at least one must be present in every result),
// * optional tags
// * error.
func parseSearchQuery(query string) ([]string, []string, error) {
const (
NONE = iota
QUO // 1
AND // 2
OR // 3
END // 4
ORD // 5
)
type token struct {
op int
val string
}
type context struct {
// Pre-token operand.
preOp int
// Post-token operand.
postOp int
// Inside quoted string.
quo bool
// Start of the current token.
start int
// End of the current token.
end int
}
ctx := context{preOp: AND}
var out []token
var prev int
query = strings.TrimSpace(query)
// Split query into tokens.
// i - character index into the string.
// pos - rune index into the string.
// w - width of the current rune in characters.
for i, w, pos := 0, 0, 0; prev != END; i, pos = i+w, pos+1 {
//
var emit bool
// Lexer: get next rune.
var r rune
// Ordinary character by default.
curr := ORD
r, w = utf8.DecodeRuneInString(query[i:])
switch {
case w == 0:
// Width zero: end of the string.
curr = END
case r == '"':
// Quote opening or closing.
curr = QUO
case !ctx.quo:
// Not inside quoted string, test for control characters.
if r == ' ' || r == '\t' {
// Tab or space.
curr = AND
} else if r == ',' {
curr = OR
}
}
if curr == QUO {
if ctx.quo {
// End of the quoted string. Close the quote.
ctx.quo = false
} else {
if prev == ORD {
// Reject strings like a"b
return nil, nil, fmt.Errorf("missing operator at or near %d", pos)
}
// Start of the quoted string. Open the quote.
ctx.quo = true
}
// Treat quoted string as ordinary.
curr = ORD
}
// Parser: process the current lexem in context.
switch curr {
case OR:
if ctx.postOp == OR {
// More than one comma: ' , ,,'
return nil, nil, fmt.Errorf("invalid operator sequence at or near %d", pos)
}
// Ensure context is not "and", i.e. the case like ' ,' -> ','
ctx.postOp = OR
if prev == ORD {
// Close the current token.
ctx.end = i
}
case AND:
if prev == ORD {
// Close the current token.
ctx.end = i
ctx.postOp = AND
} else if ctx.postOp != OR {
// "and" does not change the "or" context.
ctx.postOp = AND
}
case ORD:
if prev == OR || prev == AND {
// Ordinary character after a comma or a space: ' a' or ',a'.
// Emit without changing the operation.
emit = true
}
case END:
if prev == ORD {
// Close the current token.
ctx.end = i
}
emit = true
}
if emit {
if ctx.quo && curr == END {
return nil, nil, fmt.Errorf("unterminated quoted string at or near %d %#v", pos, ctx)
}
// Emit the new token.
op := ctx.preOp
if ctx.postOp == OR {
op = OR
}
start, end := ctx.start, ctx.end
if query[start] == '"' && query[end-1] == '"' {
start++
end--
}
// Add token if non-empty.
if start < end {
out = append(out, token{val: strings.ToLower(query[start:end]), op: op})
}
ctx.start = i
ctx.preOp, ctx.postOp = ctx.postOp, NONE
}
prev = curr
}
if len(out) == 0 {
return nil, nil, nil
}
// Convert tokens to two string slices.
var and []string
var or []string
for _, t := range out {
switch t.op {
case AND:
and = append(and, t.val)
case OR:
or = append(or, t.val)
}
}
return and, or, nil
}
// Returns > 0 if v1 > v2; zero if equal; < 0 if v1 < v2
// Only Major and Minor parts are compared, the trailer is ignored.
func versionCompare(v1, v2 int) int {
return (v1 >> 8) - (v2 >> 8)
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
// Truncate string if it's too long. Used in logging.
func truncateStringIfTooLong(s string) string {
if len(s) <= 1024 {
return s
}
return s[:1024] + "..."
}
// Convert relative filepath to absolute.
func toAbsolutePath(base, path string) string {
if filepath.IsAbs(path) {
return path
}
return filepath.Clean(filepath.Join(base, path))
}
// Detect platform from the UserAgent string.
func platformFromUA(ua string) string {
ua = strings.ToLower(ua)
switch {
case strings.Contains(ua, "reactnative"):
switch {
case strings.Contains(ua, "iphone"),
strings.Contains(ua, "ipad"):
return "ios"
case strings.Contains(ua, "android"):
return "android"
}
return ""
case strings.Contains(ua, "tinodejs"):
return "web"
case strings.Contains(ua, "tindroid"):
return "android"
case strings.Contains(ua, "tinodios"):
return "ios"
}
return ""
}
func parseTLSConfig(tlsEnabled bool, jsconfig json.RawMessage) (*tls.Config, error) {
type tlsAutocertConfig struct {
// Domains to support by autocert
Domains []string `json:"domains"`
// Name of directory where auto-certificates are cached, e.g. /etc/letsencrypt/live/your-domain-here
CertCache string `json:"cache"`
// Contact email for letsencrypt
Email string `json:"email"`
}
type tlsConfig struct {
// Flag enabling TLS
Enabled bool `json:"enabled"`
// Listen for connections on this address:port and redirect them to HTTPS port.
RedirectHTTP string `json:"http_redirect"`
// Enable Strict-Transport-Security by setting max_age > 0
StrictMaxAge int `json:"strict_max_age"`
// ACME autocert config, e.g. letsencrypt.org
Autocert *tlsAutocertConfig `json:"autocert"`
// If Autocert is not defined, provide file names of static certificate and key
CertFile string `json:"cert_file"`
KeyFile string `json:"key_file"`
}
var config tlsConfig
if jsconfig != nil {
if err := json.Unmarshal(jsconfig, &config); err != nil {
return nil, errors.New("http: failed to parse tls_config: " + err.Error() + "(" + string(jsconfig) + ")")
}
}
if !tlsEnabled && !config.Enabled {
return nil, nil
}
if config.StrictMaxAge > 0 {
globals.tlsStrictMaxAge = strconv.Itoa(config.StrictMaxAge)
}
globals.tlsRedirectHTTP = config.RedirectHTTP
// If autocert is provided, use it.
if config.Autocert != nil {
certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(config.Autocert.Domains...),
Cache: autocert.DirCache(config.Autocert.CertCache),
Email: config.Autocert.Email,
}
return certManager.TLSConfig(), nil
}
// Otherwise try to use static keys.
cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile)
if err != nil {
return nil, err
}
return &tls.Config{Certificates: []tls.Certificate{cert}}, nil
}
// Merge source interface{} into destination interface.
// If values are maps,deep-merge them. Otherwise shallow-copy.
// Returns dst, true if the dst value was changed.
func mergeInterfaces(dst, src any) (any, bool) {
var changed bool
if src == nil {
return dst, changed
}
vsrc := reflect.ValueOf(src)
switch vsrc.Kind() {
case reflect.Map:
if xsrc, ok := src.(map[string]any); ok {
xdst, _ := dst.(map[string]any)
dst, changed = mergeMaps(xdst, xsrc)
} else {
changed = true
dst = src
}
case reflect.String:
if vsrc.String() == nullValue {
changed = dst != nil
dst = nil
} else {
changed = true
dst = src
}
default:
changed = true
dst = src
}
return dst, changed
}
// Deep copy maps.
func mergeMaps(dst, src map[string]any) (map[string]any, bool) {
var changed bool
if len(src) == 0 {
return dst, changed
}
if dst == nil {
dst = make(map[string]any)
}
for key, val := range src {
xval := reflect.ValueOf(val)
switch xval.Kind() {
case reflect.Map:
if xsrc, _ := val.(map[string]any); xsrc != nil {
// Deep-copy map[string]any
xdst, _ := dst[key].(map[string]any)
var lchange bool
dst[key], lchange = mergeMaps(xdst, xsrc)
changed = changed || lchange
} else if val != nil {
// The map is shallow-copied if it's not of the type map[string]any
dst[key] = val
changed = true
}
case reflect.String:
changed = true
if xval.String() == nullValue {
delete(dst, key)
} else if val != nil {
dst[key] = val
}
default:
if val != nil {
dst[key] = val
changed = true
}
}
}
return dst, changed
}
// Shallow copy of a map
func copyMap(src map[string]any) map[string]any {
dst := make(map[string]any, len(src))
maps.Copy(dst, src)
return dst
}
// netListener creates net.Listener for tcp and unix domains:
// if addr is in the form "unix:/run/tinode.sock" it's a unix socket, otherwise TCP host:port.
func netListener(addr string) (net.Listener, error) {
addrParts := strings.SplitN(addr, ":", 2)
if len(addrParts) == 2 && addrParts[0] == "unix" {
return net.Listen("unix", addrParts[1])
}
return net.Listen("tcp", addr)
}
// Check if specified address is a unix socket like "unix:/run/tinode.sock".
func isUnixAddr(addr string) bool {
addrParts := strings.SplitN(addr, ":", 2)
return len(addrParts) == 2 && addrParts[0] == "unix"
}
var privateIPBlocks []*net.IPNet
func isRoutableIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return false
}
if privateIPBlocks == nil {
for _, cidr := range []string{
"10.0.0.0/8", // RFC1918
"172.16.0.0/12", // RFC1918
"192.168.0.0/16", // RFC1918
"fc00::/7", // RFC4193, IPv6 unique local addr
} {
_, block, _ := net.ParseCIDR(cidr)
privateIPBlocks = append(privateIPBlocks, block)
}
}
for _, block := range privateIPBlocks {
if block.Contains(ip) {
return false
}
}
return true
}