第19章:New-API项目深度解析

本章阅读指引

  • 先读架构总览与数据流,再进入模块解析,最后看扩展点与性能瓶颈。

  • 与第4/5/6/7/8/10/11章逐段对照源码,形成“设计→实现→验证”的闭环。

图表与素材(Mermaid)

19.1 项目概述

19.1.1 New-API项目简介

New-API是一个开源的AI模型接口管理系统,旨在为企业和开发者提供统一的AI服务接口。该项目采用Go语言开发,具有高性能、高可用性和易扩展的特点。

核心功能特性

图1 New-API核心功能架构图

系统架构设计

微服务架构:New-API采用模块化设计,每个功能模块相对独立,便于维护和扩展。

中间件模式:通过Gin框架的中间件机制实现横切关注点,如认证、日志、限流等。

数据库抽象层:使用GORM作为ORM框架,支持多种数据库类型,提供统一的数据访问接口。

配置管理:采用环境变量和配置文件相结合的方式,支持不同环境的配置管理。

图2 New-API系统架构图

19.1.2 技术架构概览

package main

import (
    "context"
    "fmt"
    "log"
    "net/http"
    "os"
    "os/signal"
    "syscall"
    "time"
    
    "github.com/gin-gonic/gin"
    "github.com/songquanpeng/one-api/common"
    "github.com/songquanpeng/one-api/controller"
    "github.com/songquanpeng/one-api/middleware"
    "github.com/songquanpeng/one-api/model"
    "github.com/songquanpeng/one-api/router"
)

// 应用程序主结构
type Application struct {
    Server   *http.Server
    Router   *gin.Engine
    Config   *common.Config
    Database *model.Database
}

// 初始化应用程序
func NewApplication() *Application {
    app := &Application{
        Config: common.LoadConfig(),
    }
    
    // 初始化数据库
    app.Database = model.InitDB(app.Config.DatabaseURL)
    
    // 初始化路由
    app.Router = router.SetupRouter()
    
    // 应用中间件
    app.setupMiddlewares()
    
    // 设置路由
    app.setupRoutes()
    
    // 创建HTTP服务器
    app.Server = &http.Server{
        Addr:         fmt.Sprintf(":%d", app.Config.Port),
        Handler:      app.Router,
        ReadTimeout:  30 * time.Second,
        WriteTimeout: 30 * time.Second,
        IdleTimeout:  60 * time.Second,
    }
    
    return app
}

// 设置中间件
func (app *Application) setupMiddlewares() {
    // 跨域中间件
    app.Router.Use(middleware.CORS())
    
    // 日志中间件
    app.Router.Use(middleware.Logger())
    
    // 恢复中间件
    app.Router.Use(middleware.Recovery())
    
    // 限流中间件
    app.Router.Use(middleware.RateLimit())
    
    // 认证中间件(部分路由)
    // 在具体路由组中应用
}

// 设置路由
func (app *Application) setupRoutes() {
    // API版本分组
    v1 := app.Router.Group("/api/v1")
    {
        // 用户相关路由
        userGroup := v1.Group("/user")
        userGroup.Use(middleware.AuthRequired())
        {
            userGroup.GET("/profile", controller.GetUserProfile)
            userGroup.PUT("/profile", controller.UpdateUserProfile)
            userGroup.GET("/usage", controller.GetUserUsage)
        }
        
        // 令牌相关路由
        tokenGroup := v1.Group("/token")
        tokenGroup.Use(middleware.AuthRequired())
        {
            tokenGroup.GET("/", controller.GetTokens)
            tokenGroup.POST("/", controller.CreateToken)
            tokenGroup.PUT("/:id", controller.UpdateToken)
            tokenGroup.DELETE("/:id", controller.DeleteToken)
        }
        
        // 渠道相关路由
        channelGroup := v1.Group("/channel")
        channelGroup.Use(middleware.AdminRequired())
        {
            channelGroup.GET("/", controller.GetChannels)
            channelGroup.POST("/", controller.CreateChannel)
            channelGroup.PUT("/:id", controller.UpdateChannel)
            channelGroup.DELETE("/:id", controller.DeleteChannel)
            channelGroup.POST("/:id/test", controller.TestChannel)
        }
    }
    
    // OpenAI兼容接口
    openaiGroup := app.Router.Group("/v1")
    openaiGroup.Use(middleware.TokenAuth())
    {
        openaiGroup.POST("/chat/completions", controller.ChatCompletions)
        openaiGroup.POST("/completions", controller.Completions)
        openaiGroup.POST("/embeddings", controller.Embeddings)
        openaiGroup.GET("/models", controller.ListModels)
    }
}

// 启动应用程序
func (app *Application) Start() error {
    // 启动后台任务
    go app.startBackgroundTasks()
    
    // 启动HTTP服务器
    log.Printf("Server starting on port %d", app.Config.Port)
    
    // 优雅关闭
    go func() {
        if err := app.Server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
            log.Fatalf("Server failed to start: %v", err)
        }
    }()
    
    // 等待中断信号
    quit := make(chan os.Signal, 1)
    signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
    <-quit
    
    log.Println("Server shutting down...")
    
    // 优雅关闭
    ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
    defer cancel()
    
    return app.Server.Shutdown(ctx)
}

// 启动后台任务
func (app *Application) startBackgroundTasks() {
    // 定期清理过期令牌
    go app.cleanupExpiredTokens()
    
    // 定期统计使用情况
    go app.collectUsageStats()
    
    // 定期检查渠道状态
    go app.checkChannelHealth()
}

// 清理过期令牌
func (app *Application) cleanupExpiredTokens() {
    ticker := time.NewTicker(1 * time.Hour)
    defer ticker.Stop()
    
    for {
        select {
        case <-ticker.C:
            if err := model.CleanupExpiredTokens(); err != nil {
                log.Printf("Failed to cleanup expired tokens: %v", err)
            }
        }
    }
}

// 收集使用统计
func (app *Application) collectUsageStats() {
    ticker := time.NewTicker(5 * time.Minute)
    defer ticker.Stop()
    
    for {
        select {
        case <-ticker.C:
            if err := model.CollectUsageStats(); err != nil {
                log.Printf("Failed to collect usage stats: %v", err)
            }
        }
    }
}

// 检查渠道健康状态
func (app *Application) checkChannelHealth() {
    ticker := time.NewTicker(10 * time.Minute)
    defer ticker.Stop()
    
    for {
        select {
        case <-ticker.C:
            if err := model.CheckChannelHealth(); err != nil {
                log.Printf("Failed to check channel health: %v", err)
            }
        }
    }
}

func main() {
    // 创建应用程序实例
    app := NewApplication()
    
    // 启动应用程序
    if err := app.Start(); err != nil {
        log.Fatalf("Failed to start application: %v", err)
    }
}

19.2 核心模块分析

19.2.1 用户管理模块

用户管理流程

图3 用户管理流程图

核心概念解析

用户角色(Role)

  • RoleGuestUser (0):访客用户,权限最低

  • RoleCommonUser (1):普通用户,可使用基本功能

  • RoleAdminUser (10):管理员用户,可管理渠道和用户

  • RoleRootUser (100):超级管理员,拥有所有权限

用户状态(Status)

  • UserStatusEnabled (1):正常状态,可正常使用

  • UserStatusDisabled (2):禁用状态,暂停使用

  • UserStatusDeleted (3):删除状态,软删除标记

配额管理(Quota)

  • Quota:用户总配额,单位通常为token数量

  • UsedQuota:已使用配额

  • RequestCount:请求次数统计

邀请机制(Invitation)

  • AffCode:用户专属邀请码

  • InviterId:邀请人ID,用于返佣计算

package model

import (
    "errors"
    "time"
    
    "golang.org/x/crypto/bcrypt"
    "gorm.io/gorm"
)

// 用户模型
type User struct {
    ID          int       `json:"id" gorm:"primaryKey"`
    Username    string    `json:"username" gorm:"uniqueIndex;not null"`
    Password    string    `json:"-" gorm:"not null"`
    Email       string    `json:"email" gorm:"uniqueIndex"`
    Role        int       `json:"role" gorm:"default:1"`
    Status      int       `json:"status" gorm:"default:1"`
    Quota       int64     `json:"quota" gorm:"default:0"`
    UsedQuota   int64     `json:"used_quota" gorm:"default:0"`
    RequestCount int64    `json:"request_count" gorm:"default:0"`
    Group       string    `json:"group" gorm:"default:'default'"`
    AffCode     string    `json:"aff_code"`
    InviterId   int       `json:"inviter_id"`
    CreatedTime int64     `json:"created_time"`
    AccessTime  int64     `json:"access_time"`
}

// 用户角色常量
const (
    RoleGuestUser  = 0
    RoleCommonUser = 1
    RoleAdminUser  = 10
    RoleRootUser   = 100
)

// 用户状态常量
const (
    UserStatusEnabled  = 1
    UserStatusDisabled = 2
    UserStatusDeleted  = 3
)

// 用户服务
type UserService struct {
    db *gorm.DB
}

