gorm log with traceId 打印带有traceId信息的日志,通过context实现

背景

无论是单体项目,还是分布式项目,一个请求进来总会有一定的链路,单体项目中会调用各种方法,分布式服务中更麻烦一点,跨服务调用。于是乎,我们就希望有一个全局的traceId可以把一个请求过程中经过的所有链路的关键信息串联起来,这样的话在检索日志的时候可以带来极大的方便,根据traceId把整个链路上的日志全部打印出来。

在golang项目中,通用的写法是通过context实现traceId信息传递。那么gorm如何通过context把traceId传进去,以实现打印日志带上traceId信息呢?

我们得通过阅读源码来寻找这个问题的解决方案。

gorm源码解读

我们首先需要了解gorm日志打印是如何实现的,任意找一个sql执行方法进去,比如,查询的方法。

func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
	tx = db.getInstance()
	if len(conds) > 0 {
		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
			tx.Statement.AddClause(clause.Where{Exprs: exprs})
		}
	}
	tx.Statement.Dest = dest
	return tx.callbacks.Query().Execute(tx)
}

进一步寻找打印日志的逻辑,定位到Execute方法。在Execute方法中找到了打印日志的逻辑。

if stmt.SQL.Len() > 0 {
		db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
			sql, vars := stmt.SQL.String(), stmt.Vars
			if filter, ok := db.Logger.(ParamsFilter); ok {
				sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
			}
			return db.Dialector.Explain(sql, vars...), db.RowsAffected
		}, db.Error)
	}

到了这边,我们发现,日志打印调用的Trace方法的第一个传参是Context。所以,我们继续顺腾摸瓜看这个Context是通过什么方式传进来的。Context从db.Statement中获取的。所以,我们需要寻找给db.Statement赋值的方法。

func (db *DB) getInstance() *DB {
	if db.clone > 0 {
		tx := &DB{Config: db.Config, Error: db.Error}

		if db.clone == 1 {
			// clone with new statement
			tx.Statement = &Statement{
				DB:        tx,
				ConnPool:  db.Statement.ConnPool,
				Context:   db.Statement.Context,
				Clauses:   map[string]clause.Clause{},
				Vars:      make([]interface{}, 0, 8),
				SkipHooks: db.Statement.SkipHooks,
			}
		} else {
			// with clone statement
			tx.Statement = db.Statement.clone()
			tx.Statement.DB = tx
		}

		return tx
	}

	return db
}

然后,我们就在WithContext的方法中找到了把context传递进来的入口。

func (db *DB) WithContext(ctx context.Context) *DB {
	return db.Session(&Session{Context: ctx})
}

传Context的入口找到了,那么,gorm中如何根据context中自定义值打印日志呢?比如,Context中塞了自定义的traceId的key,value值?

我们回到前面打印日志的地方,看打印日志的方法,打印日志的Trace方法是这个接口下的一个方法。

type Interface interface {
	LogMode(LogLevel) Interface
	Info(context.Context, string, ...interface{})
	Warn(context.Context, string, ...interface{})
	Error(context.Context, string, ...interface{})
	Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
}
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
	if l.LogLevel <= Silent {
		return
	}

	elapsed := time.Since(begin)
	switch {
	case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError):
		sql, rows := fc()
		if rows == -1 {
			l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
		} else {
			l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
		}
	case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
		sql, rows := fc()
		slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
		if rows == -1 {
			l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql)
		} else {
			l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql)
		}
	case l.LogLevel == Info:
		sql, rows := fc()
		if rows == -1 {
			l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
		} else {
			l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
		}
	}
}

定位到trace方法中,我们发现并没有处理Context,其实很正常。

所以,我们需要重写这个Trace方法,自定义一个log对象,实现gorm的log接口。

解决方案

直接上代码。

package main

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"time"

	"go.uber.org/zap"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"
	"gorm.io/gorm/utils"

	"gorm.io/driver/mysql"
)

