620 lines
18 KiB
Go
620 lines
18 KiB
Go
package doc
|
||
|
||
import (
|
||
"app/cfg"
|
||
M "app/models"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/invopop/jsonschema"
|
||
"github.com/ledongthuc/pdf"
|
||
"github.com/openai/openai-go"
|
||
"github.com/openai/openai-go/option"
|
||
"github.com/xeipuuv/gojsonschema"
|
||
)
|
||
|
||
type Parameter struct {
|
||
Name string `json:"name" jsonschema:"required,description=参数名称"`
|
||
Location string `json:"location" jsonschema:"enum=query,enum=header,enum=body,enum=path,description=参数位置"`
|
||
Type string `json:"type,omitempty" jsonschema:"default=string,description=参数类型"`
|
||
ParamType string `json:"param_type,omitempty" jsonschema:"enum=application/json,enum=text/plain,enum=multipart/form-data,enum=application/x-www-form-urlencoded,enum=,description=只有position为body时有效,其他时候为空即可"`
|
||
Description string `json:"description" jsonschema:"required,description=参数描述"`
|
||
Value string `json:"value,omitempty" jsonschema:"description=参数示例值,统一用字符串形式,不要用json,如果是对象,需要转换成json字符串,一定要是纯字符串,绝对不要有字符串加字符串这种字符串间的拼接运算,还要注意字符串中不要给{}[]加转义字符,这样会导致错误"`
|
||
}
|
||
|
||
type Response struct {
|
||
StatusCode int `json:"status_code" jsonschema:"required,description=HTTP状态码"`
|
||
Example string `json:"example" jsonschema:"description=响应示例,是json字符串,一定要是纯字符串,绝对不要有字符串加字符串这种字符串间的拼接运算,还要注意字符串中不要给{}[]加转义字符,这样会导致错误"`
|
||
}
|
||
|
||
// APISpec represents an API specification
|
||
type APISpec struct {
|
||
Name string `json:"name" jsonschema:"required,description=API名称"`
|
||
Description string `json:"description" jsonschema:"required,description=API描述"`
|
||
Inputs []Parameter `json:"inputs" jsonschema:"description=输入参数列表"`
|
||
Outputs []Parameter `json:"outputs" jsonschema:"description=输出参数列表"`
|
||
Method string `json:"method" jsonschema:"required,enum=GET,enum=POST,enum=PUT,enum=DELETE,enum=PATCH,description=HTTP方法"`
|
||
Path string `json:"path" jsonschema:"required,description=API路径"`
|
||
BodyType string `json:"body_type,omitempty" jsonschema:"default=application/json,description=请求体类型"`
|
||
Response Response `json:"response" jsonschema:"required,description=API响应信息"`
|
||
}
|
||
|
||
// APISpecList represents a list of API specifications
|
||
type APISpecList struct {
|
||
APIs []APISpec `json:"apis" jsonschema:"required,description=API列表"`
|
||
HasMore bool `json:"has_more" jsonschema:"required,description=是否还有更多API需要继续在下一次输出中继续输出"`
|
||
AnalysisPercent int `json:"analysis_percent" jsonschema:"required,description=估计的分析进度百分比,范围是0-100,不一定递增,可以调整,比如发现增长过快的时候可以减少,只有全部解析完毕才可以到100"`
|
||
}
|
||
|
||
// ChatHistory represents the chat history
|
||
type ChatHistory struct {
|
||
Messages []openai.ChatCompletionMessageParamUnion
|
||
DocContent string
|
||
Schema string
|
||
}
|
||
|
||
// setupLogger creates and configures the logger
|
||
func setupLogger() *log.Logger {
|
||
if err := os.MkdirAll("logs", 0755); err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
|
||
timestamp := time.Now().Format("20060102_150405")
|
||
logFilename := filepath.Join("logs", fmt.Sprintf("llm_output_%s.log", timestamp))
|
||
|
||
file, err := os.OpenFile(logFilename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
|
||
return log.New(file, "", log.LstdFlags)
|
||
}
|
||
|
||
// DocumentAnalyzer handles document analysis
|
||
type DocumentAnalyzer struct {
|
||
client *openai.Client
|
||
logger *log.Logger
|
||
}
|
||
|
||
// NewDocumentAnalyzer creates a new DocumentAnalyzer instance
|
||
func NewDocumentAnalyzer() (*DocumentAnalyzer, error) {
|
||
client := openai.NewClient(
|
||
option.WithAPIKey("sk-0213c70194624703a1d0d80e0f762b0e"),
|
||
option.WithBaseURL("https://dashscope.aliyuncs.com/compatible-mode/v1/"),
|
||
// option.WithBaseURL("http://127.0.0.1:11434/"),
|
||
)
|
||
fmt.Println("API client initialized successfully")
|
||
return &DocumentAnalyzer{
|
||
client: client,
|
||
logger: setupLogger(),
|
||
}, nil
|
||
}
|
||
|
||
// generateJSONSchema generates JSON schema from the APISpecList struct
|
||
func generateJSONSchema() ([]byte, error) {
|
||
reflector := jsonschema.Reflector{
|
||
RequiredFromJSONSchemaTags: true,
|
||
AllowAdditionalProperties: true,
|
||
DoNotReference: true,
|
||
}
|
||
|
||
schema := reflector.Reflect(&APISpecList{})
|
||
|
||
return json.MarshalIndent(schema, "", " ")
|
||
}
|
||
|
||
// validateJSON validates the JSON response against the schema
|
||
func validateJSON(data []byte) error {
|
||
schema, err := generateJSONSchema()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to generate schema: %v", err)
|
||
}
|
||
|
||
schemaLoader := gojsonschema.NewBytesLoader(schema)
|
||
documentLoader := gojsonschema.NewBytesLoader(data)
|
||
|
||
result, err := gojsonschema.Validate(schemaLoader, documentLoader)
|
||
if err != nil {
|
||
return fmt.Errorf("validation error: %v", err)
|
||
}
|
||
|
||
if !result.Valid() {
|
||
var errors []string
|
||
for _, desc := range result.Errors() {
|
||
errors = append(errors, desc.String())
|
||
}
|
||
return fmt.Errorf("invalid JSON: %v", errors)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// cleanJSONResponse cleans and validates the LLM response
|
||
func cleanJSONResponse(response string) (string, error) {
|
||
// Find the first { and last }
|
||
start := 0
|
||
end := len(response)
|
||
|
||
for i := 0; i < len(response); i++ {
|
||
if response[i] == '{' {
|
||
start = i
|
||
break
|
||
}
|
||
}
|
||
|
||
for i := len(response) - 1; i >= 0; i-- {
|
||
if response[i] == '}' {
|
||
end = i + 1
|
||
break
|
||
}
|
||
}
|
||
|
||
if start >= end {
|
||
return "", fmt.Errorf("invalid JSON structure")
|
||
}
|
||
|
||
jsonStr := response[start:end]
|
||
|
||
// Validate the JSON
|
||
if err := validateJSON([]byte(jsonStr)); err != nil {
|
||
return "", fmt.Errorf("JSON validation failed: %v", err)
|
||
}
|
||
|
||
return jsonStr, nil
|
||
}
|
||
|
||
func (da *DocumentAnalyzer) repairJSON(ctx context.Context, malformedJSON string, originalError error) (string, error) {
|
||
prompt := fmt.Sprintf(`Fix this malformed JSON that had error: %v
|
||
|
||
JSON to fix:
|
||
%s
|
||
|
||
Return only the fixed JSON with no explanations.`, originalError, malformedJSON)
|
||
|
||
var responseBuilder strings.Builder
|
||
|
||
stream := da.client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
|
||
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
|
||
openai.UserMessage(prompt),
|
||
}),
|
||
Model: openai.F("qwen-plus"),
|
||
})
|
||
|
||
for stream.Next() {
|
||
chunk := stream.Current()
|
||
for _, choice := range chunk.Choices {
|
||
responseBuilder.WriteString(choice.Delta.Content)
|
||
}
|
||
}
|
||
|
||
if err := stream.Err(); err != nil {
|
||
return "", fmt.Errorf("failed to send message: %v", err)
|
||
}
|
||
|
||
fixed := responseBuilder.String()
|
||
if fixed == "" {
|
||
return "", fmt.Errorf("no repair response")
|
||
}
|
||
|
||
return fixed, nil
|
||
}
|
||
|
||
// extractAPIs extracts API information from document content
|
||
func (da *DocumentAnalyzer) extractAPIs(ctx context.Context, docContent string, doc *M.Doc) error {
|
||
schema, err := generateJSONSchema()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to generate schema: %v", err)
|
||
}
|
||
da.logger.Printf("JSON Schema:\n%s", schema)
|
||
|
||
// 初始化对话历史
|
||
history := &ChatHistory{
|
||
Messages: make([]openai.ChatCompletionMessageParamUnion, 0),
|
||
}
|
||
|
||
// 初始化对话,发送文档内容和schema
|
||
initialPrompt := fmt.Sprintf(`你是一个API文档分析助手。你的任务是从文档中提取API信息。
|
||
请从以下文档中提取所有API接口信息,包括页面、功能操作名称、功能描述、请求类型、接口地址、输入和输出参数。
|
||
请确保完整提取每个接口的所有信息,不要遗漏或截断,如果有不确定的项,可以设置为空。
|
||
如果文档内容较多,你可以分批次输出,每次输出一部分API信息,并设置has_more为true表示还有更多API需要在下一次输出。
|
||
当所有API都输出完成时,设置has_more为false。
|
||
|
||
文档内容:
|
||
%s
|
||
|
||
JSON Schema:
|
||
%s
|
||
|
||
请严格按照schema格式输出,确保是有效的JSON格式。输出的JSON必须符合以上schema的规范。`, docContent, string(schema))
|
||
|
||
totalAPIs := 0 // 用于跟踪总共处理的API数量
|
||
maxRetries := 10
|
||
|
||
// 添加初始消息到历史记录
|
||
history.Messages = append(history.Messages, openai.UserMessage(initialPrompt))
|
||
|
||
maxInitRetries := 3
|
||
var responseBuilder strings.Builder
|
||
var streamErr error
|
||
|
||
// 发送初始消息并重试
|
||
for i := 0; i < maxInitRetries; i++ {
|
||
responseBuilder.Reset()
|
||
stream := da.client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
|
||
Messages: openai.F(history.Messages),
|
||
Model: openai.F("qwen-plus"),
|
||
})
|
||
|
||
for stream.Next() {
|
||
chunk := stream.Current()
|
||
for _, choice := range chunk.Choices {
|
||
responseBuilder.WriteString(choice.Delta.Content)
|
||
}
|
||
}
|
||
|
||
if streamErr = stream.Err(); streamErr == nil {
|
||
break
|
||
}
|
||
time.Sleep(time.Second)
|
||
}
|
||
|
||
if streamErr != nil {
|
||
da.logger.Printf("Failed to send initial message after retries: %v", streamErr)
|
||
return fmt.Errorf("failed to send initial message: %v", streamErr)
|
||
}
|
||
|
||
response := responseBuilder.String()
|
||
// 添加模型的响应到历史记录
|
||
history.Messages = append(history.Messages, openai.AssistantMessage(response))
|
||
|
||
for {
|
||
// sleep 30s to avoid rate limiting
|
||
time.Sleep(time.Second * 30)
|
||
|
||
da.logger.Printf("LLM Response:\n%s", response)
|
||
|
||
// Clean and validate the response
|
||
cleanedJSON, err := cleanJSONResponse(response)
|
||
if err != nil {
|
||
// Try to repair the JSON
|
||
da.logger.Printf("JSON repair attempt: %v", err)
|
||
fixed, repairErr := da.repairJSON(ctx, response, err)
|
||
if repairErr != nil {
|
||
return fmt.Errorf("JSON repair failed: %v (original: %v)", repairErr, err)
|
||
}
|
||
da.logger.Printf("JSON repaired:\n%s", fixed)
|
||
cleanedJSON, err = cleanJSONResponse(fixed)
|
||
if err != nil {
|
||
return fmt.Errorf("JSON validation failed: %v", err)
|
||
}
|
||
}
|
||
// Parse the response
|
||
var result APISpecList
|
||
if err := json.Unmarshal([]byte(cleanedJSON), &result); err != nil {
|
||
return fmt.Errorf("failed to parse LLM response: %v", err)
|
||
}
|
||
|
||
// 处理这一批次的APIs
|
||
for _, api := range result.APIs {
|
||
totalAPIs++
|
||
|
||
// 创建新的endpoint
|
||
endpoint := &M.Endpoint{
|
||
DocID: doc.ID,
|
||
Name: api.Name,
|
||
Path: api.Path,
|
||
Method: api.Method,
|
||
Description: api.Description,
|
||
BodyType: api.BodyType,
|
||
Merged: false,
|
||
Node: "proxy",
|
||
}
|
||
|
||
if err := cfg.DB().Create(endpoint).Error; err != nil {
|
||
return fmt.Errorf("failed to create endpoint: %v", err)
|
||
}
|
||
|
||
// 创建响应记录
|
||
response := &M.Response{
|
||
EndpointID: endpoint.ID,
|
||
StatusCode: api.Response.StatusCode,
|
||
Example: api.Response.Example,
|
||
Name: "Default Response", // 默认值
|
||
ContentType: "application/json", // 默认值
|
||
Description: "API Response", // 默认值
|
||
}
|
||
|
||
if err := cfg.DB().Create(response).Error; err != nil {
|
||
return fmt.Errorf("failed to create response: %v", err)
|
||
}
|
||
|
||
// 存储参数
|
||
var jsonParams []*Parameter
|
||
var otherParams []*Parameter
|
||
|
||
for _, param := range api.Inputs {
|
||
jsonParams = append(jsonParams, ¶m)
|
||
otherParams = append(otherParams, ¶m)
|
||
}
|
||
|
||
// 如果有application/json类型的参数,将它们合并
|
||
if len(jsonParams) > 0 {
|
||
mergedValue := make(map[string]interface{})
|
||
var descriptions []string
|
||
|
||
for _, param := range jsonParams {
|
||
if param.Value != "" {
|
||
mergedValue[param.Name] = param.Value
|
||
}
|
||
if param.Description != "" {
|
||
descriptions = append(descriptions, param.Name+": "+param.Description)
|
||
}
|
||
}
|
||
// 创建合并后的参数
|
||
mergedValueJSON, _ := json.Marshal(mergedValue)
|
||
mergedParam := &M.Parameter{
|
||
EndpointID: endpoint.ID,
|
||
Name: "", // 空名称
|
||
Type: "body",
|
||
ParamType: "string",
|
||
Required: true,
|
||
Description: strings.Join(descriptions, "; "),
|
||
Example: "",
|
||
Value: string(mergedValueJSON),
|
||
}
|
||
|
||
if err := cfg.DB().Create(mergedParam).Error; err != nil {
|
||
return fmt.Errorf("failed to create merged json parameter: %v", err)
|
||
}
|
||
}
|
||
|
||
// 处理其他非application/json参数
|
||
locations := []string{"query", "path", "body"}
|
||
for _, location := range locations {
|
||
for _, param := range otherParams {
|
||
parameter := &M.Parameter{
|
||
EndpointID: endpoint.ID,
|
||
Name: param.Name,
|
||
Type: location,
|
||
ParamType: param.Type,
|
||
Required: true,
|
||
Description: param.Description,
|
||
Example: "",
|
||
Value: param.Value,
|
||
}
|
||
|
||
if err := cfg.DB().Create(parameter).Error; err != nil {
|
||
return fmt.Errorf("failed to create parameter: %v", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 更新文档处理进度
|
||
progress := result.AnalysisPercent
|
||
|
||
if err := cfg.DB().Model(doc).Update("analysis_percent", progress).Error; err != nil {
|
||
da.logger.Printf("Failed to update progress: %v", err)
|
||
}
|
||
}
|
||
|
||
// 如果没有更多API要处理,退出循环
|
||
if !result.HasMore {
|
||
break
|
||
}
|
||
|
||
// 使用chat history继续对话
|
||
followUpPrompt := `请继续提取剩余的API信息,保持相同的输出格式。
|
||
请记住:
|
||
1. 不要重复之前已输出的接口
|
||
2. 如果还有更多内容,添加 "has_more": true
|
||
3. 如果已经输出完所有内容,添加 "has_more": false
|
||
4. 请严格按照schema格式输出,确保是有效、完整的JSON格式。输出的JSON必须符合以上schema的规范,比如string类型的值要注意转义字符的使用。`
|
||
|
||
history.Messages = append(history.Messages, openai.UserMessage(followUpPrompt))
|
||
|
||
// 重试逻辑
|
||
for retry := 0; retry < maxRetries; retry++ {
|
||
responseBuilder.Reset()
|
||
stream := da.client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
|
||
Messages: openai.F(history.Messages),
|
||
Model: openai.F("qwen-plus"),
|
||
})
|
||
|
||
for stream.Next() {
|
||
chunk := stream.Current()
|
||
for _, choice := range chunk.Choices {
|
||
responseBuilder.WriteString(choice.Delta.Content)
|
||
}
|
||
}
|
||
|
||
if streamErr = stream.Err(); streamErr == nil {
|
||
response = responseBuilder.String()
|
||
history.Messages = append(history.Messages, openai.AssistantMessage(response))
|
||
break
|
||
}
|
||
|
||
da.logger.Printf("Attempt %d failed: %v, retrying...", retry+1, streamErr)
|
||
time.Sleep(time.Second * 30)
|
||
}
|
||
|
||
if streamErr != nil {
|
||
return fmt.Errorf("failed to send follow-up message after %d retries: %v", maxRetries, streamErr)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func readPDFLedong(path string) (string, error) {
|
||
f, r, err := pdf.Open(path)
|
||
// remember close file
|
||
defer func() {
|
||
err := f.Close()
|
||
if err != nil {
|
||
return
|
||
}
|
||
}()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
var buf bytes.Buffer
|
||
b, err := r.GetPlainText()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
buf.ReadFrom(b)
|
||
return buf.String(), nil
|
||
}
|
||
|
||
// DocProcessor handles document processing queue and analysis
|
||
type DocProcessor struct {
|
||
queue chan *DocTask
|
||
workers int
|
||
wg sync.WaitGroup
|
||
analyzer *DocumentAnalyzer
|
||
cancelTasks map[string]chan struct{}
|
||
taskMutex sync.RWMutex
|
||
}
|
||
|
||
// DocTask represents a document processing task
|
||
type DocTask struct {
|
||
Doc *M.Doc
|
||
FilePath string
|
||
}
|
||
|
||
var (
|
||
processor *DocProcessor
|
||
processorOnce sync.Once
|
||
)
|
||
|
||
// GetDocProcessor returns a singleton instance of DocProcessor
|
||
func GetDocProcessor() *DocProcessor {
|
||
processorOnce.Do(func() {
|
||
analyzer, err := NewDocumentAnalyzer()
|
||
if err != nil {
|
||
log.Fatalf("Failed to create document analyzer: %v", err)
|
||
}
|
||
processor = &DocProcessor{
|
||
queue: make(chan *DocTask, 100), // Buffer size of 100
|
||
workers: 1, // Number of concurrent workers
|
||
analyzer: analyzer,
|
||
cancelTasks: make(map[string]chan struct{}),
|
||
}
|
||
processor.Start()
|
||
})
|
||
return processor
|
||
}
|
||
|
||
func (p *DocProcessor) CancelDocProcessing(docID string) {
|
||
p.taskMutex.Lock()
|
||
defer p.taskMutex.Unlock()
|
||
|
||
if cancel, exists := p.cancelTasks[docID]; exists {
|
||
close(cancel)
|
||
delete(p.cancelTasks, docID)
|
||
}
|
||
}
|
||
|
||
// Start initializes the worker pool
|
||
func (p *DocProcessor) Start() {
|
||
for i := 0; i < p.workers; i++ {
|
||
p.wg.Add(1)
|
||
go p.worker()
|
||
}
|
||
}
|
||
|
||
// AddTask adds a new document to the processing queue
|
||
func (p *DocProcessor) AddTask(doc *M.Doc, filePath string) {
|
||
p.taskMutex.Lock()
|
||
// 为这个任务创建新的取消通道
|
||
cancelChan := make(chan struct{})
|
||
p.cancelTasks[doc.ID] = cancelChan
|
||
p.taskMutex.Unlock()
|
||
|
||
task := &DocTask{
|
||
Doc: doc,
|
||
FilePath: filePath,
|
||
}
|
||
p.queue <- task
|
||
}
|
||
|
||
// worker processes documents from the queue
|
||
func (p *DocProcessor) worker() {
|
||
defer p.wg.Done()
|
||
|
||
for task := range p.queue {
|
||
err := p.processDocument(task)
|
||
if err != nil {
|
||
log.Printf("Error processing document %s: %v", task.Doc.ID, err)
|
||
// Update document status to error
|
||
cfg.DB().Model(task.Doc).Updates(map[string]interface{}{
|
||
"analysis_completed": true,
|
||
"analysis_error": err.Error(),
|
||
})
|
||
}
|
||
}
|
||
}
|
||
|
||
// processDocument handles the actual document processing
|
||
func (p *DocProcessor) processDocument(task *DocTask) error {
|
||
defer func() {
|
||
p.taskMutex.Lock()
|
||
delete(p.cancelTasks, task.Doc.ID)
|
||
p.taskMutex.Unlock()
|
||
}()
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
go func() {
|
||
p.taskMutex.RLock()
|
||
cancelChan := p.cancelTasks[task.Doc.ID]
|
||
p.taskMutex.RUnlock()
|
||
|
||
select {
|
||
case <-cancelChan:
|
||
cancel() // 收到取消信号时取消上下文
|
||
case <-ctx.Done():
|
||
}
|
||
}()
|
||
|
||
doc := task.Doc
|
||
|
||
// Update initial status
|
||
if err := cfg.DB().Model(doc).Updates(map[string]interface{}{
|
||
"analysis_completed": false,
|
||
"analysis_percent": 0,
|
||
}).Error; err != nil {
|
||
return fmt.Errorf("failed to update initial status: %v", err)
|
||
}
|
||
|
||
// Read PDF content
|
||
content, err := readPDFLedong(task.FilePath)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to read PDF: %v", err)
|
||
}
|
||
|
||
// Extract and process APIs
|
||
err = p.analyzer.extractAPIs(ctx, content, doc)
|
||
if err != nil {
|
||
if ctx.Err() != nil {
|
||
// 如果是因为取消导致的错误,记录日志但不返回错误
|
||
log.Printf("Document processing cancelled for doc ID: %s", task.Doc.ID)
|
||
return nil
|
||
}
|
||
return fmt.Errorf("failed to extract APIs: %v", err)
|
||
}
|
||
|
||
// Update final status
|
||
return cfg.DB().Model(doc).Updates(map[string]interface{}{
|
||
"analysis_completed": true,
|
||
"analysis_percent": 100,
|
||
}).Error
|
||
}
|