func NewUserService(db *gorm.DB) *UserService {
    return &UserService{db: db}
}

// 创建用户
func (us *UserService) CreateUser(user *User) error {
    // 检查用户名是否已存在
    var existingUser User
    if err := us.db.Where("username = ?", user.Username).First(&existingUser).Error; err == nil {
        return errors.New("username already exists")
    }
    
    // 检查邮箱是否已存在
    if user.Email != "" {
        if err := us.db.Where("email = ?", user.Email).First(&existingUser).Error; err == nil {
            return errors.New("email already exists")
        }
    }
    
    // 加密密码
    hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
    if err != nil {
        return err
    }
    user.Password = string(hashedPassword)
    
    // 设置创建时间
    user.CreatedTime = time.Now().Unix()
    
    // 生成邀请码
    user.AffCode = generateAffCode()
    
    return us.db.Create(user).Error
}

// 验证用户登录
func (us *UserService) ValidateUser(username, password string) (*User, error) {
    var user User
    err := us.db.Where("username = ? AND status = ?", username, UserStatusEnabled).First(&user).Error
    if err != nil {
        return nil, errors.New("invalid username or password")
    }
    
    // 验证密码
    err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
    if err != nil {
        return nil, errors.New("invalid username or password")
    }
    
    // 更新访问时间
    user.AccessTime = time.Now().Unix()
    us.db.Save(&user)
    
    return &user, nil
}

// 获取用户信息
func (us *UserService) GetUserByID(id int) (*User, error) {
    var user User
    err := us.db.First(&user, id).Error
    return &user, err
}

// 更新用户信息
func (us *UserService) UpdateUser(user *User) error {
    return us.db.Save(user).Error
}

// 删除用户
func (us *UserService) DeleteUser(id int) error {
    return us.db.Model(&User{}).Where("id = ?", id).Update("status", UserStatusDeleted).Error
}

// 获取用户列表
func (us *UserService) GetUsers(page, pageSize int, keyword string) ([]*User, int64, error) {
    var users []*User
    var total int64
    
    query := us.db.Model(&User{}).Where("status != ?", UserStatusDeleted)
    
    if keyword != "" {
        query = query.Where("username LIKE ? OR email LIKE ?", "%"+keyword+"%", "%"+keyword+"%")
    }
    
    // 获取总数
    query.Count(&total)
    
    // 分页查询
    offset := (page - 1) * pageSize
    err := query.Offset(offset).Limit(pageSize).Find(&users).Error
    
    return users, total, err
}

// 更新用户配额
func (us *UserService) UpdateUserQuota(userID int, quota int64) error {
    return us.db.Model(&User{}).Where("id = ?", userID).Update("quota", quota).Error
}

// 消费用户配额
func (us *UserService) ConsumeUserQuota(userID int, amount int64) error {
    return us.db.Model(&User{}).Where("id = ? AND quota >= used_quota + ?", userID, amount).
        Updates(map[string]interface{}{
            "used_quota":    gorm.Expr("used_quota + ?", amount),
            "request_count": gorm.Expr("request_count + 1"),
        }).Error
}

// 检查用户配额
func (us *UserService) CheckUserQuota(userID int, amount int64) error {
    var user User
    err := us.db.Select("quota, used_quota").First(&user, userID).Error
    if err != nil {
        return err
    }
    
    if user.Quota > 0 && user.UsedQuota+amount > user.Quota {
        return errors.New("insufficient quota")
    }
    
    return nil
}

// 生成邀请码
func generateAffCode() string {
    // 实现邀请码生成逻辑
    return "aff_" + generateRandomString(8)
}

// 生成随机字符串
func generateRandomString(length int) string {
    const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
    b := make([]byte, length)
    for i := range b {
        b[i] = charset[rand.Intn(len(charset))]
    }
    return string(b)
}

19.2.2 令牌管理模块

令牌生命周期管理

图4 令牌生命周期时序图

核心概念解析

令牌状态(Token Status)

  • TokenStatusEnabled (1):启用状态,可正常使用

  • TokenStatusDisabled (2):禁用状态,管理员手动禁用

  • TokenStatusExpired (3):过期状态,超过有效期

  • TokenStatusExhausted (4):配额耗尽状态

配额类型(Quota Types)

  • RemainQuota:剩余配额,每次使用后递减

  • UsedQuota:已使用配额,累计统计

  • UnlimitedQuota:无限配额标志,true表示不限制使用

模型权限(Models)

  • 字符串格式存储,如"gpt-3.5-turbo,gpt-4,claude-3"

  • 控制令牌可访问的AI模型范围

  • 支持通配符匹配,如"gpt-*"表示所有GPT模型

安全机制

  • 令牌密钥采用加密随机生成,确保唯一性和安全性

  • 访问时间记录,用于审计和异常检测

  • 支持令牌过期时间设置,-1表示永不过期

package model

import (
    "crypto/rand"
    "encoding/hex"
    "errors"
    "strings"
    "time"
    
    "gorm.io/gorm"
)

// 令牌模型
type Token struct {
    ID             int    `json:"id" gorm:"primaryKey"`
    UserId         int    `json:"user_id" gorm:"index"`
    Key            string `json:"key" gorm:"uniqueIndex;not null"`
    Status         int    `json:"status" gorm:"default:1"`
    Name           string `json:"name"`
    CreatedTime    int64  `json:"created_time"`
    AccessedTime   int64  `json:"accessed_time"`
    ExpiredTime    int64  `json:"expired_time"`
    RemainQuota    int64  `json:"remain_quota" gorm:"default:0"`
    UsedQuota      int64  `json:"used_quota" gorm:"default:0"`
    UnlimitedQuota bool   `json:"unlimited_quota" gorm:"default:false"`
    Models         string `json:"models"`
}

// 令牌状态常量
const (
    TokenStatusEnabled  = 1
    TokenStatusDisabled = 2
    TokenStatusExpired  = 3
    TokenStatusExhausted = 4
)

// 令牌服务
type TokenService struct {
    db *gorm.DB
}

func NewTokenService(db *gorm.DB) *TokenService {
    return &TokenService{db: db}
}

// 创建令牌
func (ts *TokenService) CreateToken(token *Token) error {
    // 生成令牌密钥
    key, err := generateTokenKey()
    if err != nil {
        return err
    }
    
    token.Key = key
    token.CreatedTime = time.Now().Unix()
    token.Status = TokenStatusEnabled
    
    // 如果没有设置过期时间,默认永不过期
    if token.ExpiredTime == 0 {
        token.ExpiredTime = -1
    }
    
    return ts.db.Create(token).Error
}

// 验证令牌
func (ts *TokenService) ValidateToken(key string) (*Token, error) {
    var token Token
    err := ts.db.Where("key = ?", key).First(&token).Error
    if err != nil {
        return nil, errors.New("invalid token")
    }
    
    // 检查令牌状态
    if token.Status != TokenStatusEnabled {
        return nil, errors.New("token is disabled")
    }
    
    // 检查是否过期
    if token.ExpiredTime > 0 && time.Now().Unix() > token.ExpiredTime {
        // 更新令牌状态为过期
        ts.db.Model(&token).Update("status", TokenStatusExpired)
        return nil, errors.New("token is expired")
    }
    
    // 检查配额
    if !token.UnlimitedQuota && token.RemainQuota <= 0 {
        // 更新令牌状态为配额耗尽
        ts.db.Model(&token).Update("status", TokenStatusExhausted)
        return nil, errors.New("token quota exhausted")
    }
    
    // 更新访问时间
    token.AccessedTime = time.Now().Unix()
    ts.db.Save(&token)
    
    return &token, nil
}

// 消费令牌配额
func (ts *TokenService) ConsumeTokenQuota(tokenID int, amount int64) error {
    if amount <= 0 {
        return nil
    }
    
    // 使用事务确保数据一致性
    return ts.db.Transaction(func(tx *gorm.DB) error {
        var token Token
        err := tx.Where("id = ?", tokenID).First(&token).Error
        if err != nil {
            return err
        }
        
        // 检查配额
        if !token.UnlimitedQuota {
            if token.RemainQuota < amount {
                return errors.New("insufficient token quota")
            }
            
            // 更新配额
            err = tx.Model(&token).Updates(map[string]interface{}{
                "remain_quota": gorm.Expr("remain_quota - ?", amount),
                "used_quota":   gorm.Expr("used_quota + ?", amount),
            }).Error
            if err != nil {
                return err
            }
            
            // 检查是否配额耗尽
            if token.RemainQuota-amount <= 0 {
                tx.Model(&token).Update("status", TokenStatusExhausted)
            }
        } else {
            // 无限配额,只更新使用量
            err = tx.Model(&token).Update("used_quota", gorm.Expr("used_quota + ?", amount)).Error
            if err != nil {
                return err
            }
        }
        
        return nil
    })
}

