-
Notifications
You must be signed in to change notification settings - Fork 28
Add OpenAI Images provider support #53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fcb38a5
363c0cd
f1545ad
2569b56
75dd7d1
b054c49
39da894
fd85e39
944ac9c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,233 @@ | ||
| package provider | ||
|
|
||
| import ( | ||
| "bytes" | ||
| "context" | ||
| "encoding/json" | ||
| "fmt" | ||
| "image-gen-service/internal/diagnostic" | ||
| "image-gen-service/internal/model" | ||
| "io" | ||
| "net/http" | ||
| "strings" | ||
| "time" | ||
| ) | ||
|
|
||
| type OpenAIImageProvider struct { | ||
| *OpenAIProvider | ||
| } | ||
|
|
||
| type openAIImagesGenerationRequest struct { | ||
| Model string `json:"model"` | ||
| Prompt string `json:"prompt"` | ||
| Size string `json:"size"` | ||
| Quality string `json:"quality,omitempty"` | ||
| N int `json:"n,omitempty"` | ||
| } | ||
|
|
||
| func NewOpenAIImageProvider(config *model.ProviderConfig) (*OpenAIImageProvider, error) { | ||
| base, err := NewOpenAIProvider(config) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| return &OpenAIImageProvider{OpenAIProvider: base}, nil | ||
| } | ||
|
|
||
| func (p *OpenAIImageProvider) Name() string { | ||
| return "openai-image" | ||
| } | ||
|
|
||
| func (p *OpenAIImageProvider) ValidateParams(params map[string]interface{}) error { | ||
| prompt, _ := params["prompt"].(string) | ||
| if strings.TrimSpace(prompt) == "" { | ||
| return fmt.Errorf("prompt 不能为空") | ||
| } | ||
| if raw, ok := params["reference_images"].([]interface{}); ok && len(raw) > 0 { | ||
| return fmt.Errorf("OpenAI Images 当前仅支持文本生图") | ||
| } | ||
|
|
||
| count, ok := toInt(params["count"]) | ||
| if !ok { | ||
| count = 1 | ||
| } | ||
| if count < 1 || count > 10 { | ||
| return fmt.Errorf("count/n 必须介于 1 和 10 之间") | ||
| } | ||
|
|
||
| size, _ := params["size"].(string) | ||
| switch strings.TrimSpace(strings.ToLower(size)) { | ||
| case "", "auto", "1024x1024", "1024x1536", "1536x1024": | ||
| default: | ||
| return fmt.Errorf("size 仅支持 auto、1024x1024、1024x1536、1536x1024") | ||
| } | ||
|
|
||
| quality, _ := params["quality"].(string) | ||
| switch strings.TrimSpace(strings.ToLower(quality)) { | ||
| case "", "auto", "low", "medium", "high": | ||
| default: | ||
| return fmt.Errorf("quality 仅支持 auto、low、medium、high") | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| func (p *OpenAIImageProvider) Generate(ctx context.Context, params map[string]interface{}) (*ProviderResult, error) { | ||
| modelID := ResolveModelID(ModelResolveOptions{ | ||
| ProviderName: p.Name(), | ||
| Purpose: PurposeImage, | ||
| Params: params, | ||
| Config: p.config, | ||
| }).ID | ||
| if modelID == "" { | ||
| return nil, fmt.Errorf("缺少 model_id 参数") | ||
| } | ||
|
|
||
| reqBody, promptPreview, err := p.buildImagesGenerationRequestBody(modelID, params) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| diagnostic.Logf(params, "request_prepare", | ||
| "provider=%s model=%s size=%q quality=%q count=%d prompt_hash=%s prompt_preview=%q", | ||
| p.Name(), | ||
| modelID, | ||
| reqBody.Size, | ||
| reqBody.Quality, | ||
| reqBody.N, | ||
| diagnostic.PromptHash(promptPreview), | ||
| diagnostic.Preview(promptPreview, 160), | ||
| ) | ||
|
|
||
| respBytes, headers, err := p.doImagesGenerationRequest(ctx, reqBody, params) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| images, summary, err := p.extractImages(ctx, respBytes) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: This provider reuses the generic image extraction path, which will follow Severity Level: Critical 🚨- ❌ Image generation endpoint can probe internal HTTP services.
- ❌ Backend may fetch cloud metadata or admin endpoints.
- ⚠️ Non-image internal responses may be stored as images.Steps of Reproduction ✅1. Start the backend server by running `backend/cmd/server/main.go`; the route
`v1.POST("/tasks/generate", api.GenerateHandler)` is registered at
`backend/cmd/server/main.go:274-276`.
2. In the database `provider_config` table, configure provider `openai-image` so that its
`APIBase` points to an attacker-controlled OpenAI-compatible service; this base URL is
consumed by `NewOpenAIProvider` in `backend/internal/provider/openai.go:27-44` (stored as
`p.apiBase`) and `openai-image` is seeded in
`backend/internal/provider/provider.go:69-80`.
3. Implement the fake upstream's `/v1/images/generations` endpoint so that for any image
request it returns HTTP 200 with a body like
`{"data":[{"url":"http://127.0.0.1:8000/internal-admin"}]}` (no `b64_json` field) so that
only the `url` path is exercised.
4. From a client, call `POST /v1/tasks/generate` (handled by `GenerateHandler` in
`backend/internal/api/handlers.go:399-479`) with JSON such as
`{"provider":"openai-image","modelId":"test-model","params":{"prompt":"test
prompt","count":1}}`; the handler validates params, persists a `model.Task`, and enqueues
a `worker.Task` to `worker.Pool` at `backend/internal/api/handlers.go:31-49`.
5. The worker picks up the task in `processTask`
(`backend/internal/worker/pool.go:135-188`), looks up the provider via
`provider.GetProvider(task.TaskModel.ProviderName)` at
`backend/internal/worker/pool.go:175`, and in a goroutine calls `p.Generate(ctx,
task.Params)` at `backend/internal/worker/pool.go:26-27` with `ProviderName` set to
`"openai-image"`.
6. Inside `OpenAIImageProvider.Generate`
(`backend/internal/provider/openai_image.go:74-122`), the code calls
`p.doImagesGenerationRequest` to `p.apiBase + "/images/generations"` and then invokes
`images, summary, err := p.extractImages(ctx, respBytes)` at
`backend/internal/provider/openai_image.go:106`.
7. `extractImages` in `backend/internal/provider/openai.go:297-346` parses the JSON, sees
a non-empty `data` array, and calls `p.extractImagesFromData(ctx, data)` at
`openai.go:305-309`, which iterates each element and, finding only a `url` field, executes
`imgBytes, err := p.fetchImage(ctx, url)` at `openai.go:365-371`.
8. `fetchImage` (`backend/internal/provider/openai.go:438-467`) issues an HTTP GET using
`http.NewRequestWithContext(ctx, http.MethodGet, url, nil)` and `p.httpClient.Do(req)`
without any host, IP range, or scheme validation, so the backend server makes an outbound
request directly to `http://127.0.0.1:8000/internal-admin`. Observing this request on the
internal service or via network capture confirms that attacker-controlled `data[].url`
values from the upstream response cause server-side requests to arbitrary internal URLs
(SSRF).Fix in Cursor | Fix in VSCode Claude (Use Cmd/Ctrl + Click for best experience) Prompt for AI Agent 🤖This is a comment left during a code review.
**Path:** backend/internal/provider/openai_image.go
**Line:** 106:106
**Comment:**
*Ssrf: This provider reuses the generic image extraction path, which will follow `data[].url` values and issue server-side HTTP GETs without any host allowlist. For the new `/images/generations` flow this path is commonly exercised, so a malicious or compromised upstream/proxy response can trigger SSRF to internal services. Restrict this provider to `b64_json` responses (no remote fetch) or strictly validate allowed URL hosts/schemes before downloading.
Validate the correctness of the flagged issue. If correct, How can I resolve this? If you propose a fix, implement it and please make it concise.
Once fix is implemented, also check other comments on the same PR, and ask user if the user wants to fix the rest of the comments as well. if said yes, then fetch all the comments validate the correctness and implement a minimal fix |
||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| requestID := extractRequestIDFromHeaders(headers) | ||
| diagnostic.Logf(params, "response_summary", | ||
| "provider=%s model=%s data_count=%d choice_count=%d image_count=%d request_id=%s", | ||
| p.Name(), | ||
| modelID, | ||
| summary.DataCount, | ||
| summary.ChoiceCount, | ||
| len(images), | ||
| requestID, | ||
| ) | ||
|
|
||
| return &ProviderResult{ | ||
| Images: images, | ||
| Metadata: map[string]interface{}{ | ||
| "provider": p.Name(), | ||
| "model": modelID, | ||
| "type": "image", | ||
| "request_id": requestID, | ||
| "oneapi_request": strings.TrimSpace(headers.Get("X-Oneapi-Request-Id")), | ||
| }, | ||
| }, nil | ||
| } | ||
|
|
||
| func (p *OpenAIImageProvider) buildImagesGenerationRequestBody(modelID string, params map[string]interface{}) (*openAIImagesGenerationRequest, string, error) { | ||
| prompt, _ := params["prompt"].(string) | ||
| prompt = strings.TrimSpace(prompt) | ||
| if prompt == "" { | ||
| return nil, "", fmt.Errorf("缺少 prompt 参数") | ||
| } | ||
|
|
||
| body := &openAIImagesGenerationRequest{ | ||
| Model: modelID, | ||
| Prompt: prompt, | ||
| Size: "auto", | ||
| N: 1, | ||
| } | ||
| if size, _ := params["size"].(string); strings.TrimSpace(size) != "" { | ||
| body.Size = strings.TrimSpace(strings.ToLower(size)) | ||
| } | ||
| if quality, _ := params["quality"].(string); strings.TrimSpace(quality) != "" { | ||
| body.Quality = strings.TrimSpace(strings.ToLower(quality)) | ||
| } | ||
| if count, ok := toInt(params["count"]); ok && count >= 1 && count <= 10 { | ||
| body.N = count | ||
| } | ||
|
|
||
| return body, prompt, nil | ||
| } | ||
|
|
||
| func (p *OpenAIImageProvider) doImagesGenerationRequest(ctx context.Context, body *openAIImagesGenerationRequest, params map[string]interface{}) ([]byte, http.Header, error) { | ||
| payloadBytes, err := json.Marshal(body) | ||
| if err != nil { | ||
| return nil, nil, fmt.Errorf("序列化 OpenAI Images 请求失败: %w", err) | ||
| } | ||
|
|
||
| requestURL := strings.TrimRight(strings.TrimSpace(p.apiBase), "/") + "/images/generations" | ||
| diagnostic.Logf(params, "request_payload", | ||
| "url=%s body=%q", | ||
| diagnostic.RedactSensitive(requestURL), | ||
| diagnostic.RedactSensitive(string(payloadBytes)), | ||
| ) | ||
|
|
||
| maxRetries := providerMaxRetries(p.config) | ||
| var elapsed time.Duration | ||
| resp, _, err := doRequestWithRetry(ctx, params, p.Name(), maxRetries, func(attempt int) (*http.Response, error) { | ||
| req, buildErr := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(payloadBytes)) | ||
| if buildErr != nil { | ||
| return nil, fmt.Errorf("构建 OpenAI Images 请求失败: %w", buildErr) | ||
| } | ||
|
|
||
| req.Header.Set("Content-Type", "application/json") | ||
| req.Header.Set("Accept", "application/json") | ||
| req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(p.config.APIKey)) | ||
| req.Header.Set("Connection", "close") | ||
| if strings.TrimSpace(p.userAgent) != "" { | ||
| req.Header.Set("User-Agent", p.userAgent) | ||
| } | ||
|
|
||
| startedAt := time.Now() | ||
| resp, doErr := p.httpClient.Do(req) | ||
| elapsed = time.Since(startedAt) | ||
| return resp, doErr | ||
| }) | ||
| if err != nil { | ||
| return nil, nil, fmt.Errorf("doRequest: error sending request: %w", err) | ||
| } | ||
| defer resp.Body.Close() | ||
|
|
||
| respBody, err := io.ReadAll(resp.Body) | ||
| if err != nil { | ||
| return nil, resp.Header.Clone(), fmt.Errorf("读取 OpenAI Images 响应失败: %w", err) | ||
| } | ||
|
|
||
| requestID := extractRequestIDFromHeaders(resp.Header) | ||
| diagnostic.Logf(params, "response_headers", | ||
| "status=%s elapsed=%s request_id=%s headers=%q", | ||
| resp.Status, | ||
| elapsed, | ||
| requestID, | ||
| diagnostic.Preview(strings.Join(headerLines(resp.Header), " | "), 1000), | ||
| ) | ||
| diagnostic.Logf(params, "response_body", | ||
| "status=%s elapsed=%s request_id=%s body=%q", | ||
| resp.Status, | ||
| elapsed, | ||
| requestID, | ||
| diagnostic.RedactSensitive(string(respBody)), | ||
| ) | ||
|
|
||
| if resp.StatusCode < 200 || resp.StatusCode >= 300 { | ||
| bodyPreview := diagnostic.Preview(parseOpenAIError(respBody), 1200) | ||
| if requestID == "" { | ||
| requestID = diagnostic.ExtractRequestID(string(respBody)) | ||
| } | ||
| return nil, resp.Header.Clone(), fmt.Errorf("OpenAI HTTP %d request_id=%s body=%s", resp.StatusCode, requestID, bodyPreview) | ||
| } | ||
|
|
||
| if len(respBody) == 0 { | ||
| return nil, resp.Header.Clone(), fmt.Errorf("接口未返回内容") | ||
| } | ||
|
|
||
| return respBody, resp.Header.Clone(), nil | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: The default model returned for the
openaiprovider is a Gemini model ID, which does not match the OpenAI API contract. When request/config model is absent, this fallback will send an invalid model to OpenAI and cause generation requests to fail (typically 400 model-not-found). Use an OpenAI-compatible default model here (for examplegpt-image-1) to keep resolver behavior consistent with provider type. [api mismatch]Severity Level: Critical 🚨
Steps of Reproduction ✅
Fix in Cursor | Fix in VSCode Claude
(Use Cmd/Ctrl + Click for best experience)
Prompt for AI Agent 🤖