-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcontext_utils.py
More file actions
259 lines (216 loc) · 8.93 KB
/
Copy pathcontext_utils.py
File metadata and controls
259 lines (216 loc) · 8.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""
Context management utilities: token estimation, message pruning, history summarisation.
"""
from __future__ import annotations
import logging
import re
from typing import Any
from config import PROVIDERS
logger = logging.getLogger(__name__)
# Matches inline base64 image data URIs so they can be stripped before token counting
_B64_IMG_RE = re.compile(r"data:image/[^;]+;base64,[A-Za-z0-9+/=]{100,}")
def estimate_tet_tokens(text: str) -> int:
"""
Conservative token estimation.
English is ~4 chars/token, but code/German is closer to 2.5-3.
We use 3 chars/token as a safe buffer to prevent API errors.
"""
if not text:
return 0
return len(str(text)) // 3
def estimate_tokens(content: str | list[dict[str, Any]]) -> int:
"""
Handles both legacy strings and new Gradio 6 multimodal lists.
Strips inline base64 image data before counting so a generated image
doesn't inflate history token estimates to 100k+.
"""
if isinstance(content, str):
stripped, n_images = _B64_IMG_RE.subn("[IMAGE]", content)
return len(stripped) // 3 + n_images * 1000
total = 0
if isinstance(content, list):
for part in content:
if part.get("type") == "text":
total += len(part.get("text", "")) // 3
elif part.get("type") == "image_url":
total += 1000
return total
def get_model_context_limit(provider: str, model: str) -> int:
"""Get context window size for a specific model."""
provider_data: dict = PROVIDERS.get(provider, {}) # type: ignore[assignment]
if "context_limits" in provider_data:
return provider_data["context_limits"].get(model, 4096)
defaults = {
"Scaleway": 32000,
"Nebius": 128000,
"Mistral": 128000,
"OpenRouter": 128000,
"Groq": 8192,
"Poe": 128000,
"Deepgram": 16384,
"AssemblyAI": 16384,
}
return defaults.get(provider, 4096)
def check_content_fits_context(
content: str, provider: str, model: str, reserve_tokens: int = 1000
) -> tuple[bool, int, int]:
"""
Check if content fits within the model's context window.
Returns (fits, estimated_tokens, max_tokens).
"""
estimated = estimate_tokens(content)
limit = get_model_context_limit(provider, model)
usable_limit = limit - reserve_tokens
return (estimated <= usable_limit, estimated, limit)
def summarize_turns(turns: list[dict[str, Any]], client: Any, model: str) -> str:
"""
Ask the LLM to produce a brief factual summary of old conversation turns.
Used as a compact replacement for dropped history.
Returns a summary string, or "" on failure.
"""
if not turns or client is None:
return ""
try:
lines = []
for t in turns:
role = t.get("role", "?")
content = t.get("content", "")
if isinstance(content, list):
content = " ".join(p.get("text", "") for p in content if p.get("type") == "text")
content = re.sub(r"data:image/[^;]+;base64,[A-Za-z0-9+/=]{100,}", "[image]", content)
content = re.sub(
r"^(?:🔍|🌐|🧮|🎨|🐍|🔬) \*[^\n]*\*\n\n", "", content, flags=re.MULTILINE
)
lines.append(f"{role.upper()}: {content[:400]}")
transcript = "\n".join(lines)
resp = client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": (
"Summarise the following conversation excerpt in 3–5 bullet points. "
"Be factual, concise, in the same language as the conversation. "
"Preserve key facts, numbers, and decisions."
),
},
{"role": "user", "content": transcript},
],
max_tokens=300,
stream=False,
)
summary = resp.choices[0].message.content or ""
print(f"[CTX] 📝 Summarised {len(turns)} old turns → {len(summary)} chars")
return summary
except Exception as e:
logger.warning(f"[CTX] summarize_turns failed: {e}")
return ""
def prune_messages(
messages: list[dict[str, Any]], model_limit: int, max_output_tokens: int = 1000
) -> list[dict[str, Any]]:
"""
Smartly trims conversation history with verbose CLI logging.
Always keeps the system prompt and the latest user message.
Fills remaining budget newest-first, keeping user/assistant turn pairs together.
"""
safety_buffer = 500
safe_input_limit = model_limit - max_output_tokens - safety_buffer
if safe_input_limit < 1000:
safe_input_limit = 2000
print(f"[CTX] ⚠️ Warning: calculated input limit was too low. Forced to {safe_input_limit}")
print("\n[CTX] 📊 CONTEXT CALCULATION:")
print(f"[CTX] Model Limit: {model_limit}")
print(f"[CTX] Output Reserve: -{max_output_tokens}")
print(f"[CTX] Safety Buffer: -{safety_buffer}")
print("[CTX] ==========================")
print(f"[CTX] AVAILABLE INPUT BUDGET: {safe_input_limit} tokens")
current_tokens = 0
kept_indices = []
# Always keep system prompt
if messages and messages[0]["role"] == "system":
t = estimate_tokens(messages[0]["content"])
current_tokens += t
kept_indices.append(0)
print(f"[CTX] + {t:5d} tok | (Required) System Prompt")
# Always keep latest user message
if len(messages) > 1:
last_idx = len(messages) - 1
if last_idx not in kept_indices:
t = estimate_tokens(messages[last_idx]["content"])
current_tokens += t
kept_indices.append(last_idx)
print(f"[CTX] + {t:5d} tok | (Required) Latest User Message")
# Fill budget with history newest→oldest, keeping turn pairs together
remaining_indices = [i for i in range(len(messages)) if i not in kept_indices]
remaining_indices.reverse()
skip_set: set = set()
for i in remaining_indices:
if i in skip_set:
continue
msg = messages[i]
role = msg["role"]
content = msg["content"]
if isinstance(content, list):
text_parts = [p["text"] for p in content if p.get("type") == "text"]
preview = (text_parts[0][:30] if text_parts else "[Image/Other]") + "..."
else:
preview = str(content)[:30].replace("\n", " ") + "..."
msg_tokens = estimate_tokens(content)
pair_idx = None
pair_tokens = 0
if role == "assistant" and i > 0 and i - 1 not in kept_indices and i - 1 not in skip_set:
pair_content = messages[i - 1].get("content", "")
pair_tokens = estimate_tokens(pair_content)
pair_idx = i - 1
total_needed = msg_tokens + pair_tokens
if current_tokens + total_needed <= safe_input_limit:
current_tokens += total_needed
kept_indices.append(i)
if pair_idx is not None:
kept_indices.append(pair_idx)
skip_set.add(pair_idx)
print(f"[CTX] + {msg_tokens:5d} tok | (Kept) {role}: {preview}")
else:
print(f"[CTX] - {msg_tokens:5d} tok | (SKIP) {role}: {preview} [budget]")
if pair_idx is not None:
skip_set.add(pair_idx)
final_messages = [messages[i] for i in sorted(kept_indices)]
print("[CTX] ==========================")
print(f"[CTX] Total Used: {current_tokens} / {safe_input_limit}")
print(
f"[CTX] Messages: {len(final_messages)} kept, {len(messages) - len(final_messages)} dropped.\n"
)
return final_messages
def split_content_into_chunks(text: str, max_tokens: int = 4000, overlap: int = 200) -> list[str]:
"""
Split text into overlapping chunks that fit within token limits.
"""
max_chars = max_tokens * 4
overlap_chars = overlap * 4
paragraphs = text.split("\n\n")
chunks: list[str] = []
current_chunk: list[str] = []
current_size = 0
for para in paragraphs:
para_size = len(para)
if current_size + para_size <= max_chars:
current_chunk.append(para)
current_size += para_size + 2
else:
if current_chunk:
chunk_text = "\n\n".join(current_chunk)
chunks.append(chunk_text)
if overlap_chars > 0 and chunks:
overlap_text = chunks[-1][-overlap_chars:]
current_chunk = [overlap_text, para]
current_size = len(overlap_text) + para_size + 2
else:
current_chunk = [para]
current_size = para_size
else:
sentences = para.split(". ")
current_chunk = [sentences[0]]
current_size = len(sentences[0])
if current_chunk:
chunks.append("\n\n".join(current_chunk))
return chunks