// 检查令牌模型权限
func (ts *TokenService) CheckModelPermission(token *Token, model string) bool {
    if token.Models == "" {
        return true // 空字符串表示所有模型都可用
    }
    
    models := strings.Split(token.Models, ",")
    for _, m := range models {
        if strings.TrimSpace(m) == model {
            return true
        }
    }
    
    return false
}

// 获取用户令牌列表
func (ts *TokenService) GetUserTokens(userID int) ([]*Token, error) {
    var tokens []*Token
    err := ts.db.Where("user_id = ?", userID).Find(&tokens).Error
    return tokens, err
}

// 更新令牌
func (ts *TokenService) UpdateToken(token *Token) error {
    return ts.db.Save(token).Error
}

// 删除令牌
func (ts *TokenService) DeleteToken(id int, userID int) error {
    return ts.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Token{}).Error
}

// 清理过期令牌
func (ts *TokenService) CleanupExpiredTokens() error {
    now := time.Now().Unix()
    return ts.db.Model(&Token{}).
        Where("expired_time > 0 AND expired_time < ? AND status = ?", now, TokenStatusEnabled).
        Update("status", TokenStatusExpired).Error
}

// 生成令牌密钥
func generateTokenKey() (string, error) {
    bytes := make([]byte, 32)
    if _, err := rand.Read(bytes); err != nil {
        return "", err
    }
    return "sk-" + hex.EncodeToString(bytes), nil
}

19.2.3 渠道管理模块

package model

import (
    "encoding/json"
    "errors"
    "fmt"
    "net/http"
    "strings"
    "time"
    
    "gorm.io/gorm"
)

// 渠道模型
type Channel struct {
    ID                 int     `json:"id" gorm:"primaryKey"`
    Type               int     `json:"type" gorm:"default:1"`
    Key                string  `json:"key"`
    Status             int     `json:"status" gorm:"default:1"`
    Name               string  `json:"name"`
    Weight             *uint   `json:"weight" gorm:"default:0"`
    CreatedTime        int64   `json:"created_time"`
    TestTime           int64   `json:"test_time"`
    ResponseTime       int     `json:"response_time"`
    BaseURL            *string `json:"base_url"`
    Other              string  `json:"other"`
    Balance            float64 `json:"balance"`
    BalanceUpdatedTime int64   `json:"balance_updated_time"`
    Models             string  `json:"models"`
    Group              string  `json:"group" gorm:"default:'default'"`
    UsedQuota          int64   `json:"used_quota" gorm:"default:0"`
    ModelMapping       *string `json:"model_mapping"`
    Priority           *int64  `json:"priority"`
    Config             string  `json:"config"`
}

// 渠道类型常量
const (
    ChannelTypeOpenAI     = 1
    ChannelTypeAPI2D      = 2
    ChannelTypeAzure      = 3
    ChannelTypeClaudeAPI  = 4
    ChannelTypeBard       = 5
    ChannelTypePaLM       = 6
    ChannelTypeZhipu      = 7
    ChannelTypeAli        = 8
    ChannelTypeBaidu      = 9
    ChannelTypeTencent    = 10
)

// 渠道状态常量
const (
    ChannelStatusUnknown  = 0
    ChannelStatusEnabled  = 1
    ChannelStatusDisabled = 2
    ChannelStatusAutoDisabled = 3
)

// 渠道服务
type ChannelService struct {
    db *gorm.DB
}

func NewChannelService(db *gorm.DB) *ChannelService {
    return &ChannelService{db: db}
}

// 创建渠道
func (cs *ChannelService) CreateChannel(channel *Channel) error {
    channel.CreatedTime = time.Now().Unix()
    channel.Status = ChannelStatusEnabled
    
    return cs.db.Create(channel).Error
}

// 获取可用渠道
func (cs *ChannelService) GetAvailableChannels(group string, model string) ([]*Channel, error) {
    var channels []*Channel
    
    query := cs.db.Where("status = ? AND (group = ? OR group = 'default')", 
        ChannelStatusEnabled, group)
    
    // 如果指定了模型,过滤支持该模型的渠道
    if model != "" {
        query = query.Where("models = '' OR models LIKE ?", "%"+model+"%")
    }
    
    err := query.Order("priority DESC, weight DESC").Find(&channels).Error
    return channels, err
}

// 选择最佳渠道
func (cs *ChannelService) SelectBestChannel(channels []*Channel) *Channel {
    if len(channels) == 0 {
        return nil
    }
    
    // 基于权重的加权随机选择
    totalWeight := uint(0)
    for _, channel := range channels {
        if channel.Weight != nil {
            totalWeight += *channel.Weight
        }
    }
    
    if totalWeight == 0 {
        // 如果没有设置权重,随机选择
        return channels[rand.Intn(len(channels))]
    }
    
    // 加权随机选择
    randomWeight := rand.Intn(int(totalWeight))
    currentWeight := uint(0)
    
    for _, channel := range channels {
        if channel.Weight != nil {
            currentWeight += *channel.Weight
            if randomWeight < int(currentWeight) {
                return channel
            }
        }
    }
    
    return channels[0]
}

// 测试渠道
func (cs *ChannelService) TestChannel(channel *Channel) error {
    // 构建测试请求
    testRequest := map[string]interface{}{
        "model": "gpt-3.5-turbo",
        "messages": []map[string]string{
            {
                "role":    "user",
                "content": "Hello",
            },
        },
        "max_tokens": 10,
    }
    
    // 发送测试请求
    startTime := time.Now()
    err := cs.sendTestRequest(channel, testRequest)
    responseTime := int(time.Since(startTime).Milliseconds())
    
    // 更新测试时间和响应时间
    channel.TestTime = time.Now().Unix()
    channel.ResponseTime = responseTime
    
    if err != nil {
        // 测试失败,可能需要禁用渠道
        channel.Status = ChannelStatusAutoDisabled
    } else {
        // 测试成功,确保渠道启用
        if channel.Status == ChannelStatusAutoDisabled {
            channel.Status = ChannelStatusEnabled
        }
    }
    
    cs.db.Save(channel)
    return err
}

// 发送测试请求
func (cs *ChannelService) sendTestRequest(channel *Channel, request map[string]interface{}) error {
    // 根据渠道类型构建请求
    baseURL := "https://api.openai.com"
    if channel.BaseURL != nil && *channel.BaseURL != "" {
        baseURL = *channel.BaseURL
    }
    
    url := fmt.Sprintf("%s/v1/chat/completions", baseURL)
    
    // 构建HTTP请求
    client := &http.Client{
        Timeout: 30 * time.Second,
    }
    
    // 这里简化实现,实际应该根据渠道类型构建不同的请求
    req, err := http.NewRequest("POST", url, nil)
    if err != nil {
        return err
    }
    
    req.Header.Set("Authorization", "Bearer "+channel.Key)
    req.Header.Set("Content-Type", "application/json")
    
    resp, err := client.Do(req)
    if err != nil {
        return err
    }
    defer resp.Body.Close()
    
    if resp.StatusCode != http.StatusOK {
        return fmt.Errorf("test failed with status: %d", resp.StatusCode)
    }
    
    return nil
}

// 更新渠道使用量
func (cs *ChannelService) UpdateChannelUsage(channelID int, quota int64) error {
    return cs.db.Model(&Channel{}).Where("id = ?", channelID).
        Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
}

// 获取渠道统计信息
func (cs *ChannelService) GetChannelStats(channelID int, days int) (map[string]interface{}, error) {
    // 这里应该查询日志表获取统计信息
    // 简化实现
    stats := map[string]interface{}{
        "total_requests": 0,
        "total_tokens":   0,
        "success_rate":   0.0,
        "avg_response_time": 0,
    }
    
    return stats, nil
}

// 自动禁用异常渠道
func (cs *ChannelService) AutoDisableAbnormalChannels() error {
    // 查找最近测试失败的渠道
    var channels []*Channel
    err := cs.db.Where("status = ? AND test_time > 0", ChannelStatusEnabled).Find(&channels).Error
    if err != nil {
        return err
    }
    
    now := time.Now().Unix()
    for _, channel := range channels {
        // 如果渠道超过1小时没有测试成功,自动禁用
        if now-channel.TestTime > 3600 && channel.ResponseTime > 10000 {
            channel.Status = ChannelStatusAutoDisabled
            cs.db.Save(channel)
        }
    }
    
    return nil
}

19.2.4 请求转发模块

package controller

import (
    "bytes"
    "encoding/json"
    "fmt"
    "io"
    "net/http"
    "strconv"
    "strings"
    "time"
    
    "github.com/gin-gonic/gin"
    "github.com/songquanpeng/one-api/common"
    "github.com/songquanpeng/one-api/model"
)

// 聊天完成请求结构
type ChatCompletionRequest struct {
    Model            string                         `json:"model"`
    Messages         []ChatCompletionMessage        `json:"messages"`
    MaxTokens        *int                          `json:"max_tokens,omitempty"`
    Temperature      *float64                      `json:"temperature,omitempty"`
    TopP             *float64                      `json:"top_p,omitempty"`
    N                *int                          `json:"n,omitempty"`
    Stream           bool                          `json:"stream,omitempty"`
    Stop             interface{}                   `json:"stop,omitempty"`
    PresencePenalty  *float64                      `json:"presence_penalty,omitempty"`
    FrequencyPenalty *float64                      `json:"frequency_penalty,omitempty"`
    LogitBias        map[string]interface{}        `json:"logit_bias,omitempty"`
    User             string                        `json:"user,omitempty"`
}

