Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions ds4_kvstore.c
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,24 @@ static bool kv_cache_incoming_supersedes_continued(
return !strcmp(prefix_sha, e->sha);
}

static bool kv_cache_incoming_protects_prefix(
const ds4_kvstore_entry *e,
const ds4_kvstore_eviction_context *incoming) {
if (!e || !incoming || !incoming->protect_text) return false;
if (e->text_bytes == 0 || e->text_bytes > SIZE_MAX) return false;
if ((size_t)e->text_bytes > incoming->protect_text_len) return false;
if (e->model_id != incoming->model_id) return false;
if (incoming->reject_different_quant &&
e->quant_bits != incoming->quant_bits)
return false;
if (incoming->ctx_size > e->ctx_size) return false;

char prefix_sha[41];
ds4_kvstore_sha1_bytes_hex(incoming->protect_text,
(size_t)e->text_bytes, prefix_sha);
return !strcmp(prefix_sha, e->sha);
}

static bool kv_cache_reason_is_anchor(uint8_t reason) {
return reason == DS4_KVSTORE_REASON_COLD ||
reason == DS4_KVSTORE_REASON_EVICT ||
Expand Down Expand Up @@ -555,6 +573,7 @@ double ds4_kvstore_entry_eviction_score(
score *= KV_CACHE_CONTINUED_PREFIX_MIN_FACTOR +
KV_CACHE_CONTINUED_PREFIX_HIT_FACTOR * h;
}
if (kv_cache_incoming_protects_prefix(e, incoming)) score *= 1.0e12;
return score;
}

