// middleware/benchmark_test.go
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
)
// 基准测试:基础中间件
func BenchmarkBasicMiddleware(b *testing.B) {
gin.SetMode(gin.ReleaseMode)
router := gin.New()
router.Use(gin.Recovery())
router.GET("/test", func(c *gin.Context) {
c.JSON(200, gin.H{"message": "ok"})
})
req, _ := http.NewRequest("GET", "/test", nil)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
})
}
// 基准测试:日志中间件
func BenchmarkLoggerMiddleware(b *testing.B) {
gin.SetMode(gin.ReleaseMode)
router := gin.New()
router.Use(RequestLogger())
router.GET("/test", func(c *gin.Context) {
c.JSON(200, gin.H{"message": "ok"})
})
req, _ := http.NewRequest("GET", "/test", nil)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
})
}
// 基准测试:异步日志中间件
func BenchmarkAsyncLoggerMiddleware(b *testing.B) {
gin.SetMode(gin.ReleaseMode)
pool := NewWorkerPool(10, 1000)
router := gin.New()
router.Use(AsyncLogger(pool))
router.GET("/test", func(c *gin.Context) {
c.JSON(200, gin.H{"message": "ok"})
})
req, _ := http.NewRequest("GET", "/test", nil)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
})
}
// 基准测试:缓存中间件
func BenchmarkCacheMiddleware(b *testing.B) {
gin.SetMode(gin.ReleaseMode)
router := gin.New()
router.Use(ResponseCache(5 * time.Minute))
router.GET("/test", func(c *gin.Context) {
c.JSON(200, gin.H{"message": "ok"})
})
req, _ := http.NewRequest("GET", "/test", nil)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
})
}
// 性能测试报告
func TestMiddlewarePerformance(t *testing.T) {
middlewares := map[string]gin.HandlerFunc{
"Recovery": gin.Recovery(),
"Logger": RequestLogger(),
"AsyncLogger": AsyncLogger(NewWorkerPool(10, 1000)),
"RateLimit": GlobalAPIRateLimit(),
"Cache": ResponseCache(5 * time.Minute),
"Security": SecureHeaders(),
"CORS": CORS(),
}
for name, middleware := range middlewares {
t.Run(name, func(t *testing.T) {
result := testing.Benchmark(func(b *testing.B) {
gin.SetMode(gin.ReleaseMode)
router := gin.New()
router.Use(middleware)
router.GET("/test", func(c *gin.Context) {
c.JSON(200, gin.H{"message": "ok"})
})
req, _ := http.NewRequest("GET", "/test", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
})
t.Logf("%s: %d ns/op, %d allocs/op",
name, result.NsPerOp(), result.AllocsPerOp())
})
}
}
// 内存使用测试
func TestMiddlewareMemoryUsage(t *testing.T) {
gin.SetMode(gin.ReleaseMode)
// 测试不同中间件组合的内存使用
testCases := []struct {
name string
middlewares []gin.HandlerFunc
}{
{
name: "Basic",
middlewares: []gin.HandlerFunc{gin.Recovery()},
},
{
name: "Full",
middlewares: []gin.HandlerFunc{
gin.Recovery(),
RequestLogger(),
GlobalAPIRateLimit(),
SecureHeaders(),
CORS(),
},
},
{
name: "Optimized",
middlewares: []gin.HandlerFunc{
gin.Recovery(),
AsyncLogger(NewWorkerPool(10, 1000)),
GlobalAPIRateLimit(),
SecureHeaders(),
CORS(),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
router := gin.New()
for _, mw := range tc.middlewares {
router.Use(mw)
}
router.GET("/test", func(c *gin.Context) {
c.JSON(200, gin.H{"message": "ok"})
})
// 执行多次请求并监控内存
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
before := memStats.Alloc
for i := 0; i < 1000; i++ {
req, _ := http.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
}
runtime.ReadMemStats(&memStats)
after := memStats.Alloc
t.Logf("%s: Memory used: %d bytes", tc.name, after-before)
})
}
}
c.Header("Cache-Control", fmt.Sprintf("max-age=%d", int(config.TTL.Seconds())))
}
}
}
// 缓存响应结构
type CachedResponse struct {
StatusCode int
ContentType string
Body []byte
Timestamp time.Time
}
// 缓存响应写入器
type cacheResponseWriter struct {
gin.ResponseWriter
body []byte
}
func (w *cacheResponseWriter) Write(data []byte) (int, error) {
w.body = append(w.body, data...)
return w.ResponseWriter.Write(data)
}
// 生成缓存键
func generateCacheKey(c *gin.Context) string {
key := fmt.Sprintf("%s:%s", c.Request.Method, c.Request.URL.Path)
// 添加查询参数
if c.Request.URL.RawQuery != "" {
key += "?" + c.Request.URL.RawQuery
}
// 添加用户ID(如果已认证)
if userID := c.GetInt("user_id"); userID > 0 {
key += fmt.Sprintf(":user:%d", userID)
}
// 生成MD5哈希
hash := md5.Sum([]byte(key))
return "cache:" + hex.EncodeToString(hash[:])
}
// 从缓存获取响应
func getCachedResponse(key string) *CachedResponse {
if !common.RedisEnabled {
return nil
}
// 从Redis获取缓存数据
data, err := common.RDB.Get(context.Background(), key).Bytes()
if err != nil {
return nil
}
// 反序列化缓存数据
var response CachedResponse
if err := json.Unmarshal(data, &response); err != nil {
return nil
}
return &response
}
// 设置缓存响应
func setCachedResponse(key string, response *CachedResponse, ttl time.Duration) {
if !common.RedisEnabled {
return
}
// 序列化响应数据
data, err := json.Marshal(response)
if err != nil {
return
}
// 存储到Redis
common.RDB.Set(context.Background(), key, data, ttl)
}