type ChatCompletionMessage struct {
    Role    string `json:"role"`
    Content string `json:"content"`
    Name    string `json:"name,omitempty"`
}

// 聊天完成响应结构
type ChatCompletionResponse struct {
    ID      string                   `json:"id"`
    Object  string                   `json:"object"`
    Created int64                    `json:"created"`
    Model   string                   `json:"model"`
    Choices []ChatCompletionChoice   `json:"choices"`
    Usage   ChatCompletionUsage      `json:"usage"`
}

type ChatCompletionChoice struct {
    Index        int                    `json:"index"`
    Message      ChatCompletionMessage  `json:"message"`
    FinishReason string                 `json:"finish_reason"`
}

type ChatCompletionUsage struct {
    PromptTokens     int `json:"prompt_tokens"`
    CompletionTokens int `json:"completion_tokens"`
    TotalTokens      int `json:"total_tokens"`
}

// 请求转发服务
type RelayService struct {
    channelService *model.ChannelService
    tokenService   *model.TokenService
    userService    *model.UserService
}

func NewRelayService() *RelayService {
    return &RelayService{
        channelService: model.NewChannelService(model.DB),
        tokenService:   model.NewTokenService(model.DB),
        userService:    model.NewUserService(model.DB),
    }
}

// 聊天完成接口
func ChatCompletions(c *gin.Context) {
    relayService := NewRelayService()
    
    // 获取令牌信息
    token, exists := c.Get("token")
    if !exists {
        c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
        return
    }
    
    tokenInfo := token.(*model.Token)
    
    // 解析请求
    var request ChatCompletionRequest
    if err := c.ShouldBindJSON(&request); err != nil {
        c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request format"})
        return
    }
    
    // 验证模型权限
    if !relayService.tokenService.CheckModelPermission(tokenInfo, request.Model) {
        c.JSON(http.StatusForbidden, gin.H{"error": "model not allowed"})
        return
    }
    
    // 获取用户信息
    user, err := relayService.userService.GetUserByID(tokenInfo.UserId)
    if err != nil {
        c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get user info"})
        return
    }
    
    // 选择渠道
    channels, err := relayService.channelService.GetAvailableChannels(user.Group, request.Model)
    if err != nil || len(channels) == 0 {
        c.JSON(http.StatusServiceUnavailable, gin.H{"error": "no available channels"})
        return
    }
    
    channel := relayService.channelService.SelectBestChannel(channels)
    
    // 转发请求
    response, usage, err := relayService.forwardRequest(channel, &request)
    if err != nil {
        // 如果当前渠道失败,尝试其他渠道
        for _, fallbackChannel := range channels {
            if fallbackChannel.ID != channel.ID {
                response, usage, err = relayService.forwardRequest(fallbackChannel, &request)
                if err == nil {
                    channel = fallbackChannel
                    break
                }
            }
        }
        
        if err != nil {
            c.JSON(http.StatusBadGateway, gin.H{"error": "all channels failed"})
            return
        }
    }
    
    // 计算费用
    quota := relayService.calculateQuota(request.Model, usage)
    
    // 检查配额
    if err := relayService.tokenService.ConsumeTokenQuota(tokenInfo.ID, quota); err != nil {
        c.JSON(http.StatusPaymentRequired, gin.H{"error": "insufficient quota"})
        return
    }
    
    // 更新渠道使用量
    relayService.channelService.UpdateChannelUsage(channel.ID, quota)
    
    // 记录日志
    relayService.logRequest(tokenInfo, channel, &request, response, usage, quota)
    
    c.JSON(http.StatusOK, response)
}

// 转发请求到上游渠道
func (rs *RelayService) forwardRequest(channel *model.Channel, request *ChatCompletionRequest) (*ChatCompletionResponse, *ChatCompletionUsage, error) {
    // 构建上游请求URL
    baseURL := "https://api.openai.com"
    if channel.BaseURL != nil && *channel.BaseURL != "" {
        baseURL = *channel.BaseURL
    }
    
    url := fmt.Sprintf("%s/v1/chat/completions", baseURL)
    
    // 序列化请求
    requestBody, err := json.Marshal(request)
    if err != nil {
        return nil, nil, err
    }
    
    // 创建HTTP请求
    req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestBody))
    if err != nil {
        return nil, nil, err
    }
    
    // 设置请求头
    req.Header.Set("Content-Type", "application/json")
    req.Header.Set("Authorization", "Bearer "+channel.Key)
    
    // 根据渠道类型设置特定头部
    rs.setChannelSpecificHeaders(req, channel)
    
    // 发送请求
    client := &http.Client{
        Timeout: 60 * time.Second,
    }
    
    resp, err := client.Do(req)
    if err != nil {
        return nil, nil, err
    }
    defer resp.Body.Close()
    
    // 读取响应
    responseBody, err := io.ReadAll(resp.Body)
    if err != nil {
        return nil, nil, err
    }
    
    if resp.StatusCode != http.StatusOK {
        return nil, nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, string(responseBody))
    }
    
    // 解析响应
    var response ChatCompletionResponse
    if err := json.Unmarshal(responseBody, &response); err != nil {
        return nil, nil, err
    }
    
    return &response, &response.Usage, nil
}

// 设置渠道特定的请求头
func (rs *RelayService) setChannelSpecificHeaders(req *http.Request, channel *model.Channel) {
    switch channel.Type {
    case model.ChannelTypeAzure:
        req.Header.Set("api-key", channel.Key)
        req.Header.Del("Authorization")
    case model.ChannelTypeClaudeAPI:
        req.Header.Set("x-api-key", channel.Key)
        req.Header.Set("anthropic-version", "2023-06-01")
        req.Header.Del("Authorization")
    }
}

// 计算配额消耗
func (rs *RelayService) calculateQuota(model string, usage *ChatCompletionUsage) int64 {
    // 根据模型和使用量计算配额
    // 这里简化实现,实际应该根据不同模型的定价计算
    baseQuota := int64(usage.TotalTokens)
    
    // 不同模型的倍率
    multiplier := 1.0
    switch {
    case strings.Contains(model, "gpt-4"):
        multiplier = 20.0
    case strings.Contains(model, "gpt-3.5-turbo"):
        multiplier = 1.0
    case strings.Contains(model, "text-davinci"):
        multiplier = 10.0
    }
    
    return int64(float64(baseQuota) * multiplier)
}

// 记录请求日志
func (rs *RelayService) logRequest(token *model.Token, channel *model.Channel, request *ChatCompletionRequest, response *ChatCompletionResponse, usage *ChatCompletionUsage, quota int64) {
    // 创建日志记录
    log := &model.Log{
        UserId:           token.UserId,
        CreatedAt:        time.Now().Unix(),
        Type:             model.LogTypeChatCompletion,
        Content:          fmt.Sprintf("Model: %s, Tokens: %d", request.Model, usage.TotalTokens),
        Username:         "", // 需要从用户信息获取
        TokenName:        token.Name,
        ModelName:        request.Model,
        Quota:            quota,
        PromptTokens:     usage.PromptTokens,
        CompletionTokens: usage.CompletionTokens,
        ChannelId:        channel.ID,
    }
    
    // 异步保存日志
    go func() {
        model.CreateLog(log)
    }()
}

19.2.5 计费统计模块

package model

import (
    "time"
    
    "gorm.io/gorm"
)

// 日志模型
type Log struct {
    ID               int    `json:"id" gorm:"primaryKey"`
    UserId           int    `json:"user_id" gorm:"index"`
    CreatedAt        int64  `json:"created_at" gorm:"index"`
    Type             int    `json:"type" gorm:"index"`
    Content          string `json:"content"`
    Username         string `json:"username" gorm:"index"`
    TokenName        string `json:"token_name"`
    ModelName        string `json:"model_name" gorm:"index"`
    Quota            int64  `json:"quota"`
    PromptTokens     int    `json:"prompt_tokens"`
    CompletionTokens int    `json:"completion_tokens"`
    ChannelId        int    `json:"channel_id" gorm:"index"`
    RequestId        string `json:"request_id"`
    ResponseTime     int    `json:"response_time"`
}

// 日志类型常量
const (
    LogTypeChatCompletion = 1
    LogTypeCompletion     = 2
    LogTypeEmbedding      = 3
    LogTypeModeration     = 4
    LogTypeImage          = 5
    LogTypeAudio          = 6
)

// 统计服务
type StatService struct {
    db *gorm.DB
}

func NewStatService(db *gorm.DB) *StatService {
    return &StatService{db: db}
}

