From 5bf07076e71b1ad6e255e7e5480b73fd08aa48d9 Mon Sep 17 00:00:00 2001 From: iOrange <1606894729@qq.com> Date: Tue, 12 May 2026 18:03:24 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E5=A2=9E=E5=8A=A0=E4=BE=9B?= =?UTF-8?q?=E5=BA=94=E5=95=86/API=E7=94=A8=E9=87=8F=E7=9B=91=E6=8E=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增供应商和脱敏 API Key 的生图统计记录,并在监控端增加供应商详情页、饼图和最近活动供应商列。 --- __init__.py | 8 +- adapters/base.py | 69 +++++++-- adapters/generic.py | 36 ++++- adapters/volcengine.py | 18 +++ dashboard.py | 205 +++++++++++++++++++++++-- independent_generator.py | 16 +- nodes.py | 40 ++++- tests/test_base_adapter.py | 28 ++++ tests/test_dashboard_provider_usage.py | 108 +++++++++++++ tests/test_independent_generator.py | 54 +++++++ tests/test_usage_tracker.py | 46 ++++++ usage_tracker.py | 76 ++++++++- 12 files changed, 661 insertions(+), 43 deletions(-) create mode 100644 tests/test_dashboard_provider_usage.py diff --git a/__init__.py b/__init__.py index 5d64308..89aff5f 100644 --- a/__init__.py +++ b/__init__.py @@ -959,6 +959,7 @@ async def on_batch_complete(batch_idx, total, batch_previews): from .image_utils import detect_aspect_ratio as _dar all_preview_images = [] all_params_hashes = [] + all_provider_usage = [] total_boxes = len(blurred_b64_list) total_calls = total_boxes * batch_count _progress_total_override = total_calls @@ -1055,14 +1056,17 @@ async def on_batch_complete(batch_idx, total, batch_previews): all_preview_images.extend(box_result["preview_images"]) elif "error" in box_result: result["error"] = box_result["error"] + all_provider_usage.extend(box_result.get("provider_usage", [])) if box_result.get("params_hash"): all_params_hashes.append(box_result["params_hash"]) # Update combined result, preserving any existing error messages + from .usage_tracker import aggregate_provider_usage as _aggregate_provider_usage result["success"] = bool(all_preview_images) result["preview_images"] = all_preview_images result["params_hash"] = "_".join(all_params_hashes) if all_params_hashes else "" + result["provider_usage"] = _aggregate_provider_usage(all_provider_usage) else: # Merge: blurred image(s) + all reference images merged_images = blurred_b64_list + reference_images_base64 @@ -1107,6 +1111,7 @@ async def on_batch_complete(batch_idx, total, batch_previews): images_saved=_saved_count, success=result.get("success", False), providers_tried=[], + provider_usage=result.get("provider_usage", []), error_message=result.get("error", ""), duration_seconds=_usage_duration, ) @@ -1241,7 +1246,8 @@ async def on_batch_complete(batch_idx, total, batch_previews): images_generated=len(_preview_images), images_saved=_saved_count, success=result.get("success", False), - providers_tried=[], # Aggregated per-batch, not available in result dict + providers_tried=[], # Derived from provider_usage by UsageTracker + provider_usage=result.get("provider_usage", []), error_message=result.get("error", ""), duration_seconds=_usage_duration, ) diff --git a/adapters/base.py b/adapters/base.py index b35c4f5..f3057fd 100644 --- a/adapters/base.py +++ b/adapters/base.py @@ -28,6 +28,7 @@ class APIResponse: task_id: str = "" status: str = "" # pending, processing, success, failed providers_tried: List[str] = field(default_factory=list) # Usage tracking: providers attempted + provider_usage: List[Dict[str, Any]] = field(default_factory=list) # Safe provider/key image counts @dataclass @@ -81,42 +82,76 @@ def api_key(self) -> str: 2. api_keys: ["key1", "key2"] (plain list) 3. api_key: "single_key" (original) """ + return self.select_api_key_info().get("key", "") + + def select_api_key_info(self) -> Dict[str, Any]: + """Select an API key and return full key plus a safe display label. + + The returned ``key`` is for request signing only. ``label`` is safe for + usage logs and dashboards because it never contains the complete key. + """ import random import time keys_config = self.provider.get("api_keys", []) if keys_config and isinstance(keys_config, list): - # Build list of active keys - active_keys = [] - for item in keys_config: + # Build list of active keys with stable config-order numbering. + active_infos = [] + for idx, item in enumerate(keys_config, start=1): if isinstance(item, dict): if item.get("enabled", True) and item.get("key"): - active_keys.append(item["key"]) + active_infos.append(self._make_key_info( + item["key"], + idx, + str(item.get("name") or "").strip(), + )) elif isinstance(item, str) and item: - active_keys.append(item) + active_infos.append(self._make_key_info(item, idx, "")) - if active_keys: + if active_infos: # Filter out blacklisted keys (always use CLASS-level dict, never instance) now = time.time() bad = APIAdapter._bad_keys # explicit class reference - good_keys = [k for k in active_keys if k not in bad or bad[k] < now] - + # Clean expired entries IN-PLACE (don't reassign to avoid shadowing) expired = [k for k, t in bad.items() if t < now] for k in expired: del bad[k] - + + good_infos = [ + info for info in active_infos + if info["key"] not in bad or bad[info["key"]] < now + ] + # Use good keys if available, fall back to all keys if all blacklisted - pool = good_keys if good_keys else active_keys - key = random.choice(pool) - # Key numbering: stable index based on config order (1-based) - key_num = active_keys.index(key) + 1 if key in active_keys else "?" + pool = good_infos if good_infos else active_infos + info = random.choice(pool) provider_name = self.provider.get("display_name", self.provider.get("name", "?")) - skipped = len(active_keys) - len(good_keys) + skipped = len(active_infos) - len(good_infos) skip_info = f", {skipped} blacklisted" if skipped else "" - print(f"[APIAdapter] Key#{key_num} selected for {provider_name} ({len(pool)}/{len(active_keys)} keys{skip_info})") - return key - return self.provider.get("api_key", "") + print(f"[APIAdapter] Key#{info['index']} selected for {provider_name} ({len(pool)}/{len(active_infos)} keys{skip_info})") + return dict(info) + + key = self.provider.get("api_key", "") + return self._make_key_info(key, 1, "") if key else {"key": "", "index": 0, "name": "", "label": ""} + + def _make_key_info(self, key: str, index: int, name: str = "") -> Dict[str, Any]: + safe_name = str(name or "").strip() or f"Key#{index}" + return { + "key": key, + "index": index, + "name": safe_name, + "label": self._format_safe_key_label(key, index=index, name=safe_name), + } + + @staticmethod + def _format_safe_key_label(key: str, index: int = 1, name: str = "") -> str: + safe_name = str(name or "").strip() or f"Key#{index}" + key_str = str(key or "") + if not key_str: + return safe_name + suffix = key_str[-6:] + return f"{safe_name} · ****{suffix}" # Store active_keys list for key number lookup in blacklist_key _last_active_keys: list = [] diff --git a/adapters/generic.py b/adapters/generic.py index bc7f336..b0cef75 100644 --- a/adapters/generic.py +++ b/adapters/generic.py @@ -75,7 +75,9 @@ def build_request(self, params: Dict, mode: str = "text2img") -> Dict: def _build_openai_request(self, params: Dict, mode: str = "text2img") -> Dict: """Build request for OpenAI-compatible APIs.""" # Pick API key ONCE (self.api_key returns random key each call) - _current_key = self.api_key or "" + key_info = self.select_api_key_info() + _current_key = key_info.get("key", "") or "" + key_label = key_info.get("label", "") used_key = _current_key # Track for blacklisting on failure endpoint_path = self.mode_config.get("endpoint", "") method = self.mode_config.get("method", "POST") @@ -86,6 +88,7 @@ def _build_openai_request(self, params: Dict, mode: str = "text2img") -> Dict: # Build headers - support Account auth mode auth_type = self.endpoint.get("auth_type", "api") if auth_type == "account": + key_label = "Account Token" # Account mode: use X-Auth-T token from Account singleton try: from ..account import Account @@ -94,6 +97,7 @@ def _build_openai_request(self, params: Dict, mode: str = "text2img") -> Dict: except Exception: headers = {"Authorization": f"Bearer {_current_key}"} elif auth_type == "service_account": + key_label = "Service Account" # Vertex AI proper: OAuth2 Bearer token from Service Account JSON try: from ..vertex_sa_auth import get_access_token @@ -219,6 +223,9 @@ def _build_openai_request(self, params: Dict, mode: str = "text2img") -> Dict: "method": method, "headers": headers, "_used_key": used_key, + "_provider": self.provider.get("name", "unknown"), + "_provider_label": self.provider.get("display_name") or self.provider.get("name", "unknown"), + "_key_label": key_label, } if content_type == "application/json": @@ -298,7 +305,9 @@ def _build_gemini_request(self, params: Dict, mode: str = "text2img") -> Dict: # Pick API key ONCE for the entire request (URL + Files API + tracking) # CRITICAL: self.api_key is a property that returns a random key each call! - _current_key = self.api_key or "" + key_info = self.select_api_key_info() + _current_key = key_info.get("key", "") or "" + key_label = key_info.get("label", "") used_key = _current_key # Track for blacklisting on failure endpoint_path = self.mode_config.get("endpoint", "") @@ -327,6 +336,7 @@ def _build_gemini_request(self, params: Dict, mode: str = "text2img") -> Dict: auth_header_format = self.endpoint.get("auth_header_format", "bearer") if auth_type == "account": + key_label = "Account Token" try: from ..account import Account account = Account.get_instance() @@ -369,6 +379,8 @@ def _build_gemini_request(self, params: Dict, mode: str = "text2img") -> Dict: from ..vertex_sa_auth import get_random_sa_token sa_token, sa_project_id, sa_name = get_random_sa_token(self.provider) if sa_token: + safe_sa_name = str(sa_name or "Service Account").replace("\\", "/").rsplit("/", 1)[-1] + key_label = f"Service Account · {safe_sa_name}" headers = { "Authorization": f"Bearer {sa_token}", "Content-Type": "application/json" @@ -566,9 +578,28 @@ def _build_gemini_request(self, params: Dict, mode: str = "text2img") -> Dict: "headers": headers, "json": payload, "_used_key": used_key, + "_provider": self.provider.get("name", "unknown"), + "_provider_label": self.provider.get("display_name") or self.provider.get("name", "unknown"), + "_key_label": key_label, } + def _attach_provider_usage(self, result: APIResponse, request_info: Dict[str, Any]) -> APIResponse: + """Attach safe provider/key image counts to a successful response.""" + if not result.success: + return result + gen_count = len(result.images or []) + if gen_count <= 0: + return result + + result.provider_usage.append({ + "provider": request_info.get("_provider") or self.provider.get("name", "unknown"), + "provider_label": request_info.get("_provider_label") or self.provider.get("display_name") or self.provider.get("name", "unknown"), + "key_label": request_info.get("_key_label") or "未记录Key", + "gen": gen_count, + }) + return result + def _prepare_images_base64(self, params: Dict) -> Dict: """ Convert uploaded files to base64 data URLs for Chat API format. @@ -1074,6 +1105,7 @@ def execute(self, params: Dict, mode: str = "text2img", except Exception: pass + self._attach_provider_usage(result, request_info) return result except requests.Timeout: diff --git a/adapters/volcengine.py b/adapters/volcengine.py index 4886fe6..7a417e3 100644 --- a/adapters/volcengine.py +++ b/adapters/volcengine.py @@ -407,6 +407,24 @@ def execute(self, params: Dict, mode: str = "inpaint") -> APIResponse: if img_bytes: result.images.append(img_bytes) + self._attach_provider_usage(result) + return result + + def _attach_provider_usage(self, result: APIResponse) -> APIResponse: + """Attach safe Volcengine access-key image counts to successful image results.""" + if not result.success or not result.images: + return result + provider = self.provider.get("name", "volcengine") + result.provider_usage.append({ + "provider": provider, + "provider_label": self.provider.get("display_name") or provider, + "key_label": self._format_safe_key_label( + self.provider.get("access_key", ""), + index=1, + name="AccessKey", + ), + "gen": len(result.images), + }) return result def _poll_for_result(self, task_id: str, timeout: int = 120) -> APIResponse: diff --git a/dashboard.py b/dashboard.py index 6a3f797..6105a07 100644 --- a/dashboard.py +++ b/dashboard.py @@ -120,6 +120,7 @@ def compute_stats(records: list, date_filter: str = "today", "dur_avg": 0, "dur_min": 0, "dur_max": 0, "machines": [], "models": {}, "nodes": {}, "recent": [], "timeline": {}, "machine_list": [], + "providers_usage": [], } if not records: return empty @@ -183,6 +184,46 @@ def compute_stats(records: list, date_filter: str = "today", node = r.get("node", "unknown") node_map[node] = node_map.get(node, 0) + 1 + # Per-provider/key image counts. Only direct provider_usage records are used. + provider_map = {} + for r in records: + entries = r.get("provider_usage", []) + if not isinstance(entries, list): + continue + for entry in entries: + if not isinstance(entry, dict): + continue + try: + gen = int(entry.get("gen", 0)) + except (TypeError, ValueError): + continue + if gen <= 0: + continue + provider = str(entry.get("provider") or entry.get("provider_label") or "unknown") + provider_label = str(entry.get("provider_label") or provider) + key_label = str(entry.get("key_label") or "未记录Key") + if provider not in provider_map: + provider_map[provider] = { + "provider": provider, + "provider_label": provider_label, + "gen": 0, + "keys": {}, + } + p = provider_map[provider] + p["gen"] += gen + p["keys"][key_label] = p["keys"].get(key_label, 0) + gen + + providers_usage = [] + for provider_data in provider_map.values(): + keys = [ + {"key_label": key_label, "gen": gen} + for key_label, gen in provider_data.pop("keys").items() + ] + keys.sort(key=lambda x: (-x["gen"], x["key_label"])) + provider_data["keys"] = keys + providers_usage.append(provider_data) + providers_usage.sort(key=lambda x: (-x["gen"], x["provider_label"])) + # Timeline aggregation # If viewing a single day (today or date_exact), group by hour. # If viewing week/all, group by day. @@ -228,6 +269,7 @@ def compute_stats(records: list, date_filter: str = "today", "nodes": node_map, "recent": recent, "timeline": sorted_tl, + "providers_usage": providers_usage, } @@ -406,6 +448,32 @@ def save_machine_notes(data_dir: str, notes: dict): .recent-table::-webkit-scrollbar { width:6px; } .recent-table::-webkit-scrollbar-thumb { background:var(--border); border-radius:3px; } + /* Provider usage */ + .provider-usage-list { + display:grid; grid-template-columns:repeat(auto-fit, minmax(260px, 1fr)); + gap:12px; + } + .provider-card { + background:var(--card-bg); border:1px solid var(--border); border-radius:10px; + padding:14px; + } + .provider-card-head { + display:flex; justify-content:space-between; align-items:center; gap:10px; + font-weight:700; margin-bottom:10px; padding-bottom:8px; + border-bottom:1px solid var(--border); + } + .provider-total { color:var(--green); white-space:nowrap; } + .provider-key-row { + display:flex; justify-content:space-between; gap:10px; + font-size:0.82rem; color:var(--text-dim); padding:5px 0; + } + .provider-key-name { overflow:hidden; text-overflow:ellipsis; white-space:nowrap; } + .provider-key-gen { color:var(--text); font-weight:600; white-space:nowrap; } + .empty-block { + color:var(--text-dim); background:var(--card-bg); border:1px solid var(--border); + border-radius:10px; padding:18px; text-align:center; + } + /* Machine detail grid (Tab 2) */ .machine-grid { display:grid; grid-template-columns:repeat(auto-fill, minmax(260px, 1fr)); @@ -455,6 +523,7 @@ def save_machine_notes(data_dir: str, notes: dict):
@@ -511,7 +580,7 @@ def save_machine_notes(data_dir: str, notes: dict):| 时间 | 机器 | 节点 | 模型 | 批次 | -生成 | 保存 | 状态 (鼠标悬浮看原因) | 耗时 | +供应商 | 生成 | 保存 | 状态 (鼠标悬浮看原因) | 耗时 |
|---|