Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,7 @@ async def on_batch_complete(batch_idx, total, batch_previews):
from .image_utils import detect_aspect_ratio as _dar
all_preview_images = []
all_params_hashes = []
all_provider_usage = []
total_boxes = len(blurred_b64_list)
total_calls = total_boxes * batch_count
_progress_total_override = total_calls
Expand Down Expand Up @@ -1055,14 +1056,17 @@ async def on_batch_complete(batch_idx, total, batch_previews):
all_preview_images.extend(box_result["preview_images"])
elif "error" in box_result:
result["error"] = box_result["error"]
all_provider_usage.extend(box_result.get("provider_usage", []))

if box_result.get("params_hash"):
all_params_hashes.append(box_result["params_hash"])

# Update combined result, preserving any existing error messages
from .usage_tracker import aggregate_provider_usage as _aggregate_provider_usage
result["success"] = bool(all_preview_images)
result["preview_images"] = all_preview_images
result["params_hash"] = "_".join(all_params_hashes) if all_params_hashes else ""
result["provider_usage"] = _aggregate_provider_usage(all_provider_usage)
else:
# Merge: blurred image(s) + all reference images
merged_images = blurred_b64_list + reference_images_base64
Expand Down Expand Up @@ -1107,6 +1111,7 @@ async def on_batch_complete(batch_idx, total, batch_previews):
images_saved=_saved_count,
success=result.get("success", False),
providers_tried=[],
provider_usage=result.get("provider_usage", []),
error_message=result.get("error", ""),
duration_seconds=_usage_duration,
)
Expand Down Expand Up @@ -1241,7 +1246,8 @@ async def on_batch_complete(batch_idx, total, batch_previews):
images_generated=len(_preview_images),
images_saved=_saved_count,
success=result.get("success", False),
providers_tried=[], # Aggregated per-batch, not available in result dict
providers_tried=[], # Derived from provider_usage by UsageTracker
provider_usage=result.get("provider_usage", []),
error_message=result.get("error", ""),
duration_seconds=_usage_duration,
)
Expand Down
69 changes: 52 additions & 17 deletions adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class APIResponse:
task_id: str = ""
status: str = "" # pending, processing, success, failed
providers_tried: List[str] = field(default_factory=list) # Usage tracking: providers attempted
provider_usage: List[Dict[str, Any]] = field(default_factory=list) # Safe provider/key image counts


@dataclass
Expand Down Expand Up @@ -81,42 +82,76 @@ def api_key(self) -> str:
2. api_keys: ["key1", "key2"] (plain list)
3. api_key: "single_key" (original)
"""
return self.select_api_key_info().get("key", "")

def select_api_key_info(self) -> Dict[str, Any]:
"""Select an API key and return full key plus a safe display label.