// 用户使用统计
type UserUsageStats struct {
    UserID           int     `json:"user_id"`
    Username         string  `json:"username"`
    TotalRequests    int64   `json:"total_requests"`
    TotalTokens      int64   `json:"total_tokens"`
    TotalQuota       int64   `json:"total_quota"`
    PromptTokens     int64   `json:"prompt_tokens"`
    CompletionTokens int64   `json:"completion_tokens"`
    LastRequestTime  int64   `json:"last_request_time"`
}

// 获取用户使用统计
func (ss *StatService) GetUserUsageStats(userID int, startTime, endTime int64) (*UserUsageStats, error) {
    var stats UserUsageStats
    
    query := `
        SELECT 
            user_id,
            username,
            COUNT(*) as total_requests,
            SUM(prompt_tokens + completion_tokens) as total_tokens,
            SUM(quota) as total_quota,
            SUM(prompt_tokens) as prompt_tokens,
            SUM(completion_tokens) as completion_tokens,
            MAX(created_at) as last_request_time
        FROM logs 
        WHERE user_id = ? AND created_at BETWEEN ? AND ?
        GROUP BY user_id, username
    `
    
    err := ss.db.Raw(query, userID, startTime, endTime).Scan(&stats).Error
    return &stats, err
}

// 模型使用统计
type ModelUsageStats struct {
    ModelName        string  `json:"model_name"`
    TotalRequests    int64   `json:"total_requests"`
    TotalTokens      int64   `json:"total_tokens"`
    TotalQuota       int64   `json:"total_quota"`
    AvgResponseTime  float64 `json:"avg_response_time"`
    UniqueUsers      int64   `json:"unique_users"`
}

// 获取模型使用统计
func (ss *StatService) GetModelUsageStats(startTime, endTime int64) ([]*ModelUsageStats, error) {
    var stats []*ModelUsageStats
    
    query := `
        SELECT 
            model_name,
            COUNT(*) as total_requests,
            SUM(prompt_tokens + completion_tokens) as total_tokens,
            SUM(quota) as total_quota,
            AVG(response_time) as avg_response_time,
            COUNT(DISTINCT user_id) as unique_users
        FROM logs 
        WHERE created_at BETWEEN ? AND ?
        GROUP BY model_name
        ORDER BY total_requests DESC
    `
    
    err := ss.db.Raw(query, startTime, endTime).Scan(&stats).Error
    return stats, err
}

// 渠道使用统计
type ChannelUsageStats struct {
    ChannelID        int     `json:"channel_id"`
    ChannelName      string  `json:"channel_name"`
    TotalRequests    int64   `json:"total_requests"`
    TotalTokens      int64   `json:"total_tokens"`
    TotalQuota       int64   `json:"total_quota"`
    SuccessRate      float64 `json:"success_rate"`
    AvgResponseTime  float64 `json:"avg_response_time"`
}

// 获取渠道使用统计
func (ss *StatService) GetChannelUsageStats(startTime, endTime int64) ([]*ChannelUsageStats, error) {
    var stats []*ChannelUsageStats
    
    query := `
        SELECT 
            l.channel_id,
            c.name as channel_name,
            COUNT(*) as total_requests,
            SUM(l.prompt_tokens + l.completion_tokens) as total_tokens,
            SUM(l.quota) as total_quota,
            AVG(l.response_time) as avg_response_time
        FROM logs l
        LEFT JOIN channels c ON l.channel_id = c.id
        WHERE l.created_at BETWEEN ? AND ?
        GROUP BY l.channel_id, c.name
        ORDER BY total_requests DESC
    `
    
    err := ss.db.Raw(query, startTime, endTime).Scan(&stats).Error
    return stats, err
}

// 每日统计数据
type DailyStats struct {
    Date             string  `json:"date"`
    TotalRequests    int64   `json:"total_requests"`
    TotalTokens      int64   `json:"total_tokens"`
    TotalQuota       int64   `json:"total_quota"`
    UniqueUsers      int64   `json:"unique_users"`
    AvgResponseTime  float64 `json:"avg_response_time"`
}

// 获取每日统计
func (ss *StatService) GetDailyStats(days int) ([]*DailyStats, error) {
    var stats []*DailyStats
    
    query := `
        SELECT 
            DATE(FROM_UNIXTIME(created_at)) as date,
            COUNT(*) as total_requests,
            SUM(prompt_tokens + completion_tokens) as total_tokens,
            SUM(quota) as total_quota,
            COUNT(DISTINCT user_id) as unique_users,
            AVG(response_time) as avg_response_time
        FROM logs 
        WHERE created_at >= UNIX_TIMESTAMP(DATE_SUB(NOW(), INTERVAL ? DAY))
        GROUP BY DATE(FROM_UNIXTIME(created_at))
        ORDER BY date DESC
    `
    
    err := ss.db.Raw(query, days).Scan(&stats).Error
    return stats, err
}

// 实时统计
type RealtimeStats struct {
    OnlineUsers      int64   `json:"online_users"`
    RequestsPerMin   int64   `json:"requests_per_min"`
    TokensPerMin     int64   `json:"tokens_per_min"`
    AvgResponseTime  float64 `json:"avg_response_time"`
    ErrorRate        float64 `json:"error_rate"`
}

// 获取实时统计
func (ss *StatService) GetRealtimeStats() (*RealtimeStats, error) {
    stats := &RealtimeStats{}
    
    // 获取最近1分钟的请求数
    oneMinuteAgo := time.Now().Unix() - 60
    err := ss.db.Model(&Log{}).Where("created_at >= ?", oneMinuteAgo).Count(&stats.RequestsPerMin).Error
    if err != nil {
        return nil, err
    }
    
    // 获取最近1分钟的token数
    var tokenSum struct {
        Total int64
    }
    err = ss.db.Model(&Log{}).Select("SUM(prompt_tokens + completion_tokens) as total").
        Where("created_at >= ?", oneMinuteAgo).Scan(&tokenSum).Error
    if err != nil {
        return nil, err
    }
    stats.TokensPerMin = tokenSum.Total
    
    // 获取平均响应时间
    var avgTime struct {
        Avg float64
    }
    err = ss.db.Model(&Log{}).Select("AVG(response_time) as avg").
        Where("created_at >= ?", oneMinuteAgo).Scan(&avgTime).Error
    if err != nil {
        return nil, err
    }
    stats.AvgResponseTime = avgTime.Avg
    
    return stats, nil
}

// 创建日志记录
func CreateLog(log *Log) error {
    return DB.Create(log).Error
}

// 清理旧日志
func (ss *StatService) CleanupOldLogs(days int) error {
    cutoffTime := time.Now().AddDate(0, 0, -days).Unix()
    return ss.db.Where("created_at < ?", cutoffTime).Delete(&Log{}).Error
}

19.3 监控与告警系统

19.3.1 性能监控

package monitor

import (
    "context"
    "fmt"
    "log"
    "runtime"
    "sync"
    "time"
    
    "github.com/prometheus/client_golang/prometheus"
    "github.com/prometheus/client_golang/prometheus/promauto"
)

// 监控指标
var (
    // HTTP请求指标
    httpRequestsTotal = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "http_requests_total",
            Help: "Total number of HTTP requests",
        },
        []string{"method", "endpoint", "status"},
    )
    
    httpRequestDuration = promauto.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "http_request_duration_seconds",
            Help:    "HTTP request duration in seconds",
            Buckets: prometheus.DefBuckets,
        },
        []string{"method", "endpoint"},
    )
    
    // 业务指标
    tokenUsageTotal = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "token_usage_total",
            Help: "Total token usage",
        },
        []string{"user_id", "model"},
    )
    
    channelRequestsTotal = promauto.NewCounterVec(
        prometheus.CounterOpts{
            Name: "channel_requests_total",
            Help: "Total requests per channel",
        },
        []string{"channel_id", "status"},
    )
    
    // 系统指标
    goroutinesCount = promauto.NewGauge(
        prometheus.GaugeOpts{
            Name: "goroutines_count",
            Help: "Number of goroutines",
        },
    )
    
    memoryUsage = promauto.NewGauge(
        prometheus.GaugeOpts{
            Name: "memory_usage_bytes",
            Help: "Memory usage in bytes",
        },
    )
)

// 监控服务
type MonitorService struct {
    ctx    context.Context
    cancel context.CancelFunc
    wg     sync.WaitGroup
}

func NewMonitorService() *MonitorService {
    ctx, cancel := context.WithCancel(context.Background())
    return &MonitorService{
        ctx:    ctx,
        cancel: cancel,
    }
}

// 启动监控
func (ms *MonitorService) Start() {
    ms.wg.Add(1)
    go ms.collectSystemMetrics()
    
    log.Println("Monitor service started")
}

// 停止监控
func (ms *MonitorService) Stop() {
    ms.cancel()
    ms.wg.Wait()
    log.Println("Monitor service stopped")
}

