testFlow/api/doc/doc_processor.go
Wyle.Gong-巩文昕 67b0ad2723 init
2025-04-22 16:42:48 +08:00

620 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package doc
import (
"app/cfg"
M "app/models"
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/invopop/jsonschema"
"github.com/ledongthuc/pdf"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/xeipuuv/gojsonschema"
)
type Parameter struct {
Name string `json:"name" jsonschema:"required,description=参数名称"`
Location string `json:"location" jsonschema:"enum=query,enum=header,enum=body,enum=path,description=参数位置"`
Type string `json:"type,omitempty" jsonschema:"default=string,description=参数类型"`
ParamType string `json:"param_type,omitempty" jsonschema:"enum=application/json,enum=text/plain,enum=multipart/form-data,enum=application/x-www-form-urlencoded,enum=,description=只有position为body时有效,其他时候为空即可"`
Description string `json:"description" jsonschema:"required,description=参数描述"`
Value string `json:"value,omitempty" jsonschema:"description=参数示例值统一用字符串形式不要用json如果是对象需要转换成json字符串一定要是纯字符串绝对不要有字符串加字符串这种字符串间的拼接运算还要注意字符串中不要给{}[]加转义字符,这样会导致错误"`
}
type Response struct {
StatusCode int `json:"status_code" jsonschema:"required,description=HTTP状态码"`
Example string `json:"example" jsonschema:"description=响应示例是json字符串一定要是纯字符串绝对不要有字符串加字符串这种字符串间的拼接运算还要注意字符串中不要给{}[]加转义字符,这样会导致错误"`
}
// APISpec represents an API specification
type APISpec struct {
Name string `json:"name" jsonschema:"required,description=API名称"`
Description string `json:"description" jsonschema:"required,description=API描述"`
Inputs []Parameter `json:"inputs" jsonschema:"description=输入参数列表"`
Outputs []Parameter `json:"outputs" jsonschema:"description=输出参数列表"`
Method string `json:"method" jsonschema:"required,enum=GET,enum=POST,enum=PUT,enum=DELETE,enum=PATCH,description=HTTP方法"`
Path string `json:"path" jsonschema:"required,description=API路径"`
BodyType string `json:"body_type,omitempty" jsonschema:"default=application/json,description=请求体类型"`
Response Response `json:"response" jsonschema:"required,description=API响应信息"`
}
// APISpecList represents a list of API specifications
type APISpecList struct {
APIs []APISpec `json:"apis" jsonschema:"required,description=API列表"`
HasMore bool `json:"has_more" jsonschema:"required,description=是否还有更多API需要继续在下一次输出中继续输出"`
AnalysisPercent int `json:"analysis_percent" jsonschema:"required,description=估计的分析进度百分比范围是0-100不一定递增可以调整比如发现增长过快的时候可以减少只有全部解析完毕才可以到100"`
}
// ChatHistory represents the chat history
type ChatHistory struct {
Messages []openai.ChatCompletionMessageParamUnion
DocContent string
Schema string
}
// setupLogger creates and configures the logger
func setupLogger() *log.Logger {
if err := os.MkdirAll("logs", 0755); err != nil {
log.Fatal(err)
}
timestamp := time.Now().Format("20060102_150405")
logFilename := filepath.Join("logs", fmt.Sprintf("llm_output_%s.log", timestamp))
file, err := os.OpenFile(logFilename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
log.Fatal(err)
}
return log.New(file, "", log.LstdFlags)
}
// DocumentAnalyzer handles document analysis
type DocumentAnalyzer struct {
client *openai.Client
logger *log.Logger
}
// NewDocumentAnalyzer creates a new DocumentAnalyzer instance
func NewDocumentAnalyzer() (*DocumentAnalyzer, error) {
client := openai.NewClient(
option.WithAPIKey("sk-0213c70194624703a1d0d80e0f762b0e"),
option.WithBaseURL("https://dashscope.aliyuncs.com/compatible-mode/v1/"),
// option.WithBaseURL("http://127.0.0.1:11434/"),
)
fmt.Println("API client initialized successfully")
return &DocumentAnalyzer{
client: client,
logger: setupLogger(),
}, nil
}
// generateJSONSchema generates JSON schema from the APISpecList struct
func generateJSONSchema() ([]byte, error) {
reflector := jsonschema.Reflector{
RequiredFromJSONSchemaTags: true,
AllowAdditionalProperties: true,
DoNotReference: true,
}
schema := reflector.Reflect(&APISpecList{})
return json.MarshalIndent(schema, "", " ")
}
// validateJSON validates the JSON response against the schema
func validateJSON(data []byte) error {
schema, err := generateJSONSchema()
if err != nil {
return fmt.Errorf("failed to generate schema: %v", err)
}
schemaLoader := gojsonschema.NewBytesLoader(schema)
documentLoader := gojsonschema.NewBytesLoader(data)
result, err := gojsonschema.Validate(schemaLoader, documentLoader)
if err != nil {
return fmt.Errorf("validation error: %v", err)
}
if !result.Valid() {
var errors []string
for _, desc := range result.Errors() {
errors = append(errors, desc.String())
}
return fmt.Errorf("invalid JSON: %v", errors)
}
return nil
}
// cleanJSONResponse cleans and validates the LLM response
func cleanJSONResponse(response string) (string, error) {
// Find the first { and last }
start := 0
end := len(response)
for i := 0; i < len(response); i++ {
if response[i] == '{' {
start = i
break
}
}
for i := len(response) - 1; i >= 0; i-- {
if response[i] == '}' {
end = i + 1
break
}
}
if start >= end {
return "", fmt.Errorf("invalid JSON structure")
}
jsonStr := response[start:end]
// Validate the JSON
if err := validateJSON([]byte(jsonStr)); err != nil {
return "", fmt.Errorf("JSON validation failed: %v", err)
}
return jsonStr, nil
}
func (da *DocumentAnalyzer) repairJSON(ctx context.Context, malformedJSON string, originalError error) (string, error) {
prompt := fmt.Sprintf(`Fix this malformed JSON that had error: %v
JSON to fix:
%s
Return only the fixed JSON with no explanations.`, originalError, malformedJSON)
var responseBuilder strings.Builder
stream := da.client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage(prompt),
}),
Model: openai.F("qwen-plus"),
})
for stream.Next() {
chunk := stream.Current()
for _, choice := range chunk.Choices {
responseBuilder.WriteString(choice.Delta.Content)
}
}
if err := stream.Err(); err != nil {
return "", fmt.Errorf("failed to send message: %v", err)
}
fixed := responseBuilder.String()
if fixed == "" {
return "", fmt.Errorf("no repair response")
}
return fixed, nil
}
// extractAPIs extracts API information from document content
func (da *DocumentAnalyzer) extractAPIs(ctx context.Context, docContent string, doc *M.Doc) error {
schema, err := generateJSONSchema()
if err != nil {
return fmt.Errorf("failed to generate schema: %v", err)
}
da.logger.Printf("JSON Schema:\n%s", schema)
// 初始化对话历史
history := &ChatHistory{
Messages: make([]openai.ChatCompletionMessageParamUnion, 0),
}
// 初始化对话发送文档内容和schema
initialPrompt := fmt.Sprintf(`你是一个API文档分析助手。你的任务是从文档中提取API信息。
请从以下文档中提取所有API接口信息包括页面、功能操作名称、功能描述、请求类型、接口地址、输入和输出参数。
请确保完整提取每个接口的所有信息,不要遗漏或截断,如果有不确定的项,可以设置为空。
如果文档内容较多你可以分批次输出每次输出一部分API信息并设置has_more为true表示还有更多API需要在下一次输出。
当所有API都输出完成时设置has_more为false。
文档内容:
%s
JSON Schema:
%s
请严格按照schema格式输出,确保是有效的JSON格式。输出的JSON必须符合以上schema的规范。`, docContent, string(schema))
totalAPIs := 0 // 用于跟踪总共处理的API数量
maxRetries := 10
// 添加初始消息到历史记录
history.Messages = append(history.Messages, openai.UserMessage(initialPrompt))
maxInitRetries := 3
var responseBuilder strings.Builder
var streamErr error
// 发送初始消息并重试
for i := 0; i < maxInitRetries; i++ {
responseBuilder.Reset()
stream := da.client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
Messages: openai.F(history.Messages),
Model: openai.F("qwen-plus"),
})
for stream.Next() {
chunk := stream.Current()
for _, choice := range chunk.Choices {
responseBuilder.WriteString(choice.Delta.Content)
}
}
if streamErr = stream.Err(); streamErr == nil {
break
}
time.Sleep(time.Second)
}
if streamErr != nil {
da.logger.Printf("Failed to send initial message after retries: %v", streamErr)
return fmt.Errorf("failed to send initial message: %v", streamErr)
}
response := responseBuilder.String()
// 添加模型的响应到历史记录
history.Messages = append(history.Messages, openai.AssistantMessage(response))
for {
// sleep 30s to avoid rate limiting
time.Sleep(time.Second * 30)
da.logger.Printf("LLM Response:\n%s", response)
// Clean and validate the response
cleanedJSON, err := cleanJSONResponse(response)
if err != nil {
// Try to repair the JSON
da.logger.Printf("JSON repair attempt: %v", err)
fixed, repairErr := da.repairJSON(ctx, response, err)
if repairErr != nil {
return fmt.Errorf("JSON repair failed: %v (original: %v)", repairErr, err)
}
da.logger.Printf("JSON repaired:\n%s", fixed)
cleanedJSON, err = cleanJSONResponse(fixed)
if err != nil {
return fmt.Errorf("JSON validation failed: %v", err)
}
}
// Parse the response
var result APISpecList
if err := json.Unmarshal([]byte(cleanedJSON), &result); err != nil {
return fmt.Errorf("failed to parse LLM response: %v", err)
}
// 处理这一批次的APIs
for _, api := range result.APIs {
totalAPIs++
// 创建新的endpoint
endpoint := &M.Endpoint{
DocID: doc.ID,
Name: api.Name,
Path: api.Path,
Method: api.Method,
Description: api.Description,
BodyType: api.BodyType,
Merged: false,
Node: "proxy",
}
if err := cfg.DB().Create(endpoint).Error; err != nil {
return fmt.Errorf("failed to create endpoint: %v", err)
}
// 创建响应记录
response := &M.Response{
EndpointID: endpoint.ID,
StatusCode: api.Response.StatusCode,
Example: api.Response.Example,
Name: "Default Response", // 默认值
ContentType: "application/json", // 默认值
Description: "API Response", // 默认值
}
if err := cfg.DB().Create(response).Error; err != nil {
return fmt.Errorf("failed to create response: %v", err)
}
// 存储参数
var jsonParams []*Parameter
var otherParams []*Parameter
for _, param := range api.Inputs {
jsonParams = append(jsonParams, &param)
otherParams = append(otherParams, &param)
}
// 如果有application/json类型的参数将它们合并
if len(jsonParams) > 0 {
mergedValue := make(map[string]interface{})
var descriptions []string
for _, param := range jsonParams {
if param.Value != "" {
mergedValue[param.Name] = param.Value
}
if param.Description != "" {
descriptions = append(descriptions, param.Name+": "+param.Description)
}
}
// 创建合并后的参数
mergedValueJSON, _ := json.Marshal(mergedValue)
mergedParam := &M.Parameter{
EndpointID: endpoint.ID,
Name: "", // 空名称
Type: "body",
ParamType: "string",
Required: true,
Description: strings.Join(descriptions, "; "),
Example: "",
Value: string(mergedValueJSON),
}
if err := cfg.DB().Create(mergedParam).Error; err != nil {
return fmt.Errorf("failed to create merged json parameter: %v", err)
}
}
// 处理其他非application/json参数
locations := []string{"query", "path", "body"}
for _, location := range locations {
for _, param := range otherParams {
parameter := &M.Parameter{
EndpointID: endpoint.ID,
Name: param.Name,
Type: location,
ParamType: param.Type,
Required: true,
Description: param.Description,
Example: "",
Value: param.Value,
}
if err := cfg.DB().Create(parameter).Error; err != nil {
return fmt.Errorf("failed to create parameter: %v", err)
}
}
}
// 更新文档处理进度
progress := result.AnalysisPercent
if err := cfg.DB().Model(doc).Update("analysis_percent", progress).Error; err != nil {
da.logger.Printf("Failed to update progress: %v", err)
}
}
// 如果没有更多API要处理退出循环
if !result.HasMore {
break
}
// 使用chat history继续对话
followUpPrompt := `请继续提取剩余的API信息保持相同的输出格式。
请记住:
1. 不要重复之前已输出的接口
2. 如果还有更多内容,添加 "has_more": true
3. 如果已经输出完所有内容,添加 "has_more": false
4. 请严格按照schema格式输出确保是有效、完整的JSON格式。输出的JSON必须符合以上schema的规范比如string类型的值要注意转义字符的使用。`
history.Messages = append(history.Messages, openai.UserMessage(followUpPrompt))
// 重试逻辑
for retry := 0; retry < maxRetries; retry++ {
responseBuilder.Reset()
stream := da.client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
Messages: openai.F(history.Messages),
Model: openai.F("qwen-plus"),
})
for stream.Next() {
chunk := stream.Current()
for _, choice := range chunk.Choices {
responseBuilder.WriteString(choice.Delta.Content)
}
}
if streamErr = stream.Err(); streamErr == nil {
response = responseBuilder.String()
history.Messages = append(history.Messages, openai.AssistantMessage(response))
break
}
da.logger.Printf("Attempt %d failed: %v, retrying...", retry+1, streamErr)
time.Sleep(time.Second * 30)
}
if streamErr != nil {
return fmt.Errorf("failed to send follow-up message after %d retries: %v", maxRetries, streamErr)
}
}
return nil
}
func readPDFLedong(path string) (string, error) {
f, r, err := pdf.Open(path)
// remember close file
defer func() {
err := f.Close()
if err != nil {
return
}
}()
if err != nil {
return "", err
}
var buf bytes.Buffer
b, err := r.GetPlainText()
if err != nil {
return "", err
}
buf.ReadFrom(b)
return buf.String(), nil
}
// DocProcessor handles document processing queue and analysis
type DocProcessor struct {
queue chan *DocTask
workers int
wg sync.WaitGroup
analyzer *DocumentAnalyzer
cancelTasks map[string]chan struct{}
taskMutex sync.RWMutex
}
// DocTask represents a document processing task
type DocTask struct {
Doc *M.Doc
FilePath string
}
var (
processor *DocProcessor
processorOnce sync.Once
)
// GetDocProcessor returns a singleton instance of DocProcessor
func GetDocProcessor() *DocProcessor {
processorOnce.Do(func() {
analyzer, err := NewDocumentAnalyzer()
if err != nil {
log.Fatalf("Failed to create document analyzer: %v", err)
}
processor = &DocProcessor{
queue: make(chan *DocTask, 100), // Buffer size of 100
workers: 1, // Number of concurrent workers
analyzer: analyzer,
cancelTasks: make(map[string]chan struct{}),
}
processor.Start()
})
return processor
}
func (p *DocProcessor) CancelDocProcessing(docID string) {
p.taskMutex.Lock()
defer p.taskMutex.Unlock()
if cancel, exists := p.cancelTasks[docID]; exists {
close(cancel)
delete(p.cancelTasks, docID)
}
}
// Start initializes the worker pool
func (p *DocProcessor) Start() {
for i := 0; i < p.workers; i++ {
p.wg.Add(1)
go p.worker()
}
}
// AddTask adds a new document to the processing queue
func (p *DocProcessor) AddTask(doc *M.Doc, filePath string) {
p.taskMutex.Lock()
// 为这个任务创建新的取消通道
cancelChan := make(chan struct{})
p.cancelTasks[doc.ID] = cancelChan
p.taskMutex.Unlock()
task := &DocTask{
Doc: doc,
FilePath: filePath,
}
p.queue <- task
}
// worker processes documents from the queue
func (p *DocProcessor) worker() {
defer p.wg.Done()
for task := range p.queue {
err := p.processDocument(task)
if err != nil {
log.Printf("Error processing document %s: %v", task.Doc.ID, err)
// Update document status to error
cfg.DB().Model(task.Doc).Updates(map[string]interface{}{
"analysis_completed": true,
"analysis_error": err.Error(),
})
}
}
}
// processDocument handles the actual document processing
func (p *DocProcessor) processDocument(task *DocTask) error {
defer func() {
p.taskMutex.Lock()
delete(p.cancelTasks, task.Doc.ID)
p.taskMutex.Unlock()
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
p.taskMutex.RLock()
cancelChan := p.cancelTasks[task.Doc.ID]
p.taskMutex.RUnlock()
select {
case <-cancelChan:
cancel() // 收到取消信号时取消上下文
case <-ctx.Done():
}
}()
doc := task.Doc
// Update initial status
if err := cfg.DB().Model(doc).Updates(map[string]interface{}{
"analysis_completed": false,
"analysis_percent": 0,
}).Error; err != nil {
return fmt.Errorf("failed to update initial status: %v", err)
}
// Read PDF content
content, err := readPDFLedong(task.FilePath)
if err != nil {
return fmt.Errorf("failed to read PDF: %v", err)
}
// Extract and process APIs
err = p.analyzer.extractAPIs(ctx, content, doc)
if err != nil {
if ctx.Err() != nil {
// 如果是因为取消导致的错误,记录日志但不返回错误
log.Printf("Document processing cancelled for doc ID: %s", task.Doc.ID)
return nil
}
return fmt.Errorf("failed to extract APIs: %v", err)
}
// Update final status
return cfg.DB().Model(doc).Updates(map[string]interface{}{
"analysis_completed": true,
"analysis_percent": 100,
}).Error
}