The returned ``key`` is for request signing only. ``label`` is safe for
usage logs and dashboards because it never contains the complete key.
"""
import random
import time

keys_config = self.provider.get("api_keys", [])
if keys_config and isinstance(keys_config, list):
# Build list of active keys
active_keys = []
for item in keys_config:
# Build list of active keys with stable config-order numbering.
active_infos = []
for idx, item in enumerate(keys_config, start=1):
if isinstance(item, dict):
if item.get("enabled", True) and item.get("key"):
active_keys.append(item["key"])
active_infos.append(self._make_key_info(
item["key"],
idx,
str(item.get("name") or "").strip(),
))
elif isinstance(item, str) and item:
active_keys.append(item)
active_infos.append(self._make_key_info(item, idx, ""))

if active_keys:
if active_infos:
# Filter out blacklisted keys (always use CLASS-level dict, never instance)
now = time.time()
bad = APIAdapter._bad_keys # explicit class reference
good_keys = [k for k in active_keys if k not in bad or bad[k] < now]


# Clean expired entries IN-PLACE (don't reassign to avoid shadowing)
expired = [k for k, t in bad.items() if t < now]
for k in expired:
del bad[k]


good_infos = [
info for info in active_infos
if info["key"] not in bad or bad[info["key"]] < now
]

# Use good keys if available, fall back to all keys if all blacklisted
pool = good_keys if good_keys else active_keys
key = random.choice(pool)
# Key numbering: stable index based on config order (1-based)
key_num = active_keys.index(key) + 1 if key in active_keys else "?"
pool = good_infos if good_infos else active_infos
info = random.choice(pool)
provider_name = self.provider.get("display_name", self.provider.get("name", "?"))
skipped = len(active_keys) - len(good_keys)
skipped = len(active_infos) - len(good_infos)
skip_info = f", {skipped} blacklisted" if skipped else ""
print(f"[APIAdapter] Key#{key_num} selected for {provider_name} ({len(pool)}/{len(active_keys)} keys{skip_info})")
return key
return self.provider.get("api_key", "")
print(f"[APIAdapter] Key#{info['index']} selected for {provider_name} ({len(pool)}/{len(active_infos)} keys{skip_info})")
return dict(info)

key = self.provider.get("api_key", "")
return self._make_key_info(key, 1, "") if key else {"key": "", "index": 0, "name": "", "label": ""}

def _make_key_info(self, key: str, index: int, name: str = "") -> Dict[str, Any]:
safe_name = str(name or "").strip() or f"Key#{index}"
return {
"key": key,
"index": index,
"name": safe_name,
"label": self._format_safe_key_label(key, index=index, name=safe_name),
}

@staticmethod
def _format_safe_key_label(key: str, index: int = 1, name: str = "") -> str:
safe_name = str(name or "").strip() or f"Key#{index}"
key_str = str(key or "")
if not key_str:
return safe_name
suffix = key_str[-6:]
return f"{safe_name} · ****{suffix}"

# Store active_keys list for key number lookup in blacklist_key
_last_active_keys: list = []
Expand Down
36 changes: 34 additions & 2 deletions adapters/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def build_request(self, params: Dict, mode: str = "text2img") -> Dict:
def _build_openai_request(self, params: Dict, mode: str = "text2img") -> Dict:
"""Build request for OpenAI-compatible APIs."""
# Pick API key ONCE (self.api_key returns random key each call)
_current_key = self.api_key or ""
key_info = self.select_api_key_info()
_current_key = key_info.get("key", "") or ""
key_label = key_info.get("label", "")
used_key = _current_key # Track for blacklisting on failure
endpoint_path = self.mode_config.get("endpoint", "")
method = self.mode_config.get("method", "POST")
Expand All @@ -86,6 +88,7 @@ def _build_openai_request(self, params: Dict, mode: str = "text2img") -> Dict:
# Build headers - support Account auth mode
auth_type = self.endpoint.get("auth_type", "api")
if auth_type == "account":
key_label = "Account Token"
# Account mode: use X-Auth-T token from Account singleton
try:
from ..account import Account
Expand All @@ -94,6 +97,7 @@ def _build_openai_request(self, params: Dict, mode: str = "text2img") -> Dict:
except Exception:
headers = {"Authorization": f"Bearer {_current_key}"}
elif auth_type == "service_account":
key_label = "Service Account"
# Vertex AI proper: OAuth2 Bearer token from Service Account JSON
try:
from ..vertex_sa_auth import get_access_token
Expand Down Expand Up @@ -219,6 +223,9 @@ def _build_openai_request(self, params: Dict, mode: str = "text2img") -> Dict:
"method": method,
"headers": headers,
"_used_key": used_key,
"_provider": self.provider.get("name", "unknown"),
"_provider_label": self.provider.get("display_name") or self.provider.get("name", "unknown"),
"_key_label": key_label,
}

if content_type == "application/json":
Expand Down Expand Up @@ -298,7 +305,9 @@ def _build_gemini_request(self, params: Dict, mode: str = "text2img") -> Dict:

# Pick API key ONCE for the entire request (URL + Files API + tracking)
# CRITICAL: self.api_key is a property that returns a random key each call!
_current_key = self.api_key or ""
key_info = self.select_api_key_info()
_current_key = key_info.get("key", "") or ""
key_label = key_info.get("label", "")
used_key = _current_key # Track for blacklisting on failure

endpoint_path = self.mode_config.get("endpoint", "")
Expand Down Expand Up @@ -327,6 +336,7 @@ def _build_gemini_request(self, params: Dict, mode: str = "text2img") -> Dict:
auth_header_format = self.endpoint.get("auth_header_format", "bearer")

if auth_type == "account":
key_label = "Account Token"
try:
from ..account import Account
account = Account.get_instance()
Expand Down Expand Up @@ -369,6 +379,8 @@ def _build_gemini_request(self, params: Dict, mode: str = "text2img") -> Dict:
from ..vertex_sa_auth import get_random_sa_token
sa_token, sa_project_id, sa_name = get_random_sa_token(self.provider)
if sa_token:
safe_sa_name = str(sa_name or "Service Account").replace("\\", "/").rsplit("/", 1)[-1]
key_label = f"Service Account · {safe_sa_name}"
headers = {
"Authorization": f"Bearer {sa_token}",
"Content-Type": "application/json"
Expand Down Expand Up @@ -566,9 +578,28 @@ def _build_gemini_request(self, params: Dict, mode: str = "text2img") -> Dict:
"headers": headers,
"json": payload,
"_used_key": used_key,
"_provider": self.provider.get("name", "unknown"),
"_provider_label": self.provider.get("display_name") or self.provider.get("name", "unknown"),
"_key_label": key_label,
}


def _attach_provider_usage(self, result: APIResponse, request_info: Dict[str, Any]) -> APIResponse:
"""Attach safe provider/key image counts to a successful response."""
if not result.success:
return result
gen_count = len(result.images or [])
if gen_count <= 0:
return result

result.provider_usage.append({
"provider": request_info.get("_provider") or self.provider.get("name", "unknown"),
"provider_label": request_info.get("_provider_label") or self.provider.get("display_name") or self.provider.get("name", "unknown"),
"key_label": request_info.get("_key_label") or "未记录Key",
"gen": gen_count,
})
return result

def _prepare_images_base64(self, params: Dict) -> Dict:
"""
Convert uploaded files to base64 data URLs for Chat API format.
Expand Down Expand Up @@ -1074,6 +1105,7 @@ def execute(self, params: Dict, mode: str = "text2img",
except Exception:
pass

self._attach_provider_usage(result, request_info)
return result

except requests.Timeout:
Expand Down
18 changes: 18 additions & 0 deletions adapters/volcengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,24 @@ def execute(self, params: Dict, mode: str = "inpaint") -> APIResponse:
if img_bytes:
result.images.append(img_bytes)

self._attach_provider_usage(result)
return result

def _attach_provider_usage(self, result: APIResponse) -> APIResponse:
"""Attach safe Volcengine access-key image counts to successful image results."""
if not result.success or not result.images:
return result
provider = self.provider.get("name", "volcengine")
result.provider_usage.append({
"provider": provider,
"provider_label": self.provider.get("display_name") or provider,
"key_label": self._format_safe_key_label(
self.provider.get("access_key", ""),
index=1,
name="AccessKey",
),
"gen": len(result.images),
})
return result

def _poll_for_result(self, task_id: str, timeout: int = 120) -> APIResponse:
Expand Down
Loading
Loading