// 收集系统指标
func (ms *MonitorService) collectSystemMetrics() {
    defer ms.wg.Done()
    
    ticker := time.NewTicker(10 * time.Second)
    defer ticker.Stop()
    
    for {
        select {
        case <-ms.ctx.Done():
            return
        case <-ticker.C:
            // 收集Goroutine数量
            goroutinesCount.Set(float64(runtime.NumGoroutine()))
            
            // 收集内存使用情况
            var m runtime.MemStats
            runtime.ReadMemStats(&m)
            memoryUsage.Set(float64(m.Alloc))
        }
    }
}

// 记录HTTP请求指标
func RecordHTTPRequest(method, endpoint, status string, duration time.Duration) {
    httpRequestsTotal.WithLabelValues(method, endpoint, status).Inc()
    httpRequestDuration.WithLabelValues(method, endpoint).Observe(duration.Seconds())
}

// 记录Token使用指标
func RecordTokenUsage(userID, model string, tokens int64) {
    tokenUsageTotal.WithLabelValues(userID, model).Add(float64(tokens))
}

// 记录渠道请求指标
func RecordChannelRequest(channelID, status string) {
    channelRequestsTotal.WithLabelValues(channelID, status).Inc()
}

19.3.2 告警系统

package alert

import (
    "bytes"
    "encoding/json"
    "fmt"
    "log"
    "net/http"
    "time"
)

// 告警级别
type AlertLevel int

const (
    AlertLevelInfo AlertLevel = iota
    AlertLevelWarning
    AlertLevelError
    AlertLevelCritical
)

func (al AlertLevel) String() string {
    switch al {
    case AlertLevelInfo:
        return "INFO"
    case AlertLevelWarning:
        return "WARNING"
    case AlertLevelError:
        return "ERROR"
    case AlertLevelCritical:
        return "CRITICAL"
    default:
        return "UNKNOWN"
    }
}

// 告警消息
type AlertMessage struct {
    Level       AlertLevel `json:"level"`
    Title       string     `json:"title"`
    Description string     `json:"description"`
    Timestamp   time.Time  `json:"timestamp"`
    Source      string     `json:"source"`
    Tags        []string   `json:"tags"`
}

// 告警规则
type AlertRule struct {
    Name        string
    Condition   func() bool
    Level       AlertLevel
    Message     string
    Cooldown    time.Duration
    LastTriggered time.Time
}

// 告警服务
type AlertService struct {
    rules    []*AlertRule
    webhooks []string
    enabled  bool
}

func NewAlertService() *AlertService {
    return &AlertService{
        rules:   make([]*AlertRule, 0),
        enabled: true,
    }
}

// 添加告警规则
func (as *AlertService) AddRule(rule *AlertRule) {
    as.rules = append(as.rules, rule)
}

// 添加Webhook
func (as *AlertService) AddWebhook(url string) {
    as.webhooks = append(as.webhooks, url)
}

// 检查告警规则
func (as *AlertService) CheckRules() {
    if !as.enabled {
        return
    }
    
    for _, rule := range as.rules {
        // 检查冷却时间
        if time.Since(rule.LastTriggered) < rule.Cooldown {
            continue
        }
        
        // 检查条件
        if rule.Condition() {
            alert := &AlertMessage{
                Level:       rule.Level,
                Title:       rule.Name,
                Description: rule.Message,
                Timestamp:   time.Now(),
                Source:      "new-api",
                Tags:        []string{"monitoring"},
            }
            
            as.SendAlert(alert)
            rule.LastTriggered = time.Now()
        }
    }
}

// 发送告警
func (as *AlertService) SendAlert(alert *AlertMessage) {
    log.Printf("[ALERT] %s: %s - %s", alert.Level, alert.Title, alert.Description)
    
    // 发送到Webhook
    for _, webhook := range as.webhooks {
        go as.sendWebhook(webhook, alert)
    }
}

// 发送Webhook
func (as *AlertService) sendWebhook(url string, alert *AlertMessage) {
    payload, err := json.Marshal(alert)
    if err != nil {
        log.Printf("Failed to marshal alert: %v", err)
        return
    }
    
    resp, err := http.Post(url, "application/json", bytes.NewBuffer(payload))
    if err != nil {
        log.Printf("Failed to send webhook to %s: %v", url, err)
        return
    }
    defer resp.Body.Close()
    
    if resp.StatusCode != http.StatusOK {
        log.Printf("Webhook %s returned status %d", url, resp.StatusCode)
    }
}

// 预定义告警规则
func (as *AlertService) SetupDefaultRules() {
    // 高错误率告警
    as.AddRule(&AlertRule{
        Name: "High Error Rate",
        Condition: func() bool {
            // 这里应该查询实际的错误率
            return false // 简化实现
        },
        Level:    AlertLevelError,
        Message:  "Error rate is above 5%",
        Cooldown: 5 * time.Minute,
    })
    
    // 高响应时间告警
    as.AddRule(&AlertRule{
        Name: "High Response Time",
        Condition: func() bool {
            // 这里应该查询实际的响应时间
            return false // 简化实现
        },
        Level:    AlertLevelWarning,
        Message:  "Average response time is above 2 seconds",
        Cooldown: 5 * time.Minute,
    })
    
    // 渠道离线告警
    as.AddRule(&AlertRule{
        Name: "Channel Offline",
        Condition: func() bool {
            // 这里应该检查渠道状态
            return false // 简化实现
        },
        Level:    AlertLevelCritical,
        Message:  "One or more channels are offline",
        Cooldown: 10 * time.Minute,
    })
}

// 启动告警检查
func (as *AlertService) Start() {
    ticker := time.NewTicker(1 * time.Minute)
    go func() {
        for range ticker.C {
            as.CheckRules()
        }
    }()
}

19.4 项目部署与运维

19.4.1 部署架构设计

图5 New-API生产环境部署架构图

19.4.2 Docker部署

多阶段构建Dockerfile

# Dockerfile
FROM golang:1.21-alpine AS builder

# 设置工作目录
WORKDIR /app

# 安装必要的工具
RUN apk add --no-cache git ca-certificates tzdata

# 复制go mod文件
COPY go.mod go.sum ./

# 下载依赖
RUN go mod download

# 复制源代码
COPY . .

# 构建应用
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
    -ldflags='-w -s -extldflags "-static"' \
    -a -installsuffix cgo -o main .

# 运行阶段
FROM scratch

# 从builder阶段复制必要文件
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo
COPY --from=builder /app/main /app/main

# 设置时区
ENV TZ=Asia/Shanghai

# 暴露端口
EXPOSE 3000

# 健康检查
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
    CMD ["./app/main", "health"]

# 运行应用
ENTRYPOINT ["/app/main"]

Docker Compose配置

# docker-compose.yml
version: '3.8'

services:
  new-api:
    build: 
      context: .
      dockerfile: Dockerfile
    ports:
      - "3000:3000"
    environment:
      - DATABASE_URL=mysql://root:${MYSQL_ROOT_PASSWORD}@mysql:3306/new_api?charset=utf8mb4&parseTime=True&loc=Local
      - REDIS_URL=redis://redis:6379/0
      - SESSION_SECRET=${SESSION_SECRET}
      - JWT_SECRET=${JWT_SECRET}
      - LOG_LEVEL=info
      - GIN_MODE=release
    depends_on:
      mysql:
        condition: service_healthy
      redis:
        condition: service_healthy
    restart: unless-stopped
    volumes:
      - ./logs:/app/logs
      - ./config:/app/config:ro
    networks:
      - new-api-network
    deploy:
      resources:
        limits:
          cpus: '1.0'
          memory: 512M
        reservations:
          cpus: '0.5'
          memory: 256M

  mysql:
    image: mysql:8.0
    environment:
      - MYSQL_ROOT_PASSWORD=${MYSQL_ROOT_PASSWORD}
      - MYSQL_DATABASE=new_api
      - MYSQL_USER=${MYSQL_USER}
      - MYSQL_PASSWORD=${MYSQL_PASSWORD}
    ports:
      - "3306:3306"
    volumes:
      - mysql_data:/var/lib/mysql
      - ./init.sql:/docker-entrypoint-initdb.d/init.sql:ro
      - ./mysql.cnf:/etc/mysql/conf.d/mysql.cnf:ro
    restart: unless-stopped
    networks:
      - new-api-network
    healthcheck:
      test: ["CMD", "mysqladmin", "ping", "-h", "localhost"]
      timeout: 20s
      retries: 10

  redis:
    image: redis:7-alpine
    ports:
      - "6379:6379"
    volumes:
      - redis_data:/data
      - ./redis.conf:/usr/local/etc/redis/redis.conf:ro
    command: redis-server /usr/local/etc/redis/redis.conf
    restart: unless-stopped
    networks:
      - new-api-network
    healthcheck:
      test: ["CMD", "redis-cli", "ping"]
      interval: 30s
      timeout: 3s
      retries: 5

  nginx:
    image: nginx:alpine
    ports:
      - "80:80"
      - "443:443"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf:ro
      - ./ssl:/etc/nginx/ssl:ro
      - nginx_logs:/var/log/nginx
    depends_on:
      - new-api
    restart: unless-stopped
    networks:
      - new-api-network

  prometheus:
    image: prom/prometheus:latest
    ports:
      - "9090:9090"
    volumes:
      - ./prometheus.yml:/etc/prometheus/prometheus.yml:ro
      - prometheus_data:/prometheus
    command:
      - '--config.file=/etc/prometheus/prometheus.yml'
      - '--storage.tsdb.path=/prometheus'
      - '--web.console.libraries=/etc/prometheus/console_libraries'
      - '--web.console.templates=/etc/prometheus/consoles'
    restart: unless-stopped
    networks:
      - new-api-network

  grafana:
    image: grafana/grafana:latest
    ports:
      - "3001:3000"
    environment:
      - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_PASSWORD}
    volumes:
      - grafana_data:/var/lib/grafana
      - ./grafana/dashboards:/etc/grafana/provisioning/dashboards:ro
      - ./grafana/datasources:/etc/grafana/provisioning/datasources:ro
    restart: unless-stopped
    networks:
      - new-api-network

