在线时间:8:00-16:00
迪恩网络APP
随时随地掌握行业动态
扫描二维码
关注迪恩网络微信公众号
在以前的Go语言jaeger和opentracing 有用来做日志,但是很多时候我们希望数据库的操作也可以记录下来,程序一般作为http或者grpc 服务, 所以grpc和http也是需要用中间件来实现的。首先看程序的目录, 只是一个简单的demo: 因为程序最后会部署到k8s上,计划采用docker来收集,灌到elk或者graylog,所以这里直接输出,程序设计采用切换数据库 实现简单的saas。 来看看主要的几个文件 logger.go package logger import ( "context" "fmt" "io" "runtime" "strings" "time" "github.com/opentracing/opentracing-go" "github.com/uber/jaeger-client-go" "github.com/uber/jaeger-client-go/config" "github.com/uber/jaeger-client-go/log" "github.com/uber/jaeger-lib/metrics" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) var ( logTimeFormat = "2006-01-02T15:04:05.000+08:00" zapLogger *zap.Logger ) //配置默认初始化 func init() { c := zap.NewProductionConfig() c.EncoderConfig.LevelKey = "" c.EncoderConfig.CallerKey = "" c.EncoderConfig.MessageKey = "logModel" c.EncoderConfig.TimeKey = "" c.Level = zap.NewAtomicLevelAt(zap.DebugLevel) zapLogger, _ = c.Build() } //初始化 Jaeger client func NewJaegerTracer(serviceName string, agentHost string) (tracer opentracing.Tracer, closer io.Closer, err error) { cfg := config.Configuration{ ServiceName: serviceName, Sampler: &config.SamplerConfig{ Type: jaeger.SamplerTypeRateLimiting, Param: 10, }, Reporter: &config.ReporterConfig{ LogSpans: false, BufferFlushInterval: 1 * time.Second, LocalAgentHostPort: agentHost, }, } jLogger := log.StdLogger jMetricsFactory := metrics.NullFactory tracer, closer, err = cfg.NewTracer(config.Logger(jLogger), config.Metrics(jMetricsFactory)) if err == nil { opentracing.SetGlobalTracer(tracer) } return tracer, closer, err } func Error(ctx context.Context, format interface{}, args ...interface{}) { msg := "" if e, ok := format.(error); ok { msg = fmt.Sprintf(e.Error(), args...) } else if e, ok := format.(string); ok { msg = fmt.Sprintf(e, args...) } jsonStdOut(ctx, zap.ErrorLevel, msg) } func Warn(ctx context.Context, format string, args ...interface{}) { jsonStdOut(ctx, zap.WarnLevel, fmt.Sprintf(format, args...)) } func Info(ctx context.Context, format string, args ...interface{}) { jsonStdOut(ctx, zap.InfoLevel, fmt.Sprintf(format, args...)) } func Debug(ctx context.Context, format string, args ...interface{}) { jsonStdOut(ctx, zap.DebugLevel, fmt.Sprintf(format, args...)) } //本地打印 Json func jsonStdOut(ctx context.Context, level zapcore.Level, msg string) { traceId, spanId := getTraceId(ctx) if ce := zapLogger.Check(level, "zap"); ce != nil { ce.Write( zap.Any("message", JsonLogger{ LogTime: time.Now().Format(logTimeFormat), Level: level, Content: msg, CallPath: getCallPath(), TraceId: traceId, SpanId: spanId, }), ) } } type JsonLogger struct { TraceId string `json:"traceId"` SpanId uint64 `json:"spanId"` Content interface{} `json:"content"` CallPath interface{} `json:"callPath"` LogTime string `json:"logDate"` //日志时间 Level zapcore.Level `json:"level"` //日志级别 } func getTraceId(ctx context.Context) (string, uint64) { span := opentracing.SpanFromContext(ctx) if span == nil { return "", 0 } if sc, ok := span.Context().(jaeger.SpanContext); ok { return fmt.Sprintf("%v", sc.TraceID()), uint64(sc.SpanID()) } return "", 0 } func getCallPath() string { _, file, lineno, ok := runtime.Caller(2) if ok { return strings.Replace(fmt.Sprintf("%s:%d", stringTrim(file, ""), lineno), "%2e", ".", -1) } return "" } func stringTrim(s, cut string) string { ss := strings.SplitN(s, cut, 2) if len(ss) == 1 { return ss[0] } return ss[1] } db.go package db import ( "context" "database/sql/driver" "fmt" "net/url" "reflect" "regexp" "strings" "tracedemo/logger" "unicode" "github.com/jinzhu/gorm" "github.com/pkg/errors" "sync" "time" _ "github.com/go-sql-driver/mysql" "github.com/opentracing/opentracing-go" ) // DB连接配置信息 type Config struct { DbHost string DbPort int DbUser string DbPass string DbName string Debug bool } // 连接的数据库类型 const ( dbMaster string = "master" jaegerContextKey = "jeager:context" callbackPrefix = "jeager" startTime = "start:time" ) func init() { connMap = make(map[string]*gorm.DB) } var ( connMap map[string]*gorm.DB connLock sync.RWMutex ) // 初始化DB func InitDb(siteCode string, cfg *Config) (err error) { url := url.Values{} url.Add("parseTime", "True") url.Add("loc", "Local") url.Add("charset", "utf8mb4") url.Add("collation", "utf8mb4_unicode_ci") url.Add("readTimeout", "0s") url.Add("writeTimeout", "0s") url.Add("timeout", "0s") dsn := fmt.Sprintf("%s:%s@tcp(%s:%v)/%s?%s", cfg.DbUser, cfg.DbPass, cfg.DbHost, cfg.DbPort, cfg.DbName, url.Encode()) conn, err := gorm.Open("mysql", dsn) if err != nil { return errors.Wrap(err, "fail to connect db") } //新增gorm插件 if cfg.Debug == true { registerCallbacks(conn) } //打印日志 //conn.LogMode(true) conn.DB().SetMaxIdleConns(30) conn.DB().SetMaxOpenConns(200) conn.DB().SetConnMaxLifetime(60 * time.Second) if err := conn.DB().Ping(); err != nil { return errors.Wrap(err, "fail to ping db") } connLock.Lock() dbName := fmt.Sprintf("%s-%s", siteCode, dbMaster) connMap[dbName] = conn connLock.Unlock() go mysqlHeart(conn) return nil } func GetMaster(ctx context.Context) *gorm.DB { connLock.RLock() defer connLock.RUnlock() siteCode := fmt.Sprintf("%v", ctx.Value("SiteCode")) if strings.Contains(siteCode, "nil") { panic(errors.New("当前上下文没有找到DB")) } dbName := fmt.Sprintf("%s-%s", siteCode, dbMaster) ctx = context.WithValue(ctx, "DbName", dbName) db := connMap[dbName] if db == nil { panic(errors.New(fmt.Sprintf("当前上下文没有找到DB:%s", dbName))) } return db.Set(jaegerContextKey, ctx) } func mysqlHeart(conn *gorm.DB) { for { if conn != nil { err := conn.DB().Ping() if err != nil { fmt.Println(fmt.Sprintf("mysqlHeart has err:%v", err)) } } time.Sleep(3 * time.Minute) } } func registerCallbacks(db *gorm.DB) { driverName := db.Dialect().GetName() switch driverName { case "postgres": driverName = "postgresql" } spanTypePrefix := fmt.Sprintf("gorm.db.%s.", driverName) querySpanType := spanTypePrefix + "query" execSpanType := spanTypePrefix + "exec" type params struct { spanType string processor func() *gorm.CallbackProcessor } callbacks := map[string]params{ "gorm:create": { spanType: execSpanType, processor: func() *gorm.CallbackProcessor { return db.Callback().Create() }, }, "gorm:delete": { spanType: execSpanType, processor: func() *gorm.CallbackProcessor { return db.Callback().Delete() }, }, "gorm:query": { spanType: querySpanType, processor: func() *gorm.CallbackProcessor { return db.Callback().Query() }, }, "gorm:update": { spanType: execSpanType, processor: func() *gorm.CallbackProcessor { return db.Callback().Update() }, }, "gorm:row_query": { spanType: querySpanType, processor: func() *gorm.CallbackProcessor { return db.Callback().RowQuery() }, }, } for name, params := range callbacks { params.processor().Before(name).Register( fmt.Sprintf("%s:before:%s", callbackPrefix, name), newBeforeCallback(params.spanType), ) params.processor().After(name).Register( fmt.Sprintf("%s:after:%s", callbackPrefix, name), newAfterCallback(), ) } } func newBeforeCallback(spanType string) func(*gorm.Scope) { return func(scope *gorm.Scope) { ctx, ok := scopeContext(scope) if !ok { return } //新增链路追踪 span, ctx := opentracing.StartSpanFromContext(ctx, spanType) if span.Tracer() == nil { span.Finish() ctx = nil } scope.Set(jaegerContextKey, ctx) scope.Set(startTime, time.Now().UnixNano()) } } func newAfterCallback() func(*gorm.Scope) { return func(scope *gorm.Scope) { ctx, ok := scopeContext(scope) if !ok { return } span := opentracing.SpanFromContext(ctx) if span == nil { return } defer span.Finish() duration := int64(0) if t, ok := scopeStartTime(scope); ok { duration = (time.Now().UnixNano() - t) / 1e6 } logger.Debug(ctx, "[gorm] [%vms] [RowsReturned(%v)] %v ", duration, scope.DB().RowsAffected, gormSQL(scope.SQL, scope.SQLVars)) for _, err := range scope.DB().GetErrors() { if gorm.IsRecordNotFoundError(err) || err == errors.New("sql: no rows in result set") { continue } //打印错误日志 logger.Error(ctx, "%v", err.Error()) } //span.LogFields(traceLog.String("sql", scope.SQL)) } } func scopeContext(scope *gorm.Scope) (context.Context, bool) { value, ok := scope.Get(jaegerContextKey) if !ok { return nil, false } ctx, _ := value.(context.Context) return ctx, ctx != nil } func scopeStartTime(scope *gorm.Scope) (int64, bool) { value, ok := scope.Get(startTime) if !ok { return 0, false } t, ok := value.(int64) return t, ok } /*===============Log=======================================*/ var ( sqlRegexp = regexp.MustCompile(`\?`) numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`) ) func gormSQL(inputSql interface{}, value interface{}) string { var sql string var formattedValues []string for _, value := range value.([]interface{}) { indirectValue := reflect.Indirect(reflect.ValueOf(value)) if indirectValue.IsValid() { value = indirectValue.Interface() if t, ok := value.(time.Time); ok { if t.IsZero() { formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00")) } else { formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) } } else if b, ok := value.([]byte); ok { if str := string(b); isPrintable(str) { formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) } else { formattedValues = append(formattedValues, "'<binary>'") } } else if r, ok := value.(driver.Valuer); ok { if value, err := r.Value(); err == nil && value != nil { formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) } else { formattedValues = append(formattedValues, "NULL") } } else { switch value.(type) { case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: formattedValues = append(formattedValues, fmt.Sprintf("%v", value)) default: formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) } } } else { formattedValues = append(formattedValues, "NULL") } } if formattedValues == nil || len(formattedValues) < 1 { return sql } // differentiate between $n placeholders or else treat like ? if numericPlaceHolderRegexp.MatchString(inputSql.(string)) { sql = inputSql.(string) for index, value := range formattedValues { placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1) sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1") } } else { formattedValuesLength := len(formattedValues) for index, value := range sqlRegexp.Split(inputSql.(string), -1) { sql += value if index < formattedValuesLength { sql += formattedValues[index] } } } return sql } func isPrintable(s string) bool { for _, r := range s { if !unicode.IsPrint(r) { return false } } return true } server.go package apiserver import ( contextV2 "context" "fmt" "runtime/debug" "tracedemo/apiserver/userinfo" "tracedemo/logger" "github.com/kataras/iris/v12" "github.com/kataras/iris/v12/context" "github.com/opentracing/opentracing-go" ) func StartApiServerr() { addr := ":8080" app := iris.New() app.Use(openTracing()) app.Use(withSiteCode()) app.Use(withRecover()) app.Get("/", func(c context.Context) { c.WriteString("pong") }) initIris(app) logger.Info(contextV2.Background(), "[apiServer]开始监听%s,", addr) err := app.Run(iris.Addr(addr), iris.WithoutInterruptHandler) if err != nil { logger.Error(contextV2.Background(), "[apiServer]开始监听%s 错误%v,", addr,err) } } func initIris(app *iris.Application) { api:= userinfo.ApiServer{} userGroup := app.Party("/user") { userGroup.Get("/test",api.TestUserInfo) userGroup.Get("/rpc",api.TestRpc) } } func openTracing() context.Handler { return func(c iris.Context) { span := opentracing.GlobalTracer().StartSpan("apiServer") c.ResetRequest(c.Request().WithContext(opentracing.ContextWithSpan(c.Request().Context(), span))) logger.Info(c.Request().Context(), "Api请求地址%v", c.Request().URL) c.Next() } } func withSiteCode() context.Handler { return func(c iris.Context) { siteCode := c.GetHeader("SiteCode") if len(siteCode) < 1 { siteCode = "001" } ctx := contextV2.WithValue(c.Request().Context(), "SiteCode", siteCode) c.ResetRequest(c.Request().WithContext(ctx)) c.Next() } } func withRecover() context.Handler { return func(c iris.Context) { defer func() { if e := recover(); e != nil { stack := debug.Stack() logger.Error(c.Request().Context(), fmt.Sprintf("Api has err:%v, stack:%v", e, string(stack))) } }() c.Next() } } grpc的中间件middleware.go package middleware import ( "context" "encoding/json" "fmt" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "runtime/debug" "strings" "time" "tracedemo/logger" ) type MDCarrier struct { metadata.MD } func (m MDCarrier) ForeachKey(handler func(key, val string) error) error { for k, strs := range m.MD { for _, v := range strs { if err := handler(k, v); err != nil { return err } } } return nil } func (m MDCarrier) Set(key, val string) { m.MD[key] = append(m.MD[key], val) } // ClientInterceptor 客户端拦截器 func ClientTracing(tracer opentracing.Tracer) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, request, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { //一个RPC调用的服务端的span,和RPC服务客户端的span构成ChildOf关系 var parentCtx opentracing.SpanContext parentSpan := opentracing.SpanFromContext(ctx) if parentSpan != nil { parentCtx = parentSpan.Context() } span := tracer.StartSpan( method, opentracing.ChildOf(parentCtx), opentracing.Tag{Key: string(ext.Component), Value: "gRPC Client"}, ext.SpanKindRPCClient, ) defer span.Finish() md, ok := metadata.FromOutgoingContext(ctx) if !ok { md = metadata.New(nil) } else { md = md.Copy() } err := tracer.Inject( span.Context(), opentracing.TextMap, MDCarrier{md}, // 自定义 carrier ) if err != nil { logger.Error(ctx, "ClientTracing inject span error :%v", err.Error()) } ///SiteCode siteCode := fmt.Sprintf("%v", ctx.Value("SiteCode")) if len(siteCode) < 1 || strings.Contains(siteCode, "nil") { siteCode = "001" } md.Set("SiteCode", siteCode) // newCtx := metadata.NewOutgoingContext(ctx, md) err = invoker(newCtx, method, request, reply, cc, opts...) if err != nil { logger.Error(ctx, "ClientTracing call error : %v", err.Error()) } return err } } func ClientSiteCode() grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, request, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { md, ok := metadata.FromOutgoingContext(ctx) if !ok { md = metadata.New(nil) } else { md = md.Copy() } ///SiteCode siteCode := fmt.Sprintf("%v", ctx.Value("SiteCode")) if len(siteCode) < 1 || strings.Contains(siteCode, "nil") { siteCode = "< |
请发表评论