背景
无论是单体项目,还是分布式项目,一个请求进来总会有一定的链路,单体项目中会调用各种方法,分布式服务中更麻烦一点,跨服务调用。于是乎,我们就希望有一个全局的traceId可以把一个请求过程中经过的所有链路的关键信息串联起来,这样的话在检索日志的时候可以带来极大的方便,根据traceId把整个链路上的日志全部打印出来。
在golang项目中,通用的写法是通过context实现traceId信息传递。那么gorm如何通过context把traceId传进去,以实现打印日志带上traceId信息呢?
我们得通过阅读源码来寻找这个问题的解决方案。
gorm源码解读
我们首先需要了解gorm日志打印是如何实现的,任意找一个sql执行方法进去,比如,查询的方法。
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
}
进一步寻找打印日志的逻辑,定位到Execute方法。在Execute方法中找到了打印日志的逻辑。
if stmt.SQL.Len() > 0 {
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
sql, vars := stmt.SQL.String(), stmt.Vars
if filter, ok := db.Logger.(ParamsFilter); ok {
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
}
return db.Dialector.Explain(sql, vars...), db.RowsAffected
}, db.Error)
}
到了这边,我们发现,日志打印调用的Trace方法的第一个传参是Context。所以,我们继续顺腾摸瓜看这个Context是通过什么方式传进来的。Context从db.Statement中获取的。所以,我们需要寻找给db.Statement赋值的方法。
func (db *DB) getInstance() *DB {
if db.clone > 0 {
tx := &DB{Config: db.Config, Error: db.Error}
if db.clone == 1 {
// clone with new statement
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
SkipHooks: db.Statement.SkipHooks,
}
} else {
// with clone statement
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
}
return tx
}
return db
}
然后,我们就在WithContext的方法中找到了把context传递进来的入口。
func (db *DB) WithContext(ctx context.Context) *DB {
return db.Session(&Session{Context: ctx})
}
传Context的入口找到了,那么,gorm中如何根据context中自定义值打印日志呢?比如,Context中塞了自定义的traceId的key,value值?
我们回到前面打印日志的地方,看打印日志的方法,打印日志的Trace方法是这个接口下的一个方法。
type Interface interface {
LogMode(LogLevel) Interface
Info(context.Context, string, ...interface{})
Warn(context.Context, string, ...interface{})
Error(context.Context, string, ...interface{})
Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
}
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= Silent {
return
}
elapsed := time.Since(begin)
switch {
case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError):
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
sql, rows := fc()
slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
if rows == -1 {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case l.LogLevel == Info:
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
}
}
定位到trace方法中,我们发现并没有处理Context,其实很正常。
所以,我们需要重写这个Trace方法,自定义一个log对象,实现gorm的log接口。
解决方案
直接上代码。
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/utils"
"gorm.io/driver/mysql"
)
func main() {
zapL, err := zap.NewProduction()
if err != nil {
panic(err)
}
log := New(zapL,
WithCustomFields(
String("timeStamp", time.Now().Format("2006-01-02 15:04:05")),
func(ctx context.Context) zap.Field {
v := ctx.Value("requestId")
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String("trace", vv)
}
return zap.Skip()
},
func(ctx context.Context) zap.Field {
v := ctx.Value("method")
if v == nil {
return zap.Skip()
}
if vv, ok := v.(string); ok {
return zap.String("method", vv)
}
return zap.Skip()
},
),
WithConfig(logger.Config{
SlowThreshold: 200 * time.Millisecond,
Colorful: false,
IgnoreRecordNotFoundError: false,
LogLevel: logger.Info,
}),
)
mysqlConfig := mysql.Config{
DSN: "*******", // DSN data source name
DefaultStringSize: 191, // string 类型字段的默认长度
SkipInitializeWithVersion: false, // 根据版本自动配置
}
// your dialector
db, _ := gorm.Open(mysql.New(mysqlConfig), &gorm.Config{Logger: log})
// do your things
result := make(map[string]interface{})
ctx := context.WithValue(context.Background(), "method", "method")
db.WithContext(context.WithValue(ctx, "requestId", "requestId123456")).Table("privacy_detail").Find(&result)
db.WithContext(context.WithValue(context.Background(), "requestId", "requestId123457")).Table("privacy_detail").Find(&result)
db.WithContext(context.WithValue(context.Background(), "requestId", "requestId123458")).Table("privacy_detail").Create(&result)
log.Info(context.WithValue(context.Background(), "requestId", "requestId123456"), "msg", "args")
}
// Logger logger for gorm2
type Logger struct {
log *zap.Logger
logger.Config
customFields []func(ctx context.Context) zap.Field
}
// Option logger/recover option
type Option func(l *Logger)
// WithCustomFields optional custom field
func WithCustomFields(fields ...func(ctx context.Context) zap.Field) Option {
return func(l *Logger) {
l.customFields = fields
}
}
// WithConfig optional custom logger.Config
func WithConfig(cfg logger.Config) Option {
return func(l *Logger) {
l.Config = cfg
}
}
// SetGormDBLogger set db logger
func SetGormDBLogger(db *gorm.DB, l logger.Interface) {
db.Logger = l
}
// New logger form gorm2
func New(zapLogger *zap.Logger, opts ...Option) logger.Interface {
l := &Logger{
log: zapLogger,
Config: logger.Config{
SlowThreshold: 200 * time.Millisecond,
Colorful: false,
IgnoreRecordNotFoundError: false,
LogLevel: logger.Warn,
},
}
for _, opt := range opts {
opt(l)
}
return l
}
// LogMode log mode
func (l *Logger) LogMode(level logger.LogLevel) logger.Interface {
newLogger := *l
newLogger.LogLevel = level
return &newLogger
}
// Info print info
func (l Logger) Info(ctx context.Context, msg string, args ...interface{}) {
if l.LogLevel >= logger.Info {
//预留10个字段位置
fields := make([]zap.Field, 0, 10+len(l.customFields))
fields = append(fields, zap.String("file", utils.FileWithLineNum()))
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
for _, arg := range args {
if vv, ok := arg.(zapcore.Field); ok {
if len(vv.String) > 0 {
fields = append(fields, zap.String(vv.Key, vv.String))
} else if vv.Integer > 0 {
fields = append(fields, zap.Int64(vv.Key, vv.Integer))
} else {
fields = append(fields, zap.Any(vv.Key, vv.Interface))
}
}
}
l.log.Info(msg, fields...)
}
}
// Warn print warn messages
func (l Logger) Warn(ctx context.Context, msg string, args ...interface{}) {
if l.LogLevel >= logger.Warn {
//预留10个字段位置
fields := make([]zap.Field, 0, 10+len(l.customFields))
fields = append(fields, zap.String("file", utils.FileWithLineNum()))
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
for _, arg := range args {
if vv, ok := arg.(zapcore.Field); ok {
if len(vv.String) > 0 {
fields = append(fields, zap.String(vv.Key, vv.String))
} else if vv.Integer > 0 {
fields = append(fields, zap.Int64(vv.Key, vv.Integer))
} else {
fields = append(fields, zap.Any(vv.Key, vv.Interface))
}
}
}
l.log.Warn(msg, fields...)
}
}
// Error print error messages
func (l Logger) Error(ctx context.Context, msg string, args ...interface{}) {
if l.LogLevel >= logger.Error {
//预留10个字段位置
fields := make([]zap.Field, 0, 10+len(l.customFields))
fields = append(fields, zap.String("file", utils.FileWithLineNum()))
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
for _, arg := range args {
if vv, ok := arg.(zapcore.Field); ok {
if len(vv.String) > 0 {
fields = append(fields, zap.String(vv.Key, vv.String))
} else if vv.Integer > 0 {
fields = append(fields, zap.Int64(vv.Key, vv.Integer))
} else {
fields = append(fields, zap.Any(vv.Key, vv.Interface))
}
}
}
l.log.Error(msg, fields...)
}
}
// Trace print sql message
func (l Logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= logger.Silent {
return
}
fields := make([]zap.Field, 0, 6+len(l.customFields))
elapsed := time.Since(begin)
switch {
case err != nil && l.LogLevel >= logger.Error && (!l.IgnoreRecordNotFoundError || !errors.Is(err, gorm.ErrRecordNotFound)):
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
fields = append(fields,
zap.Error(err),
zap.String("file", utils.FileWithLineNum()),
zap.Duration("latency", elapsed),
)
sql, rows := fc()
if rows == -1 {
fields = append(fields, zap.String("rows", "-"))
} else {
fields = append(fields, zap.Int64("rows", rows))
}
fields = append(fields, zap.String("sql", sql))
l.log.Error("", fields...)
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= logger.Warn:
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
fields = append(fields,
zap.Error(err),
zap.String("file", utils.FileWithLineNum()),
zap.String("slow!!!", fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)),
zap.Duration("latency", elapsed),
)
sql, rows := fc()
if rows == -1 {
fields = append(fields, zap.String("rows", "-"))
} else {
fields = append(fields, zap.Int64("rows", rows))
}
fields = append(fields, zap.String("sql", sql))
l.log.Warn("", fields...)
case l.LogLevel == logger.Info:
for _, customField := range l.customFields {
fields = append(fields, customField(ctx))
}
fields = append(fields,
zap.Error(err),
zap.String("file", utils.FileWithLineNum()),
zap.Duration("latency", elapsed),
)
sql, rows := fc()
if rows == -1 {
fields = append(fields, zap.String("rows", "-"))
} else {
fields = append(fields, zap.Int64("rows", rows))
}
fields = append(fields, zap.String("sql", sql))
l.log.Info("", fields...)
}
}
// Immutable custom immutable field
// Deprecated: use Any instead
func Immutable(key string, value interface{}) func(ctx context.Context) zap.Field {
return Any(key, value)
}
// Any custom immutable any field
func Any(key string, value interface{}) func(ctx context.Context) zap.Field {
field := zap.Any(key, value)
return func(ctx context.Context) zap.Field { return field }
}
// String custom immutable string field
func String(key string, value string) func(ctx context.Context) zap.Field {
field := zap.String(key, value)
return func(ctx context.Context) zap.Field { return field }
}
// Int64 custom immutable int64 field
func Int64(key string, value int64) func(ctx context.Context) zap.Field {
field := zap.Int64(key, value)
return func(ctx context.Context) zap.Field { return field }
}
// Uint64 custom immutable uint64 field
func Uint64(key string, value uint64) func(ctx context.Context) zap.Field {
field := zap.Uint64(key, value)
return func(ctx context.Context) zap.Field { return field }
}
// Float64 custom immutable float32 field
func Float64(key string, value float64) func(ctx context.Context) zap.Field {
field := zap.Float64(key, value)
return func(ctx context.Context) zap.Field { return field }
}
自定义结构体
// Logger logger for gorm2
type Logger struct {
log *zap.Logger
logger.Config
customFields []func(ctx context.Context) zap.Field
}
关键在于 customFields定义了一个接受传Context参数的方法。在初始化日志的地方,传从Context中获取对应参数的函数,比如,从context中接受traceId。
由此,gorm log with traceId目的实现。