volumes:
  mysql_data:
  redis_data:
  prometheus_data:
  grafana_data:
  nginx_logs:

networks:
  new-api-network:
    driver: bridge

19.4.3 Kubernetes部署

部署流程图

图6 Kubernetes部署流程图

命名空间和配置

# namespace.yaml
apiVersion: v1
kind: Namespace
metadata:
  name: new-api
  labels:
    name: new-api
---
# configmap.yaml
apiVersion: v1
kind: ConfigMap
metadata:
  name: new-api-config
  namespace: new-api
data:
  app.yaml: |
    server:
      port: 3000
      mode: release
    database:
      max_idle_conns: 10
      max_open_conns: 100
      conn_max_lifetime: 3600
    redis:
      pool_size: 10
      min_idle_conns: 5
    log:
      level: info
      format: json

应用部署配置

# deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: new-api
  namespace: new-api
  labels:
    app: new-api
spec:
  replicas: 3
  strategy:
    type: RollingUpdate
    rollingUpdate:
      maxSurge: 1
      maxUnavailable: 0
  selector:
    matchLabels:
      app: new-api
  template:
    metadata:
      labels:
        app: new-api
      annotations:
        prometheus.io/scrape: "true"
        prometheus.io/port: "3000"
        prometheus.io/path: "/metrics"
    spec:
      containers:
      - name: new-api
        image: new-api:latest
        ports:
        - containerPort: 3000
          name: http
        env:
        - name: DATABASE_URL
          valueFrom:
            secretKeyRef:
              name: new-api-secrets
              key: database-url
        - name: REDIS_URL
          valueFrom:
            secretKeyRef:
              name: new-api-secrets
              key: redis-url
        - name: JWT_SECRET
          valueFrom:
            secretKeyRef:
              name: new-api-secrets
              key: jwt-secret
        volumeMounts:
        - name: config
          mountPath: /app/config
          readOnly: true
        resources:
          requests:
            memory: "256Mi"
            cpu: "250m"
          limits:
            memory: "512Mi"
            cpu: "500m"
        livenessProbe:
          httpGet:
            path: /health
            port: 3000
          initialDelaySeconds: 30
          periodSeconds: 10
          timeoutSeconds: 5
          failureThreshold: 3
        readinessProbe:
          httpGet:
            path: /ready
            port: 3000
          initialDelaySeconds: 5
          periodSeconds: 5
          timeoutSeconds: 3
          failureThreshold: 3
      volumes:
      - name: config
        configMap:
          name: new-api-config
---
apiVersion: v1
kind: Service
metadata:
  name: new-api-service
  namespace: new-api
  labels:
    app: new-api
spec:
  selector:
    app: new-api
  ports:
  - protocol: TCP
    port: 80
    targetPort: 3000
    name: http
  type: ClusterIP
---
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
  name: new-api-ingress
  namespace: new-api
  annotations:
    kubernetes.io/ingress.class: nginx
    cert-manager.io/cluster-issuer: letsencrypt-prod
    nginx.ingress.kubernetes.io/rate-limit: "100"
    nginx.ingress.kubernetes.io/rate-limit-window: "1m"
spec:
  tls:
  - hosts:
    - api.example.com
    secretName: new-api-tls
  rules:
  - host: api.example.com
    http:
      paths:
      - path: /
        pathType: Prefix
        backend:
          service:
            name: new-api-service
            port:
              number: 80

19.4.4 运维监控

监控指标体系

图7 监控指标体系架构图

日志管理系统

package logging

import (
    "context"
    "encoding/json"
    "fmt"
    "os"
    "time"
    
    "github.com/sirupsen/logrus"
    "gopkg.in/natefinch/lumberjack.v2"
)

// 日志级别
type LogLevel string

const (
    LogLevelDebug LogLevel = "debug"
    LogLevelInfo  LogLevel = "info"
    LogLevelWarn  LogLevel = "warn"
    LogLevelError LogLevel = "error"
    LogLevelFatal LogLevel = "fatal"
)

// 结构化日志字段
type LogFields map[string]interface{}

// 日志配置
type LogConfig struct {
    Level      LogLevel `yaml:"level"`
    Format     string   `yaml:"format"` // json, text
    Output     string   `yaml:"output"` // stdout, file
    Filename   string   `yaml:"filename"`
    MaxSize    int      `yaml:"max_size"`    // MB
    MaxBackups int      `yaml:"max_backups"`
    MaxAge     int      `yaml:"max_age"`     // days
    Compress   bool     `yaml:"compress"`
}

// 日志管理器
type Logger struct {
    logger *logrus.Logger
    config *LogConfig
}

// 创建日志管理器
func NewLogger(config *LogConfig) *Logger {
    logger := logrus.New()
    
    // 设置日志级别
    level, err := logrus.ParseLevel(string(config.Level))
    if err != nil {
        level = logrus.InfoLevel
    }
    logger.SetLevel(level)
    
    // 设置日志格式
    if config.Format == "json" {
        logger.SetFormatter(&logrus.JSONFormatter{
            TimestampFormat: time.RFC3339,
        })
    } else {
        logger.SetFormatter(&logrus.TextFormatter{
            FullTimestamp:   true,
            TimestampFormat: time.RFC3339,
        })
    }
    
    // 设置输出
    if config.Output == "file" && config.Filename != "" {
        logger.SetOutput(&lumberjack.Logger{
            Filename:   config.Filename,
            MaxSize:    config.MaxSize,
            MaxBackups: config.MaxBackups,
            MaxAge:     config.MaxAge,
            Compress:   config.Compress,
        })
    } else {
        logger.SetOutput(os.Stdout)
    }
    
    return &Logger{
        logger: logger,
        config: config,
    }
}

// 记录请求日志
func (l *Logger) LogRequest(ctx context.Context, method, path string, statusCode int, duration time.Duration, fields LogFields) {
    entry := l.logger.WithFields(logrus.Fields{
        "type":        "request",
        "method":      method,
        "path":        path,
        "status_code": statusCode,
        "duration_ms": duration.Milliseconds(),
        "timestamp":   time.Now().Format(time.RFC3339),
    })
    
    // 添加自定义字段
    for k, v := range fields {
        entry = entry.WithField(k, v)
    }
    
    // 根据状态码确定日志级别
    if statusCode >= 500 {
        entry.Error("HTTP request completed with server error")
    } else if statusCode >= 400 {
        entry.Warn("HTTP request completed with client error")
    } else {
        entry.Info("HTTP request completed successfully")
    }
}

// 记录业务日志
func (l *Logger) LogBusiness(level LogLevel, event string, fields LogFields) {
    entry := l.logger.WithFields(logrus.Fields{
        "type":      "business",
        "event":     event,
        "timestamp": time.Now().Format(time.RFC3339),
    })
    
    for k, v := range fields {
        entry = entry.WithField(k, v)
    }
    
    switch level {
    case LogLevelDebug:
        entry.Debug(event)
    case LogLevelInfo:
        entry.Info(event)
    case LogLevelWarn:
        entry.Warn(event)
    case LogLevelError:
        entry.Error(event)
    case LogLevelFatal:
        entry.Fatal(event)
    }
}

性能优化策略

图8 性能优化策略图

package optimization

import (
    "context"
    "runtime"
    "sync"
    "time"
)

// 连接池配置
type PoolConfig struct {
    MaxIdleConns    int           `yaml:"max_idle_conns"`
    MaxOpenConns    int           `yaml:"max_open_conns"`
    ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime"`
    ConnMaxIdleTime time.Duration `yaml:"conn_max_idle_time"`
}

// Goroutine池
type GoroutinePool struct {
    workers   chan chan func()
    jobQueue  chan func()
    quit      chan bool
    wg        sync.WaitGroup
    maxWorker int
}

// 创建Goroutine池
func NewGoroutinePool(maxWorker int, maxQueue int) *GoroutinePool {
    pool := &GoroutinePool{
        workers:   make(chan chan func(), maxWorker),
        jobQueue:  make(chan func(), maxQueue),
        quit:      make(chan bool),
        maxWorker: maxWorker,
    }
    
    pool.start()
    return pool
}