func main() {
	zapL, err := zap.NewProduction()
	if err != nil {
		panic(err)
	}
	log := New(zapL,
		WithCustomFields(
			String("timeStamp", time.Now().Format("2006-01-02 15:04:05")),
			func(ctx context.Context) zap.Field {
				v := ctx.Value("requestId")
				if v == nil {
					return zap.Skip()
				}
				if vv, ok := v.(string); ok {
					return zap.String("trace", vv)
				}
				return zap.Skip()
			},
			func(ctx context.Context) zap.Field {
				v := ctx.Value("method")
				if v == nil {
					return zap.Skip()
				}
				if vv, ok := v.(string); ok {
					return zap.String("method", vv)
				}
				return zap.Skip()
			},
		),
		WithConfig(logger.Config{
			SlowThreshold:             200 * time.Millisecond,
			Colorful:                  false,
			IgnoreRecordNotFoundError: false,
			LogLevel:                  logger.Info,
		}),
	)
	mysqlConfig := mysql.Config{
		DSN:                       "*******", // DSN data source name
		DefaultStringSize:         191,                                                                                                // string 类型字段的默认长度
		SkipInitializeWithVersion: false,                                                                                              // 根据版本自动配置
	}
	// your dialector
	db, _ := gorm.Open(mysql.New(mysqlConfig), &gorm.Config{Logger: log})
	// do your things
	result := make(map[string]interface{})
	ctx := context.WithValue(context.Background(), "method", "method")
	db.WithContext(context.WithValue(ctx, "requestId", "requestId123456")).Table("privacy_detail").Find(&result)
	db.WithContext(context.WithValue(context.Background(), "requestId", "requestId123457")).Table("privacy_detail").Find(&result)
	db.WithContext(context.WithValue(context.Background(), "requestId", "requestId123458")).Table("privacy_detail").Create(&result)
	log.Info(context.WithValue(context.Background(), "requestId", "requestId123456"), "msg", "args")
}

// Logger logger for gorm2
type Logger struct {
	log *zap.Logger
	logger.Config
	customFields []func(ctx context.Context) zap.Field
}

// Option logger/recover option
type Option func(l *Logger)

// WithCustomFields optional custom field
func WithCustomFields(fields ...func(ctx context.Context) zap.Field) Option {
	return func(l *Logger) {
		l.customFields = fields
	}
}

// WithConfig optional custom logger.Config
func WithConfig(cfg logger.Config) Option {
	return func(l *Logger) {
		l.Config = cfg
	}
}

// SetGormDBLogger set db logger
func SetGormDBLogger(db *gorm.DB, l logger.Interface) {
	db.Logger = l
}

// New logger form gorm2
func New(zapLogger *zap.Logger, opts ...Option) logger.Interface {
	l := &Logger{
		log: zapLogger,
		Config: logger.Config{
			SlowThreshold:             200 * time.Millisecond,
			Colorful:                  false,
			IgnoreRecordNotFoundError: false,
			LogLevel:                  logger.Warn,
		},
	}
	for _, opt := range opts {
		opt(l)
	}
	return l
}

// LogMode log mode
func (l *Logger) LogMode(level logger.LogLevel) logger.Interface {
	newLogger := *l
	newLogger.LogLevel = level
	return &newLogger
}

// Info print info
func (l Logger) Info(ctx context.Context, msg string, args ...interface{}) {
	if l.LogLevel >= logger.Info {
		//预留10个字段位置
		fields := make([]zap.Field, 0, 10+len(l.customFields))
		fields = append(fields, zap.String("file", utils.FileWithLineNum()))
		for _, customField := range l.customFields {
			fields = append(fields, customField(ctx))
		}
		for _, arg := range args {
			if vv, ok := arg.(zapcore.Field); ok {
				if len(vv.String) > 0 {
					fields = append(fields, zap.String(vv.Key, vv.String))
				} else if vv.Integer > 0 {
					fields = append(fields, zap.Int64(vv.Key, vv.Integer))
				} else {
					fields = append(fields, zap.Any(vv.Key, vv.Interface))
				}
			}
		}
		l.log.Info(msg, fields...)
	}
}

// Warn print warn messages
func (l Logger) Warn(ctx context.Context, msg string, args ...interface{}) {
	if l.LogLevel >= logger.Warn {
		//预留10个字段位置
		fields := make([]zap.Field, 0, 10+len(l.customFields))
		fields = append(fields, zap.String("file", utils.FileWithLineNum()))
		for _, customField := range l.customFields {
			fields = append(fields, customField(ctx))
		}
		for _, arg := range args {
			if vv, ok := arg.(zapcore.Field); ok {
				if len(vv.String) > 0 {
					fields = append(fields, zap.String(vv.Key, vv.String))
				} else if vv.Integer > 0 {
					fields = append(fields, zap.Int64(vv.Key, vv.Integer))
				} else {
					fields = append(fields, zap.Any(vv.Key, vv.Interface))
				}
			}
		}
		l.log.Warn(msg, fields...)
	}
}

// Error print error messages
func (l Logger) Error(ctx context.Context, msg string, args ...interface{}) {
	if l.LogLevel >= logger.Error {
		//预留10个字段位置
		fields := make([]zap.Field, 0, 10+len(l.customFields))
		fields = append(fields, zap.String("file", utils.FileWithLineNum()))
		for _, customField := range l.customFields {
			fields = append(fields, customField(ctx))
		}

		for _, arg := range args {
			if vv, ok := arg.(zapcore.Field); ok {
				if len(vv.String) > 0 {
					fields = append(fields, zap.String(vv.Key, vv.String))
				} else if vv.Integer > 0 {
					fields = append(fields, zap.Int64(vv.Key, vv.Integer))
				} else {
					fields = append(fields, zap.Any(vv.Key, vv.Interface))
				}
			}
		}
		l.log.Error(msg, fields...)
	}
}