Expand Down Expand Up @@ -929,6 +948,7 @@ bool ds4_kvstore_store_live_prefix_text(ds4_kvstore *kc,
const char *cache_text_override,
uint8_t cache_text_ext,
const char *cache_text_key,
const char *protect_text,
const ds4_kvstore_trailer_hooks *hooks,
char *err,
size_t err_len) {
Expand Down Expand Up @@ -1044,6 +1064,8 @@ bool ds4_kvstore_store_live_prefix_text(ds4_kvstore *kc,
ds4_kvstore_eviction_context incoming = {
.text = text,
.text_len = text_len,
.protect_text = protect_text,
.protect_text_len = protect_text ? strlen(protect_text) : 0,
.model_id = (uint8_t)model_id,
.quant_bits = (uint8_t)quant_bits,
.ctx_size = (uint32_t)ds4_session_ctx(session),
Expand Down Expand Up @@ -1164,8 +1186,9 @@ bool ds4_kvstore_store_live_prefix(ds4_kvstore *kc,
char *err,
size_t err_len) {
return ds4_kvstore_store_live_prefix_text(kc, engine, session, tokens,
store_len, reason, NULL, 0, NULL,
hooks, err, err_len);
store_len, reason, NULL, 0, NULL,
NULL,
hooks, err, err_len);
}

bool ds4_kvstore_maybe_store_continued(ds4_kvstore *kc,
Expand Down
3 changes: 3 additions & 0 deletions ds4_kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ typedef struct {
typedef struct {
const char *text;
size_t text_len;
const char *protect_text;
size_t protect_text_len;
uint8_t model_id;
uint8_t quant_bits;
uint32_t ctx_size;
Expand Down Expand Up @@ -169,6 +171,7 @@ bool ds4_kvstore_store_live_prefix_text(ds4_kvstore *kc,
const char *cache_text_override,
uint8_t cache_text_ext,
const char *cache_text_key,
const char *protect_text,
const ds4_kvstore_trailer_hooks *hooks,
char *err,
size_t err_len);
Expand Down
80 changes: 74 additions & 6 deletions ds4_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -8707,24 +8707,27 @@ static bool kv_cache_store_live_prefix_text(server *s, const ds4_tokens *tokens,
int store_len, const char *reason,
const char *cache_text_override,
uint8_t cache_text_ext,
const char *cache_text_key) {
const char *cache_text_key,
const char *protect_text) {
char err[160] = {0};
ds4_kvstore_trailer_hooks hooks = kv_cache_tool_map_hooks(s, NULL);
return ds4_kvstore_store_live_prefix_text(&s->kv, s->engine, s->session,
tokens, store_len, reason,
cache_text_override,
cache_text_ext,
cache_text_key,
protect_text,
&hooks, err, sizeof(err));
}

static bool kv_cache_store_live_prefix(server *s, const ds4_tokens *tokens,
int store_len, const char *reason) {
return kv_cache_store_live_prefix_text(s, tokens, store_len, reason,
NULL, 0, NULL);
NULL, 0, NULL, NULL);
}

static void kv_cache_store_current(server *s, const char *reason) {
static void kv_cache_store_current_protected(server *s, const char *reason,
const char *protect_text) {
const ds4_tokens *tokens = ds4_session_tokens(s->session);
if (!tokens) return;

Expand Down Expand Up @@ -8758,13 +8761,19 @@ static void kv_cache_store_current(server *s, const char *reason) {
* tokenizes only the visible suffix that follows this key. */
if (visible_text) {
kv_cache_store_live_prefix_text(s, tokens, tokens->len, reason,
visible_text, visible_ext, visible_key);
visible_text, visible_ext, visible_key,
protect_text);
free(visible_text);
} else {
kv_cache_store_live_prefix(s, tokens, tokens->len, reason);
kv_cache_store_live_prefix_text(s, tokens, tokens->len, reason,
NULL, 0, NULL, protect_text);
}
}

static void kv_cache_store_current(server *s, const char *reason) {
kv_cache_store_current_protected(s, reason, NULL);
}

static void kv_cache_note_store(kv_disk_cache *kc, int tokens) {
ds4_kvstore_note_store(kc, tokens);
}
Expand Down Expand Up @@ -10098,7 +10107,7 @@ static void generate_job(server *s, job *j) {
/* Loading a disk snapshot replaces the live Metal session. Persist the
* current checkpoint first, otherwise a cache hit for an older prefix
* would silently discard the newer conversation state. */
kv_cache_store_current(s, "evict");
kv_cache_store_current_protected(s, "evict", j->req.prompt_text);
}
if (cached == 0) {
disk_cached = kv_cache_try_load(s, &j->req, &effective_prompt,
Expand Down Expand Up @@ -15352,6 +15361,64 @@ static void test_kv_cache_eviction_prefers_superseded_continued_prefix(void) {
rmdir(dir);
}

static void test_kv_cache_eviction_protects_incoming_prefix_on_live_miss(void) {
char tmpl[] = "/tmp/ds4-kv-prefix-protect-test.XXXXXX";
char *dir = mkdtemp(tmpl);
TEST_ASSERT(dir != NULL);
if (!dir) return;

const char *prefix_text = "system: long stable prompt prefix";
const char *other_text = "unrelated checkpoint";
const char *old_live_text = "system: long stable prompt prefix\nassistant: failed streamed tail";
const char *incoming_prompt_text = "system: long stable prompt prefix\nuser: replayed visible tail";
const uint64_t protected_payload = 2000000u;
const uint64_t other_payload = 2048u;
const uint64_t extra_bytes = 4096u;

test_kv_text_stub_file(dir, prefix_text, KV_REASON_CONTINUED,
81920, protected_payload);
test_kv_text_stub_file(dir, other_text, KV_REASON_COLD,
1024, other_payload);

char prefix_sha[41], other_sha[41];
sha1_bytes_hex(prefix_text, strlen(prefix_text), prefix_sha);
sha1_bytes_hex(other_text, strlen(other_text), other_sha);
char prefix_name[44], other_name[44];
snprintf(prefix_name, sizeof(prefix_name), "%.40s.kv", prefix_sha);
snprintf(other_name, sizeof(other_name), "%.40s.kv", other_sha);
char *prefix_path = path_join(dir, prefix_name);
char *other_path = path_join(dir, other_name);

kv_disk_cache kc = {0};
kc.enabled = true;
kc.dir = xstrdup(dir);
kc.opt = kv_cache_default_options();
uint64_t protected_file_bytes =
KV_CACHE_FIXED_HEADER + 4u + strlen(prefix_text) + protected_payload;
kc.budget_bytes = extra_bytes + protected_file_bytes + 16u;
ds4_kvstore_eviction_context incoming = {
.text = old_live_text,
.text_len = strlen(old_live_text),
.protect_text = incoming_prompt_text,
.protect_text_len = strlen(incoming_prompt_text),
.model_id = 0,
.quant_bits = 2,
.ctx_size = 32768,
.reject_different_quant = false,
};
kv_cache_evict(&kc, NULL, extra_bytes, &incoming);

TEST_ASSERT(access(prefix_path, F_OK) == 0);
TEST_ASSERT(access(other_path, F_OK) != 0);

kv_cache_close(&kc);
unlink(prefix_path);
unlink(other_path);
free(prefix_path);
free(other_path);
rmdir(dir);
}

static void test_kv_cache_eviction_keeps_smaller_context_prefix(void) {
char tmpl[] = "/tmp/ds4-kv-prefix-ctx-test.XXXXXX";
char *dir = mkdtemp(tmpl);
Expand Down Expand Up @@ -15854,6 +15921,7 @@ static void ds4_server_unit_tests_run(void) {
test_kv_cache_eviction_makes_room_before_store();
test_kv_cache_eviction_ignores_oversize_incoming();
test_kv_cache_eviction_prefers_superseded_continued_prefix();
test_kv_cache_eviction_protects_incoming_prefix_on_live_miss();
test_kv_cache_eviction_keeps_smaller_context_prefix();
test_kv_cache_eviction_score_decays_stale_hits();
test_kv_cache_eviction_decayed_hits_tie_break_by_age();
Expand Down