// 启动工作池
func (p *GoroutinePool) start() {
    for i := 0; i < p.maxWorker; i++ {
        worker := &Worker{
            workerPool: p.workers,
            jobChannel: make(chan func()),
            quit:       make(chan bool),
        }
        worker.start()
    }
    
    go p.dispatch()
}

// 分发任务
func (p *GoroutinePool) dispatch() {
    for {
        select {
        case job := <-p.jobQueue:
            go func(job func()) {
                jobChannel := <-p.workers
                jobChannel <- job
            }(job)
        case <-p.quit:
            return
        }
    }
}

// 提交任务
func (p *GoroutinePool) Submit(job func()) {
    p.jobQueue <- job
}

// 工作者
type Worker struct {
    workerPool chan chan func()
    jobChannel chan func()
    quit       chan bool
}

// 启动工作者
func (w *Worker) start() {
    go func() {
        for {
            w.workerPool <- w.jobChannel
            
            select {
            case job := <-w.jobChannel:
                job()
            case <-w.quit:
                return
            }
        }
    }()
}

故障排查与恢复

图9 故障排查与恢复流程图

package recovery

import (
    "context"
    "fmt"
    "log"
    "net/http"
    "time"
)

// 健康检查器
type HealthChecker struct {
    checks map[string]HealthCheck
    timeout time.Duration
}

// 健康检查接口
type HealthCheck interface {
    Check(ctx context.Context) error
    Name() string
}

// 数据库健康检查
type DatabaseHealthCheck struct {
    db interface{} // 数据库连接
}

func (dhc *DatabaseHealthCheck) Check(ctx context.Context) error {
    // 执行简单的数据库查询
    // return db.PingContext(ctx)
    return nil // 简化实现
}

func (dhc *DatabaseHealthCheck) Name() string {
    return "database"
}

// Redis健康检查
type RedisHealthCheck struct {
    client interface{} // Redis客户端
}

func (rhc *RedisHealthCheck) Check(ctx context.Context) error {
    // 执行Redis PING命令
    // return client.Ping(ctx).Err()
    return nil // 简化实现
}

func (rhc *RedisHealthCheck) Name() string {
    return "redis"
}

// 创建健康检查器
func NewHealthChecker(timeout time.Duration) *HealthChecker {
    return &HealthChecker{
        checks:  make(map[string]HealthCheck),
        timeout: timeout,
    }
}

// 添加健康检查
func (hc *HealthChecker) AddCheck(check HealthCheck) {
    hc.checks[check.Name()] = check
}

// 执行所有健康检查
func (hc *HealthChecker) CheckAll(ctx context.Context) map[string]error {
    results := make(map[string]error)
    
    for name, check := range hc.checks {
        ctx, cancel := context.WithTimeout(ctx, hc.timeout)
        err := check.Check(ctx)
        results[name] = err
        cancel()
    }
    
    return results
}

// HTTP健康检查端点
func (hc *HealthChecker) HealthHandler(w http.ResponseWriter, r *http.Request) {
    results := hc.CheckAll(r.Context())
    
    allHealthy := true
    for _, err := range results {
        if err != nil {
            allHealthy = false
            break
        }
    }
    
    if allHealthy {
        w.WriteHeader(http.StatusOK)
        fmt.Fprintf(w, "OK")
    } else {
        w.WriteHeader(http.StatusServiceUnavailable)
        fmt.Fprintf(w, "Service Unavailable")
        
        // 记录详细错误信息
        for name, err := range results {
            if err != nil {
                log.Printf("Health check failed for %s: %v", name, err)
            }
        }
    }
}

图10 系统监控指标架构图

Docker Compose监控配置

version: '3.8'

services:
  new-api:
    image: new-api:latest
    ports:
      - "3000:3000"
    environment:
      - DATABASE_URL=mysql://user:password@mysql:3306/newapi
      - REDIS_URL=redis://redis:6379
    depends_on:
      - mysql
      - redis
    restart: unless-stopped
    networks:
      - new-api-network

  nginx:
    image: nginx:alpine
    ports:
      - "80:80"
      - "443:443"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf
      - ./ssl:/etc/nginx/ssl
    depends_on:
      - new-api
    restart: unless-stopped
    networks:
      - new-api-network

  prometheus:
    image: prom/prometheus
    ports:
      - "9090:9090"
    volumes:
      - ./prometheus.yml:/etc/prometheus/prometheus.yml
      - prometheus_data:/prometheus
    command:
      - '--config.file=/etc/prometheus/prometheus.yml'
      - '--storage.tsdb.path=/prometheus'
      - '--web.console.libraries=/etc/prometheus/console_libraries'
      - '--web.console.templates=/etc/prometheus/consoles'
    restart: unless-stopped
    networks:
      - new-api-network

  grafana:
    image: grafana/grafana
    ports:
      - "3001:3000"
    environment:
      - GF_SECURITY_ADMIN_PASSWORD=admin
    volumes:
      - grafana_data:/var/lib/grafana
      - ./grafana/dashboards:/etc/grafana/provisioning/dashboards
      - ./grafana/datasources:/etc/grafana/provisioning/datasources
    restart: unless-stopped
    networks:
      - new-api-network

volumes:
  mysql_data:
  redis_data:
  prometheus_data:
  grafana_data:

networks:
  new-api-network:
    driver: bridge

19.4.5 Kubernetes生产配置

# k8s-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: new-api
  labels:
    app: new-api
spec:
  replicas: 3
  selector:
    matchLabels:
      app: new-api
  template:
    metadata:
      labels:
        app: new-api
    spec:
      containers:
      - name: new-api
        image: new-api:latest
        ports:
        - containerPort: 3000
        env:
        - name: DATABASE_URL
          valueFrom:
            secretKeyRef:
              name: new-api-secrets
              key: database-url
        - name: REDIS_URL
          valueFrom:
            secretKeyRef:
              name: new-api-secrets
              key: redis-url
        - name: JWT_SECRET
          valueFrom:
            secretKeyRef:
              name: new-api-secrets
              key: jwt-secret
        resources:
          requests:
            memory: "256Mi"
            cpu: "250m"
          limits:
            memory: "512Mi"
            cpu: "500m"
        livenessProbe:
          httpGet:
            path: /health
            port: 3000
          initialDelaySeconds: 30
          periodSeconds: 10
        readinessProbe:
          httpGet:
            path: /ready
            port: 3000
          initialDelaySeconds: 5
          periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
  name: new-api-service
spec:
  selector:
    app: new-api
  ports:
  - protocol: TCP
    port: 80
    targetPort: 3000
  type: LoadBalancer
---
apiVersion: v1
kind: Secret
metadata:
  name: new-api-secrets
type: Opaque
data:
  database-url: <base64-encoded-database-url>
  redis-url: <base64-encoded-redis-url>
  jwt-secret: <base64-encoded-jwt-secret>

19.5 本章小结

本章深入分析了New-API项目的核心架构和关键模块实现:

  1. 项目架构:采用分层架构设计,包含控制器层、服务层、数据访问层

  2. 用户管理:实现了完整的用户注册、登录、权限控制和配额管理

  3. 令牌管理:提供了安全的API密钥生成、验证和配额控制机制

  4. 渠道管理:支持多渠道接入、负载均衡和故障转移

  5. 请求转发:实现了高效的请求路由和协议转换

  6. 计费统计:提供了详细的使用统计和计费功能

  7. 监控告警:集成了Prometheus监控和自定义告警系统

  8. 部署运维:支持Docker、Kubernetes等多种部署方式

通过学习New-API项目,我们可以了解到企业级Go应用的最佳实践,包括代码组织、错误处理、性能优化、安全防护等方面的经验。

19.6 练习题

  1. 实现一个新的渠道类型支持(如百度文心一言)

  2. 添加用户组功能,支持不同组的用户使用不同的渠道

  3. 实现请求缓存功能,减少对上游API的调用

  4. 添加API调用频率限制功能

  5. 实现数据库读写分离

19.7 扩展阅读

项目源码与文档

  1. New-API项目资源

  2. AI模型接口标准

Go语言企业级开发

  1. Web框架与中间件

  2. 数据库与ORM

  3. 配置管理与日志

微服务架构与设计模式

  1. 架构设计理论

  2. Go微服务框架

  3. 服务治理与网格

监控与可观测性

  1. 监控系统建设

  2. 分布式追踪

  3. 日志聚合分析

容器化与云原生部署

  1. 容器技术

  2. Kubernetes部署

  3. DevOps与CI/CD

性能优化与安全实践

  1. Go语言性能优化

  2. 数据库性能调优

  3. API安全与认证

生产环境运维

  1. 监控告警实践

  2. 容量规划与扩缩容

  3. 故障处理与恢复

开源项目学习

  1. 类似项目参考

  2. Go语言优秀项目

通过本章的深度解析和这些扩展资源,你可以全面理解企业级Go应用的架构设计、实现方法和运维实践,为构建高质量的生产系统奠定坚实基础。

最后更新于

这有帮助吗?