Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions backend/internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ func buildConfigSnapshot(providerName, modelID string, params map[string]interfa
} else if v, ok := params["image_size"].(string); ok && v != "" {
snapshot["imageSize"] = v
}
if v, ok := params["size"].(string); ok && strings.TrimSpace(v) != "" {
snapshot["size"] = strings.TrimSpace(v)
}
if v, ok := params["quality"].(string); ok && strings.TrimSpace(v) != "" {
snapshot["quality"] = strings.TrimSpace(v)
}

// count 可能是 float64(JSON 解析)或 int(服务内部)
if v, ok := params["count"].(int); ok && v > 0 {
Expand Down Expand Up @@ -132,7 +138,7 @@ func fetchProviderConfig(providerName string) *model.ProviderConfig {

func defaultTimeoutSecondsForProvider(providerName string) int {
switch providerName {
case "gemini", "openai":
case "gemini", "openai", "openai-image":
return 500
default:
return 150
Expand All @@ -141,7 +147,7 @@ func defaultTimeoutSecondsForProvider(providerName string) int {

func providerDefaultMaxRetries(providerName string) int {
switch providerName {
case "gemini", "openai":
case "gemini", "openai", "openai-image":
return 1
default:
return 1
Expand Down
6 changes: 3 additions & 3 deletions backend/internal/model/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ func InitDB(dbPath string) {

// 兼容旧版本默认超时(0/60s)记录:按 Provider 类型修复到对应默认值
if err := DB.Model(&ProviderConfig{}).
Where("provider_name IN ? AND (timeout_seconds <= 0 OR timeout_seconds = ?)", []string{"gemini", "openai"}, 60).
Where("provider_name IN ? AND (timeout_seconds <= 0 OR timeout_seconds = ?)", []string{"gemini", "openai", "openai-image"}, 60).
Update("timeout_seconds", 500).Error; err != nil {
log.Printf("更新生图默认超时失败: %v", err)
}
if err := DB.Model(&ProviderConfig{}).
Where("provider_name NOT IN ? AND (timeout_seconds <= 0 OR timeout_seconds = ?)", []string{"gemini", "openai"}, 60).
Where("provider_name NOT IN ? AND (timeout_seconds <= 0 OR timeout_seconds = ?)", []string{"gemini", "openai", "openai-image"}, 60).
Update("timeout_seconds", 150).Error; err != nil {
log.Printf("更新对话默认超时失败: %v", err)
}
Expand All @@ -75,7 +75,7 @@ func InitDB(dbPath string) {

func defaultTimeoutForProvider(providerName string) time.Duration {
switch providerName {
case "gemini", "openai":
case "gemini", "openai", "openai-image":
return 500 * time.Second
default:
return 150 * time.Second
Expand Down
2 changes: 1 addition & 1 deletion backend/internal/provider/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func NewGeminiProvider(config *model.ProviderConfig) (*GeminiProvider, error) {
func (p *GeminiProvider) newHTTPClient() *http.Client {
return &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: false,
ForceAttemptHTTP2: false,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: false,
MinVersion: tls.VersionTLS12,
Expand Down
6 changes: 6 additions & 0 deletions backend/internal/provider/model_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,11 @@ func defaultModelForProvider(providerName string, purpose ModelPurpose) string {
if purpose == PurposeChat || name == "openai-chat" {
return "gemini-3-flash-preview"
}
if name == "openai-image" {
return "gpt-image-1"
}
if name == "openai" {
return "gemini-3-pro-image-preview"
}
Comment on lines +93 to +95

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: The default model returned for the openai provider is a Gemini model ID, which does not match the OpenAI API contract. When request/config model is absent, this fallback will send an invalid model to OpenAI and cause generation requests to fail (typically 400 model-not-found). Use an OpenAI-compatible default model here (for example gpt-image-1) to keep resolver behavior consistent with provider type. [api mismatch]

Severity Level: Critical 🚨
- ❌ /api/v1/tasks/generate fails when provider=openai without model.
- ❌ OpenAIProvider sends Gemini model name to OpenAI /chat API.
- ⚠️ Users see 4xx model-not-found errors for OpenAI provider.
Steps of Reproduction ✅
1. Start the backend server so `InitProviders` runs in
`backend/internal/provider/provider.go:69-21`, which ensures a default `ProviderConfig`
row exists for `ProviderName = "openai"` without any `Models` configured (lines 69-21 show
`defaultProviders := []string{"gemini", "openai", "openai-image"}` and
`model.DB.Create(&model.ProviderConfig{ ProviderName: name, ... })`).

2. Call the HTTP endpoint `POST /api/v1/tasks/generate` defined in
`backend/cmd/server/main.go:17`, sending a JSON body that sets `"provider": "openai"` and
omits both `model_id` and any `params.model`/`params.model_id` (this is bound into
`GenerateRequest` and handled by `GenerateHandler` in
`backend/internal/api/handlers.go:400`).

3. Inside `GenerateHandler` (`backend/internal/api/handlers.go:29-37`),
`provider.ResolveModelID` is invoked with `ProviderName: req.Provider` (i.e. `"openai"`)
and `Purpose: provider.PurposeImage`; because the request has no model and the provider
config has no `Models`, `defaultModelForProvider` in
`backend/internal/provider/model_resolver.go:85-96` runs and, for `name == "openai"`,
returns `"gemini-3-pro-image-preview"`, so `modelID` becomes this Gemini model and
`req.Params["model_id"] = modelID` is set.

4. The task is queued and later executed by the worker: `worker.Pool` calls
`p.Generate(ctx, task.Params)` in `backend/internal/worker/pool.go:225-27`, where `p` is
the `OpenAIProvider` created in `backend/internal/provider/provider.go:74-78`;
`OpenAIProvider.Generate` in `backend/internal/provider/openai.go:83-88` resolves
`modelID` from `params["model_id"]`, then posts to `p.apiBase + "/chat/completions"` with
`"model": "gemini-3-pro-image-preview"` (see `buildChatRequestBody` at `openai.go:35-39`
and `doChatRequest` at `openai.go:98-104`), causing the OpenAI-compatible backend to
reject the request with a 4xx "model not found" error because the model name is not a
valid OpenAI/OpenAI-compat model for that endpoint.

Fix in Cursor | Fix in VSCode Claude

(Use Cmd/Ctrl + Click for best experience)

Prompt for AI Agent 🤖
This is a comment left during a code review.

**Path:** backend/internal/provider/model_resolver.go
**Line:** 93:95
**Comment:**
	*Api Mismatch: The default model returned for the `openai` provider is a Gemini model ID, which does not match the OpenAI API contract. When request/config model is absent, this fallback will send an invalid model to OpenAI and cause generation requests to fail (typically 400 model-not-found). Use an OpenAI-compatible default model here (for example `gpt-image-1`) to keep resolver behavior consistent with provider type.

Validate the correctness of the flagged issue. If correct, How can I resolve this? If you propose a fix, implement it and please make it concise.
Once fix is implemented, also check other comments on the same PR, and ask user if the user wants to fix the rest of the comments as well. if said yes, then fetch all the comments validate the correctness and implement a minimal fix
👍 | 👎

return "gemini-3-pro-image-preview"
}
233 changes: 233 additions & 0 deletions backend/internal/provider/openai_image.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
package provider

import (
"bytes"
"context"
"encoding/json"
"fmt"
"image-gen-service/internal/diagnostic"
"image-gen-service/internal/model"
"io"
"net/http"
"strings"
"time"
)

type OpenAIImageProvider struct {
*OpenAIProvider
}

type openAIImagesGenerationRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Size string `json:"size"`
Quality string `json:"quality,omitempty"`
N int `json:"n,omitempty"`
}

func NewOpenAIImageProvider(config *model.ProviderConfig) (*OpenAIImageProvider, error) {
base, err := NewOpenAIProvider(config)
if err != nil {
return nil, err
}
return &OpenAIImageProvider{OpenAIProvider: base}, nil
}

func (p *OpenAIImageProvider) Name() string {
return "openai-image"
}

func (p *OpenAIImageProvider) ValidateParams(params map[string]interface{}) error {
prompt, _ := params["prompt"].(string)
if strings.TrimSpace(prompt) == "" {
return fmt.Errorf("prompt 不能为空")
}
if raw, ok := params["reference_images"].([]interface{}); ok && len(raw) > 0 {
return fmt.Errorf("OpenAI Images 当前仅支持文本生图")
}

count, ok := toInt(params["count"])
if !ok {
count = 1
}
if count < 1 || count > 10 {
return fmt.Errorf("count/n 必须介于 1 和 10 之间")
}

size, _ := params["size"].(string)
switch strings.TrimSpace(strings.ToLower(size)) {
case "", "auto", "1024x1024", "1024x1536", "1536x1024":
default:
return fmt.Errorf("size 仅支持 auto、1024x1024、1024x1536、1536x1024")
}

quality, _ := params["quality"].(string)
switch strings.TrimSpace(strings.ToLower(quality)) {
case "", "auto", "low", "medium", "high":
default:
return fmt.Errorf("quality 仅支持 auto、low、medium、high")
}

return nil
}

func (p *OpenAIImageProvider) Generate(ctx context.Context, params map[string]interface{}) (*ProviderResult, error) {
modelID := ResolveModelID(ModelResolveOptions{
ProviderName: p.Name(),
Purpose: PurposeImage,
Params: params,
Config: p.config,
}).ID
if modelID == "" {
return nil, fmt.Errorf("缺少 model_id 参数")
}

reqBody, promptPreview, err := p.buildImagesGenerationRequestBody(modelID, params)
if err != nil {
return nil, err
}

diagnostic.Logf(params, "request_prepare",
"provider=%s model=%s size=%q quality=%q count=%d prompt_hash=%s prompt_preview=%q",
p.Name(),
modelID,
reqBody.Size,
reqBody.Quality,
reqBody.N,
diagnostic.PromptHash(promptPreview),
diagnostic.Preview(promptPreview, 160),
)

respBytes, headers, err := p.doImagesGenerationRequest(ctx, reqBody, params)
if err != nil {
return nil, err
}

images, summary, err := p.extractImages(ctx, respBytes)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: This provider reuses the generic image extraction path, which will follow data[].url values and issue server-side HTTP GETs without any host allowlist. For the new /images/generations flow this path is commonly exercised, so a malicious or compromised upstream/proxy response can trigger SSRF to internal services. Restrict this provider to b64_json responses (no remote fetch) or strictly validate allowed URL hosts/schemes before downloading. [ssrf]

Severity Level: Critical 🚨
- ❌ Image generation endpoint can probe internal HTTP services.
- ❌ Backend may fetch cloud metadata or admin endpoints.
- ⚠️ Non-image internal responses may be stored as images.
Steps of Reproduction ✅
1. Start the backend server by running `backend/cmd/server/main.go`; the route
`v1.POST("/tasks/generate", api.GenerateHandler)` is registered at
`backend/cmd/server/main.go:274-276`.

2. In the database `provider_config` table, configure provider `openai-image` so that its
`APIBase` points to an attacker-controlled OpenAI-compatible service; this base URL is
consumed by `NewOpenAIProvider` in `backend/internal/provider/openai.go:27-44` (stored as
`p.apiBase`) and `openai-image` is seeded in
`backend/internal/provider/provider.go:69-80`.

3. Implement the fake upstream's `/v1/images/generations` endpoint so that for any image
request it returns HTTP 200 with a body like
`{"data":[{"url":"http://127.0.0.1:8000/internal-admin"}]}` (no `b64_json` field) so that
only the `url` path is exercised.

4. From a client, call `POST /v1/tasks/generate` (handled by `GenerateHandler` in
`backend/internal/api/handlers.go:399-479`) with JSON such as
`{"provider":"openai-image","modelId":"test-model","params":{"prompt":"test
prompt","count":1}}`; the handler validates params, persists a `model.Task`, and enqueues
a `worker.Task` to `worker.Pool` at `backend/internal/api/handlers.go:31-49`.

5. The worker picks up the task in `processTask`
(`backend/internal/worker/pool.go:135-188`), looks up the provider via
`provider.GetProvider(task.TaskModel.ProviderName)` at
`backend/internal/worker/pool.go:175`, and in a goroutine calls `p.Generate(ctx,
task.Params)` at `backend/internal/worker/pool.go:26-27` with `ProviderName` set to
`"openai-image"`.

6. Inside `OpenAIImageProvider.Generate`
(`backend/internal/provider/openai_image.go:74-122`), the code calls
`p.doImagesGenerationRequest` to `p.apiBase + "/images/generations"` and then invokes
`images, summary, err := p.extractImages(ctx, respBytes)` at
`backend/internal/provider/openai_image.go:106`.

7. `extractImages` in `backend/internal/provider/openai.go:297-346` parses the JSON, sees
a non-empty `data` array, and calls `p.extractImagesFromData(ctx, data)` at
`openai.go:305-309`, which iterates each element and, finding only a `url` field, executes
`imgBytes, err := p.fetchImage(ctx, url)` at `openai.go:365-371`.

8. `fetchImage` (`backend/internal/provider/openai.go:438-467`) issues an HTTP GET using
`http.NewRequestWithContext(ctx, http.MethodGet, url, nil)` and `p.httpClient.Do(req)`
without any host, IP range, or scheme validation, so the backend server makes an outbound
request directly to `http://127.0.0.1:8000/internal-admin`. Observing this request on the
internal service or via network capture confirms that attacker-controlled `data[].url`
values from the upstream response cause server-side requests to arbitrary internal URLs
(SSRF).

Fix in Cursor | Fix in VSCode Claude

(Use Cmd/Ctrl + Click for best experience)

Prompt for AI Agent 🤖
This is a comment left during a code review.

**Path:** backend/internal/provider/openai_image.go
**Line:** 106:106
**Comment:**
	*Ssrf: This provider reuses the generic image extraction path, which will follow `data[].url` values and issue server-side HTTP GETs without any host allowlist. For the new `/images/generations` flow this path is commonly exercised, so a malicious or compromised upstream/proxy response can trigger SSRF to internal services. Restrict this provider to `b64_json` responses (no remote fetch) or strictly validate allowed URL hosts/schemes before downloading.

Validate the correctness of the flagged issue. If correct, How can I resolve this? If you propose a fix, implement it and please make it concise.
Once fix is implemented, also check other comments on the same PR, and ask user if the user wants to fix the rest of the comments as well. if said yes, then fetch all the comments validate the correctness and implement a minimal fix
👍 | 👎

if err != nil {
return nil, err
}

requestID := extractRequestIDFromHeaders(headers)
diagnostic.Logf(params, "response_summary",
"provider=%s model=%s data_count=%d choice_count=%d image_count=%d request_id=%s",
p.Name(),
modelID,
summary.DataCount,
summary.ChoiceCount,
len(images),
requestID,
)

return &ProviderResult{
Images: images,
Metadata: map[string]interface{}{
"provider": p.Name(),
"model": modelID,
"type": "image",
"request_id": requestID,
"oneapi_request": strings.TrimSpace(headers.Get("X-Oneapi-Request-Id")),
},
}, nil
}

func (p *OpenAIImageProvider) buildImagesGenerationRequestBody(modelID string, params map[string]interface{}) (*openAIImagesGenerationRequest, string, error) {
prompt, _ := params["prompt"].(string)
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return nil, "", fmt.Errorf("缺少 prompt 参数")
}

body := &openAIImagesGenerationRequest{
Model: modelID,
Prompt: prompt,
Size: "auto",
N: 1,
}
if size, _ := params["size"].(string); strings.TrimSpace(size) != "" {
body.Size = strings.TrimSpace(strings.ToLower(size))
}
if quality, _ := params["quality"].(string); strings.TrimSpace(quality) != "" {
body.Quality = strings.TrimSpace(strings.ToLower(quality))
}
if count, ok := toInt(params["count"]); ok && count >= 1 && count <= 10 {
body.N = count
}

return body, prompt, nil
}

func (p *OpenAIImageProvider) doImagesGenerationRequest(ctx context.Context, body *openAIImagesGenerationRequest, params map[string]interface{}) ([]byte, http.Header, error) {
payloadBytes, err := json.Marshal(body)
if err != nil {
return nil, nil, fmt.Errorf("序列化 OpenAI Images 请求失败: %w", err)
}

requestURL := strings.TrimRight(strings.TrimSpace(p.apiBase), "/") + "/images/generations"
diagnostic.Logf(params, "request_payload",
"url=%s body=%q",
diagnostic.RedactSensitive(requestURL),
diagnostic.RedactSensitive(string(payloadBytes)),
)

maxRetries := providerMaxRetries(p.config)
var elapsed time.Duration
resp, _, err := doRequestWithRetry(ctx, params, p.Name(), maxRetries, func(attempt int) (*http.Response, error) {
req, buildErr := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(payloadBytes))
if buildErr != nil {
return nil, fmt.Errorf("构建 OpenAI Images 请求失败: %w", buildErr)
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(p.config.APIKey))
req.Header.Set("Connection", "close")
if strings.TrimSpace(p.userAgent) != "" {
req.Header.Set("User-Agent", p.userAgent)
}

startedAt := time.Now()
resp, doErr := p.httpClient.Do(req)
elapsed = time.Since(startedAt)
return resp, doErr
})
if err != nil {
return nil, nil, fmt.Errorf("doRequest: error sending request: %w", err)
}
defer resp.Body.Close()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, resp.Header.Clone(), fmt.Errorf("读取 OpenAI Images 响应失败: %w", err)
}

requestID := extractRequestIDFromHeaders(resp.Header)
diagnostic.Logf(params, "response_headers",
"status=%s elapsed=%s request_id=%s headers=%q",
resp.Status,
elapsed,
requestID,
diagnostic.Preview(strings.Join(headerLines(resp.Header), " | "), 1000),
)
diagnostic.Logf(params, "response_body",
"status=%s elapsed=%s request_id=%s body=%q",
resp.Status,
elapsed,
requestID,
diagnostic.RedactSensitive(string(respBody)),
)

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
bodyPreview := diagnostic.Preview(parseOpenAIError(respBody), 1200)
if requestID == "" {
requestID = diagnostic.ExtractRequestID(string(respBody))
}
return nil, resp.Header.Clone(), fmt.Errorf("OpenAI HTTP %d request_id=%s body=%s", resp.StatusCode, requestID, bodyPreview)
}

if len(respBody) == 0 {
return nil, resp.Header.Clone(), fmt.Errorf("接口未返回内容")
}

return respBody, resp.Header.Clone(), nil
}
6 changes: 4 additions & 2 deletions backend/internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ var (

func defaultTimeoutSeconds(providerName string) int {
switch providerName {
case "gemini", "openai":
case "gemini", "openai", "openai-image":
return 500
default:
return 150
Expand Down Expand Up @@ -66,7 +66,7 @@ func InitProviders() error {
defer initMu.Unlock()

// 0. 确保基础 Provider 至少存在于数据库中(即使没有配置文件)
defaultProviders := []string{"gemini", "openai"}
defaultProviders := []string{"gemini", "openai", "openai-image"}
for _, name := range defaultProviders {
var count int64
model.DB.Model(&model.ProviderConfig{}).Where("provider_name = ?", name).Count(&count)
Expand Down Expand Up @@ -135,6 +135,8 @@ func InitProviders() error {
p, err = NewGeminiProvider(&cfg)
case "openai":
p, err = NewOpenAIProvider(&cfg)
case "openai-image":
p, err = NewOpenAIImageProvider(&cfg)
default:
log.Printf("未知的 Provider 类型: %s", cfg.ProviderName)
continue
Expand Down
26 changes: 13 additions & 13 deletions backend/internal/worker/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,19 @@ func (wp *WorkerPool) processTask(task *Task) {
len([]rune(task.TaskModel.Prompt)),
)

if err := wp.optimizePromptForTask(ctx, task); err != nil {
log.Printf("任务 %s 自动优化提示词失败,终止生图: %v", task.TaskModel.TaskID, err)
diagnostic.Logf(task.Params, "prompt_optimize_failed",
"mode=%s provider=%s model=%s err=%q fallback=%t",
task.TaskModel.PromptOptimizeMode,
promptopt.ExtractProvider(task.Params),
promptopt.ExtractModel(task.Params),
err.Error(),
false,
)
wp.failTask(task, fmt.Errorf("提示词优化失败: %w", err))
return
}
if err := wp.optimizePromptForTask(ctx, task); err != nil {
log.Printf("任务 %s 自动优化提示词失败,终止生图: %v", task.TaskModel.TaskID, err)
diagnostic.Logf(task.Params, "prompt_optimize_failed",
"mode=%s provider=%s model=%s err=%q fallback=%t",
task.TaskModel.PromptOptimizeMode,
promptopt.ExtractProvider(task.Params),
promptopt.ExtractModel(task.Params),
err.Error(),
false,
)
wp.failTask(task, fmt.Errorf("提示词优化失败: %w", err))
return
}

done := make(chan generateResult, 1)
go func() {
Expand Down
Loading
Loading