diff --git a/ds4_kvstore.c b/ds4_kvstore.c index 6b663b51b..e27e7a28e 100644 --- a/ds4_kvstore.c +++ b/ds4_kvstore.c @@ -523,6 +523,24 @@ static bool kv_cache_incoming_supersedes_continued( return !strcmp(prefix_sha, e->sha); } +static bool kv_cache_incoming_protects_prefix( + const ds4_kvstore_entry *e, + const ds4_kvstore_eviction_context *incoming) { + if (!e || !incoming || !incoming->protect_text) return false; + if (e->text_bytes == 0 || e->text_bytes > SIZE_MAX) return false; + if ((size_t)e->text_bytes > incoming->protect_text_len) return false; + if (e->model_id != incoming->model_id) return false; + if (incoming->reject_different_quant && + e->quant_bits != incoming->quant_bits) + return false; + if (incoming->ctx_size > e->ctx_size) return false; + + char prefix_sha[41]; + ds4_kvstore_sha1_bytes_hex(incoming->protect_text, + (size_t)e->text_bytes, prefix_sha); + return !strcmp(prefix_sha, e->sha); +} + static bool kv_cache_reason_is_anchor(uint8_t reason) { return reason == DS4_KVSTORE_REASON_COLD || reason == DS4_KVSTORE_REASON_EVICT || @@ -555,6 +573,7 @@ double ds4_kvstore_entry_eviction_score( score *= KV_CACHE_CONTINUED_PREFIX_MIN_FACTOR + KV_CACHE_CONTINUED_PREFIX_HIT_FACTOR * h; } + if (kv_cache_incoming_protects_prefix(e, incoming)) score *= 1.0e12; return score; } @@ -929,6 +948,7 @@ bool ds4_kvstore_store_live_prefix_text(ds4_kvstore *kc, const char *cache_text_override, uint8_t cache_text_ext, const char *cache_text_key, + const char *protect_text, const ds4_kvstore_trailer_hooks *hooks, char *err, size_t err_len) { @@ -1044,6 +1064,8 @@ bool ds4_kvstore_store_live_prefix_text(ds4_kvstore *kc, ds4_kvstore_eviction_context incoming = { .text = text, .text_len = text_len, + .protect_text = protect_text, + .protect_text_len = protect_text ? strlen(protect_text) : 0, .model_id = (uint8_t)model_id, .quant_bits = (uint8_t)quant_bits, .ctx_size = (uint32_t)ds4_session_ctx(session), @@ -1164,8 +1186,9 @@ bool ds4_kvstore_store_live_prefix(ds4_kvstore *kc, char *err, size_t err_len) { return ds4_kvstore_store_live_prefix_text(kc, engine, session, tokens, - store_len, reason, NULL, 0, NULL, - hooks, err, err_len); + store_len, reason, NULL, 0, NULL, + NULL, + hooks, err, err_len); } bool ds4_kvstore_maybe_store_continued(ds4_kvstore *kc, diff --git a/ds4_kvstore.h b/ds4_kvstore.h index 3a01f586f..c426070e8 100644 --- a/ds4_kvstore.h +++ b/ds4_kvstore.h @@ -82,6 +82,8 @@ typedef struct { typedef struct { const char *text; size_t text_len; + const char *protect_text; + size_t protect_text_len; uint8_t model_id; uint8_t quant_bits; uint32_t ctx_size; @@ -169,6 +171,7 @@ bool ds4_kvstore_store_live_prefix_text(ds4_kvstore *kc, const char *cache_text_override, uint8_t cache_text_ext, const char *cache_text_key, + const char *protect_text, const ds4_kvstore_trailer_hooks *hooks, char *err, size_t err_len); diff --git a/ds4_server.c b/ds4_server.c index 34a9d5084..eac97d427 100644 --- a/ds4_server.c +++ b/ds4_server.c @@ -8707,7 +8707,8 @@ static bool kv_cache_store_live_prefix_text(server *s, const ds4_tokens *tokens, int store_len, const char *reason, const char *cache_text_override, uint8_t cache_text_ext, - const char *cache_text_key) { + const char *cache_text_key, + const char *protect_text) { char err[160] = {0}; ds4_kvstore_trailer_hooks hooks = kv_cache_tool_map_hooks(s, NULL); return ds4_kvstore_store_live_prefix_text(&s->kv, s->engine, s->session, @@ -8715,16 +8716,18 @@ static bool kv_cache_store_live_prefix_text(server *s, const ds4_tokens *tokens, cache_text_override, cache_text_ext, cache_text_key, + protect_text, &hooks, err, sizeof(err)); } static bool kv_cache_store_live_prefix(server *s, const ds4_tokens *tokens, int store_len, const char *reason) { return kv_cache_store_live_prefix_text(s, tokens, store_len, reason, - NULL, 0, NULL); + NULL, 0, NULL, NULL); } -static void kv_cache_store_current(server *s, const char *reason) { +static void kv_cache_store_current_protected(server *s, const char *reason, + const char *protect_text) { const ds4_tokens *tokens = ds4_session_tokens(s->session); if (!tokens) return; @@ -8758,13 +8761,19 @@ static void kv_cache_store_current(server *s, const char *reason) { * tokenizes only the visible suffix that follows this key. */ if (visible_text) { kv_cache_store_live_prefix_text(s, tokens, tokens->len, reason, - visible_text, visible_ext, visible_key); + visible_text, visible_ext, visible_key, + protect_text); free(visible_text); } else { - kv_cache_store_live_prefix(s, tokens, tokens->len, reason); + kv_cache_store_live_prefix_text(s, tokens, tokens->len, reason, + NULL, 0, NULL, protect_text); } } +static void kv_cache_store_current(server *s, const char *reason) { + kv_cache_store_current_protected(s, reason, NULL); +} + static void kv_cache_note_store(kv_disk_cache *kc, int tokens) { ds4_kvstore_note_store(kc, tokens); } @@ -10098,7 +10107,7 @@ static void generate_job(server *s, job *j) { /* Loading a disk snapshot replaces the live Metal session. Persist the * current checkpoint first, otherwise a cache hit for an older prefix * would silently discard the newer conversation state. */ - kv_cache_store_current(s, "evict"); + kv_cache_store_current_protected(s, "evict", j->req.prompt_text); } if (cached == 0) { disk_cached = kv_cache_try_load(s, &j->req, &effective_prompt, @@ -15352,6 +15361,64 @@ static void test_kv_cache_eviction_prefers_superseded_continued_prefix(void) { rmdir(dir); } +static void test_kv_cache_eviction_protects_incoming_prefix_on_live_miss(void) { + char tmpl[] = "/tmp/ds4-kv-prefix-protect-test.XXXXXX"; + char *dir = mkdtemp(tmpl); + TEST_ASSERT(dir != NULL); + if (!dir) return; + + const char *prefix_text = "system: long stable prompt prefix"; + const char *other_text = "unrelated checkpoint"; + const char *old_live_text = "system: long stable prompt prefix\nassistant: failed streamed tail"; + const char *incoming_prompt_text = "system: long stable prompt prefix\nuser: replayed visible tail"; + const uint64_t protected_payload = 2000000u; + const uint64_t other_payload = 2048u; + const uint64_t extra_bytes = 4096u; + + test_kv_text_stub_file(dir, prefix_text, KV_REASON_CONTINUED, + 81920, protected_payload); + test_kv_text_stub_file(dir, other_text, KV_REASON_COLD, + 1024, other_payload); + + char prefix_sha[41], other_sha[41]; + sha1_bytes_hex(prefix_text, strlen(prefix_text), prefix_sha); + sha1_bytes_hex(other_text, strlen(other_text), other_sha); + char prefix_name[44], other_name[44]; + snprintf(prefix_name, sizeof(prefix_name), "%.40s.kv", prefix_sha); + snprintf(other_name, sizeof(other_name), "%.40s.kv", other_sha); + char *prefix_path = path_join(dir, prefix_name); + char *other_path = path_join(dir, other_name); + + kv_disk_cache kc = {0}; + kc.enabled = true; + kc.dir = xstrdup(dir); + kc.opt = kv_cache_default_options(); + uint64_t protected_file_bytes = + KV_CACHE_FIXED_HEADER + 4u + strlen(prefix_text) + protected_payload; + kc.budget_bytes = extra_bytes + protected_file_bytes + 16u; + ds4_kvstore_eviction_context incoming = { + .text = old_live_text, + .text_len = strlen(old_live_text), + .protect_text = incoming_prompt_text, + .protect_text_len = strlen(incoming_prompt_text), + .model_id = 0, + .quant_bits = 2, + .ctx_size = 32768, + .reject_different_quant = false, + }; + kv_cache_evict(&kc, NULL, extra_bytes, &incoming); + + TEST_ASSERT(access(prefix_path, F_OK) == 0); + TEST_ASSERT(access(other_path, F_OK) != 0); + + kv_cache_close(&kc); + unlink(prefix_path); + unlink(other_path); + free(prefix_path); + free(other_path); + rmdir(dir); +} + static void test_kv_cache_eviction_keeps_smaller_context_prefix(void) { char tmpl[] = "/tmp/ds4-kv-prefix-ctx-test.XXXXXX"; char *dir = mkdtemp(tmpl); @@ -15854,6 +15921,7 @@ static void ds4_server_unit_tests_run(void) { test_kv_cache_eviction_makes_room_before_store(); test_kv_cache_eviction_ignores_oversize_incoming(); test_kv_cache_eviction_prefers_superseded_continued_prefix(); + test_kv_cache_eviction_protects_incoming_prefix_on_live_miss(); test_kv_cache_eviction_keeps_smaller_context_prefix(); test_kv_cache_eviction_score_decays_stale_hits(); test_kv_cache_eviction_decayed_hits_tie_break_by_age();