// Trace print sql message
func (l Logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
	if l.LogLevel <= logger.Silent {
		return
	}
	fields := make([]zap.Field, 0, 6+len(l.customFields))
	elapsed := time.Since(begin)
	switch {
	case err != nil && l.LogLevel >= logger.Error && (!l.IgnoreRecordNotFoundError || !errors.Is(err, gorm.ErrRecordNotFound)):
		for _, customField := range l.customFields {
			fields = append(fields, customField(ctx))
		}
		fields = append(fields,
			zap.Error(err),
			zap.String("file", utils.FileWithLineNum()),
			zap.Duration("latency", elapsed),
		)

		sql, rows := fc()
		if rows == -1 {
			fields = append(fields, zap.String("rows", "-"))
		} else {
			fields = append(fields, zap.Int64("rows", rows))
		}
		fields = append(fields, zap.String("sql", sql))
		l.log.Error("", fields...)
	case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= logger.Warn:
		for _, customField := range l.customFields {
			fields = append(fields, customField(ctx))
		}
		fields = append(fields,
			zap.Error(err),
			zap.String("file", utils.FileWithLineNum()),
			zap.String("slow!!!", fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)),
			zap.Duration("latency", elapsed),
		)

		sql, rows := fc()
		if rows == -1 {
			fields = append(fields, zap.String("rows", "-"))
		} else {
			fields = append(fields, zap.Int64("rows", rows))
		}
		fields = append(fields, zap.String("sql", sql))
		l.log.Warn("", fields...)
	case l.LogLevel == logger.Info:
		for _, customField := range l.customFields {
			fields = append(fields, customField(ctx))
		}
		fields = append(fields,
			zap.Error(err),
			zap.String("file", utils.FileWithLineNum()),
			zap.Duration("latency", elapsed),
		)

		sql, rows := fc()
		if rows == -1 {
			fields = append(fields, zap.String("rows", "-"))
		} else {
			fields = append(fields, zap.Int64("rows", rows))
		}
		fields = append(fields, zap.String("sql", sql))
		l.log.Info("", fields...)
	}
}

// Immutable custom immutable field
// Deprecated: use Any instead
func Immutable(key string, value interface{}) func(ctx context.Context) zap.Field {
	return Any(key, value)
}

// Any custom immutable any field
func Any(key string, value interface{}) func(ctx context.Context) zap.Field {
	field := zap.Any(key, value)
	return func(ctx context.Context) zap.Field { return field }
}

// String custom immutable string field
func String(key string, value string) func(ctx context.Context) zap.Field {
	field := zap.String(key, value)
	return func(ctx context.Context) zap.Field { return field }
}

// Int64 custom immutable int64 field
func Int64(key string, value int64) func(ctx context.Context) zap.Field {
	field := zap.Int64(key, value)
	return func(ctx context.Context) zap.Field { return field }
}

// Uint64 custom immutable uint64 field
func Uint64(key string, value uint64) func(ctx context.Context) zap.Field {
	field := zap.Uint64(key, value)
	return func(ctx context.Context) zap.Field { return field }
}

// Float64 custom immutable float32 field
func Float64(key string, value float64) func(ctx context.Context) zap.Field {
	field := zap.Float64(key, value)
	return func(ctx context.Context) zap.Field { return field }
}

自定义结构体

// Logger logger for gorm2

type Logger struct {

    log *zap.Logger

    logger.Config

    customFields []func(ctx context.Context) zap.Field

}

关键在于 customFields定义了一个接受传Context参数的方法。在初始化日志的地方,传从Context中获取对应参数的函数,比如,从context中接受traceId。

由此,gorm log with traceId目的实现。

最近更新

  1. TCP协议是安全的吗?

    2024-05-16 11:50:20       19 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-05-16 11:50:20       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-05-16 11:50:20       20 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-05-16 11:50:20       20 阅读

热门阅读

  1. 解密 Unix 中的 “rc“ 后缀:自定义你的工作环境

    2024-05-16 11:50:20       10 阅读
  2. oracle 临时表 在sql 里面用完要删除吗

    2024-05-16 11:50:20       11 阅读
  3. 简单上手SpringBean的整个装配过程

    2024-05-16 11:50:20       13 阅读
  4. Oracle 数据块之变化时的SCN

    2024-05-16 11:50:20       12 阅读
  5. bert 的MLM框架任务-梯度累积

    2024-05-16 11:50:20       14 阅读