package channels
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"your-project/plugin"
)
// 百度文心一言渠道插件
type BaiduErniePlugin struct {
name string
version string
description string
status plugin.PluginStatus
apiKey string
secretKey string
baseURL string
client *http.Client
accessToken string
tokenExpiry time.Time
}
// 创建百度文心一言插件
func NewBaiduErniePlugin() *BaiduErniePlugin {
return &BaiduErniePlugin{
name: "baidu-ernie",
version: "1.0.0",
description: "Baidu ERNIE channel plugin",
status: plugin.PluginStatusStopped,
baseURL: "https://aip.baidubce.com",
client: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// 插件名称
func (bp *BaiduErniePlugin) Name() string {
return bp.name
}
// 插件版本
func (bp *BaiduErniePlugin) Version() string {
return bp.version
}
// 插件描述
func (bp *BaiduErniePlugin) Description() string {
return bp.description
}
// 初始化插件
func (bp *BaiduErniePlugin) Initialize(config map[string]interface{}) error {
apiKey, ok := config["api_key"].(string)
if !ok || apiKey == "" {
return fmt.Errorf("api_key is required")
}
secretKey, ok := config["secret_key"].(string)
if !ok || secretKey == "" {
return fmt.Errorf("secret_key is required")
}
bp.apiKey = apiKey
bp.secretKey = secretKey
if baseURL, ok := config["base_url"].(string); ok && baseURL != "" {
bp.baseURL = baseURL
}
return nil
}
// 启动插件
func (bp *BaiduErniePlugin) Start(ctx context.Context) error {
bp.status = plugin.PluginStatusStarting
// 获取访问令牌
if err := bp.refreshAccessToken(); err != nil {
bp.status = plugin.PluginStatusError
return fmt.Errorf("failed to get access token: %v", err)
}
bp.status = plugin.PluginStatusRunning
// 启动令牌刷新协程
go bp.tokenRefreshLoop(ctx)
return nil
}
// 停止插件
func (bp *BaiduErniePlugin) Stop() error {
bp.status = plugin.PluginStatusStopped
return nil
}
// 插件状态
func (bp *BaiduErniePlugin) Status() plugin.PluginStatus {
return bp.status
}
// 支持的模型列表
func (bp *BaiduErniePlugin) SupportedModels() []string {
return []string{
"ernie-bot",
"ernie-bot-turbo",
"ernie-bot-4",
"ernie-3.5-8k",
"ernie-3.5-8k-0205",
"ernie-3.5-4k-0205",
}
}
// 发送聊天请求
func (bp *BaiduErniePlugin) ChatCompletion(ctx context.Context, request *plugin.ChatRequest) (*plugin.ChatResponse, error) {
if bp.status != plugin.PluginStatusRunning {
return nil, fmt.Errorf("plugin is not running")
}
// 检查访问令牌是否过期
if time.Now().After(bp.tokenExpiry) {
if err := bp.refreshAccessToken(); err != nil {
return nil, fmt.Errorf("failed to refresh access token: %v", err)
}
}
// 构建请求
baiduRequest := bp.buildBaiduRequest(request)
// 发送请求
url := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s",
bp.baseURL, bp.getModelEndpoint(request.Model), bp.accessToken)
reqBody, err := json.Marshal(baiduRequest)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %v", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := bp.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(respBody))
}
// 解析响应
var baiduResponse BaiduChatResponse
if err := json.Unmarshal(respBody, &baiduResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %v", err)
}
// 转换为标准响应格式
return bp.convertResponse(&baiduResponse, request), nil
}
// 计算配额
func (bp *BaiduErniePlugin) CalculateQuota(request *plugin.ChatRequest, response *plugin.ChatResponse) int64 {
// 百度文心一言按token计费
return int64(response.Usage.TotalTokens)
}
// 健康检查
func (bp *BaiduErniePlugin) HealthCheck(ctx context.Context) error {
if bp.status != plugin.PluginStatusRunning {
return fmt.Errorf("plugin is not running")
}
// 发送简单的测试请求
testRequest := &plugin.ChatRequest{
Model: "ernie-bot",
Messages: []plugin.Message{
{
Role: "user",
Content: "Hello",
},
},
MaxTokens: 10,
}
_, err := bp.ChatCompletion(ctx, testRequest)
return err
}
// 百度请求结构
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
MaxTokens int `json:"max_output_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type BaiduMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// 百度响应结构
type BaiduChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage BaiduUsage `json:"usage"`
ErrorCode int `json:"error_code,omitempty"`
ErrorMsg string `json:"error_msg,omitempty"`
}
type BaiduUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// 访问令牌响应
type AccessTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
Error string `json:"error,omitempty"`
ErrorDesc string `json:"error_description,omitempty"`
}
// 刷新访问令牌
func (bp *BaiduErniePlugin) refreshAccessToken() error {
url := fmt.Sprintf("%s/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
bp.baseURL, bp.apiKey, bp.secretKey)
resp, err := bp.client.Post(url, "application/json", nil)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
var tokenResp AccessTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return err
}
if tokenResp.Error != "" {
return fmt.Errorf("failed to get access token: %s - %s", tokenResp.Error, tokenResp.ErrorDesc)
}
bp.accessToken = tokenResp.AccessToken
bp.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // 提前5分钟刷新
return nil
}
// 令牌刷新循环
func (bp *BaiduErniePlugin) tokenRefreshLoop(ctx context.Context) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if time.Now().Add(10*time.Minute).After(bp.tokenExpiry) {
if err := bp.refreshAccessToken(); err != nil {
fmt.Printf("Failed to refresh access token: %v\n", err)
}
}
}
}
}
// 构建百度请求
func (bp *BaiduErniePlugin) buildBaiduRequest(request *plugin.ChatRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, len(request.Messages))
for i, msg := range request.Messages {
messages[i] = BaiduMessage{
Role: msg.Role,
Content: msg.Content,
}
}
return &BaiduChatRequest{
Messages: messages,
Temperature: request.Temperature,
MaxTokens: request.MaxTokens,
Stream: request.Stream,
}
}
// 转换响应
func (bp *BaiduErniePlugin) convertResponse(baiduResp *BaiduChatResponse, request *plugin.ChatRequest) *plugin.ChatResponse {
return &plugin.ChatResponse{
ID: baiduResp.ID,
Object: "chat.completion",
Created: baiduResp.Created,
Model: request.Model,
Choices: []plugin.Choice{
{
Index: 0,
Message: plugin.Message{
Role: "assistant",
Content: baiduResp.Result,
},
FinishReason: bp.getFinishReason(baiduResp.IsTruncated),
},
},
Usage: plugin.Usage{
PromptTokens: baiduResp.Usage.PromptTokens,
CompletionTokens: baiduResp.Usage.CompletionTokens,
TotalTokens: baiduResp.Usage.TotalTokens,
},
}
}
// 获取模型端点
func (bp *BaiduErniePlugin) getModelEndpoint(model string) string {
endpoints := map[string]string{
"ernie-bot": "completions",
"ernie-bot-turbo": "eb-instant",
"ernie-bot-4": "completions_pro",
"ernie-3.5-8k": "completions",
"ernie-3.5-8k-0205": "ernie_bot_8k",
"ernie-3.5-4k-0205": "completions",
}
if endpoint, ok := endpoints[model]; ok {
return endpoint
}
return "completions"
}
// 获取结束原因
func (bp *BaiduErniePlugin) getFinishReason(isTruncated bool) string {
if isTruncated {
return "length"
}
return "stop"
}