From cdbd81faa5e82f22acca598f432d000b729994d0 Mon Sep 17 00:00:00 2001 From: Gareth Ellis Date: Thu, 4 Jun 2026 13:01:52 +0200 Subject: [PATCH 1/2] Reduce metrics store HTTP round-trips by batching per-task queries Introduce private query-builder methods and static response parsers to EsMetricsStore, plus an msearch() method on both EsClient and EsMetricsStore. GlobalStatsCalculator dispatches to _call_batched() when the store supports msearch, collapsing all per-task queries (~14 per task) into a single _msearch HTTP call. Cluster-level metrics continue to use the sequential helpers so there is a single source of truth for those metric names. GlobalStats gains a cluster_name field for future multi-cluster use. Co-Authored-By: Claude Sonnet 4.6 --- esrally/metrics.py | 416 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 319 insertions(+), 97 deletions(-) diff --git a/esrally/metrics.py b/esrally/metrics.py index ec9f5212f..004415732 100644 --- a/esrally/metrics.py +++ b/esrally/metrics.py @@ -119,6 +119,9 @@ def index(self, *, index, item, id=None, use_data_streams): def search(self, index, body): return self.guarded(self._client.search, index=index, body=body) + def msearch(self, body): + return self.guarded(self._client.msearch, body=body) + def guarded(self, target, *args, **kwargs): # pylint: disable=import-outside-toplevel import elasticsearch @@ -1244,133 +1247,58 @@ def _add(self, doc): with self._docs_lock: self._docs.append(doc) - def _get(self, name, task, operation_type, sample_type, node_name, mapper): - query = { - "query": self._query_by_name(name, task, operation_type, sample_type, node_name), - "track_total_hits": True, - "size": 10000, - } + def _get(self, name, task, operation_type, sample_type, node_name, cluster_name, mapper): + query = self._build_values_query(name, task, operation_type, sample_type, node_name, cluster_name) self.logger.debug("Issuing get against index=[%s], query=[%s].", self._index_handler.index_name(self._race_timestamp), query) result = self._client.search(index=self._index_handler.index_name(self._race_timestamp), body=query) es_count = result["hits"]["total"]["value"] self.logger.debug("Metrics query found [%s] results.", es_count) if es_count != len(result["hits"]["hits"]): self.logger.warning("Metrics query returned [%d] out of [%s] matching docs.", len(result["hits"]["hits"]), es_count) - return [mapper(v["_source"]) for v in result["hits"]["hits"]] + return self._parse_values_response(result, mapper) def get_one( - self, name, sample_type=None, node_name=None, task=None, mapper=lambda doc: doc["value"], sort_key=None, sort_reverse=False + self, name, sample_type=None, node_name=None, task=None, cluster_name=None, mapper=lambda doc: doc["value"], sort_key=None, sort_reverse=False ): - order = "desc" if sort_reverse else "asc" - query = { - "query": self._query_by_name(name, task, None, sample_type, node_name), - "size": 1, - } - if sort_key: - query["sort"] = [{sort_key: {"order": order}}] + query = self._build_one_query(name, task, sample_type, node_name, cluster_name, sort_key, sort_reverse) self.logger.debug("Issuing get against index=[%s], query=[%s].", self._index_handler.index_name(self._race_timestamp), query) result = self._client.search(index=self._index_handler.index_name(self._race_timestamp), body=query) - hits = result["hits"]["total"] - # Elasticsearch 7.0+ - if isinstance(hits, dict): - hits = hits["value"] - self.logger.debug("Metrics query produced [%s] results.", hits) - if hits > 0: - return mapper(result["hits"]["hits"][0]["_source"]) - else: - return None + self.logger.debug("Metrics query produced [%s] results.", result["hits"]["total"]) + return self._parse_one_response(result, mapper) - def get_error_rate(self, task, operation_type=None, sample_type=None): - query = { - "query": self._query_by_name("service_time", task, operation_type, sample_type, None), - "size": 0, - "aggs": { - "error_rate": { - "terms": { - "field": "meta.success", - }, - }, - }, - } + def get_error_rate(self, task, operation_type=None, sample_type=None, cluster_name=None): + query = self._build_error_rate_query(task, operation_type, sample_type, cluster_name) self.logger.debug( "Issuing get_error_rate against index=[%s], query=[%s]", self._index_handler.index_name(self._race_timestamp), query ) result = self._client.search(index=self._index_handler.index_name(self._race_timestamp), body=query) - buckets = result["aggregations"]["error_rate"]["buckets"] - self.logger.debug("Query returned [%d] buckets.", len(buckets)) - count_success = 0 - count_errors = 0 - for bucket in buckets: - k = bucket["key_as_string"] - doc_count = int(bucket["doc_count"]) - self.logger.debug("Processing key [%s] with [%d] docs.", k, doc_count) - if k == "true": - count_success = doc_count - elif k == "false": - count_errors = doc_count - else: - self.logger.warning("Unrecognized bucket key [%s] with [%d] docs.", k, doc_count) + self.logger.debug("Query returned [%d] buckets.", len(result["aggregations"]["error_rate"]["buckets"])) + return self._parse_error_rate_response(result) - if count_errors == 0: - return 0.0 - elif count_success == 0: - return 1.0 - else: - return count_errors / (count_errors + count_success) - - def get_stats(self, name, task=None, operation_type=None, sample_type=None): + def get_stats(self, name, task=None, operation_type=None, sample_type=None, cluster_name=None): """ Gets standard statistics for the given metric name. :return: A metric_stats structure. For details please refer to https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-metrics-stats-aggregation.html """ - query = { - "query": self._query_by_name(name, task, operation_type, sample_type, None), - "size": 0, - "aggs": { - "metric_stats": { - "stats": { - "field": "value", - }, - }, - }, - } + query = self._build_stats_query(name, task, operation_type, sample_type, cluster_name) self.logger.debug("Issuing get_stats against index=[%s], query=[%s]", self._index_handler.index_name(self._race_timestamp), query) result = self._client.search(index=self._index_handler.index_name(self._race_timestamp), body=query) - return result["aggregations"]["metric_stats"] + return self._parse_stats_response(result) - def get_percentiles(self, name, task=None, operation_type=None, sample_type=None, percentiles=None): + def get_percentiles(self, name, task=None, operation_type=None, sample_type=None, percentiles=None, cluster_name=None): if percentiles is None: percentiles = [99, 99.9, 100] - query = { - "query": self._query_by_name(name, task, operation_type, sample_type, None), - "size": 0, - "aggs": { - "percentile_stats": { - "percentiles": { - "field": "value", - "percents": percentiles, - }, - }, - }, - } + query = self._build_percentiles_query(name, task, operation_type, sample_type, percentiles, cluster_name) self.logger.debug( "Issuing get_percentiles against index=[%s], query=[%s]", self._index_handler.index_name(self._race_timestamp), query ) result = self._client.search(index=self._index_handler.index_name(self._race_timestamp), body=query) - hits = result["hits"]["total"] - # Elasticsearch 7.0+ - if isinstance(hits, dict): - hits = hits["value"] - self.logger.debug("get_percentiles produced %d hits", hits) - if hits > 0: - raw = result["aggregations"]["percentile_stats"]["values"] - return collections.OrderedDict(sorted(raw.items(), key=lambda t: float(t[0]))) - else: - return None + self.logger.debug("get_percentiles produced %d hits", result["hits"]["total"]) + return self._parse_percentiles_response(result) - def _query_by_name(self, name, task, operation_type, sample_type, node_name): + def _query_by_name(self, name, task, operation_type, sample_type, node_name, cluster_name=None): q = { "bool": { "filter": [ @@ -1421,6 +1349,113 @@ def _query_by_name(self, name, task, operation_type, sample_type, node_name): ) return q + # ------------------------------------------------------------------ + # Private query builders – return a query body dict ready for search + # ------------------------------------------------------------------ + + def _build_values_query(self, name, task, operation_type, sample_type, node_name, cluster_name): + return { + "query": self._query_by_name(name, task, operation_type, sample_type, node_name, cluster_name), + "track_total_hits": True, + "size": 10000, + } + + def _build_one_query(self, name, task, sample_type, node_name, cluster_name, sort_key=None, sort_reverse=False): + query = { + "query": self._query_by_name(name, task, None, sample_type, node_name, cluster_name), + "size": 1, + } + if sort_key: + query["sort"] = [{sort_key: {"order": "desc" if sort_reverse else "asc"}}] + return query + + def _build_stats_query(self, name, task, operation_type, sample_type, cluster_name): + return { + "query": self._query_by_name(name, task, operation_type, sample_type, None, cluster_name), + "size": 0, + "aggs": {"metric_stats": {"stats": {"field": "value"}}}, + } + + def _build_error_rate_query(self, task, operation_type, sample_type, cluster_name): + return { + "query": self._query_by_name("service_time", task, operation_type, sample_type, None, cluster_name), + "size": 0, + "aggs": {"error_rate": {"terms": {"field": "meta.success"}}}, + } + + def _build_percentiles_query(self, name, task, operation_type, sample_type, percents, cluster_name): + return { + "query": self._query_by_name(name, task, operation_type, sample_type, None, cluster_name), + "size": 0, + "aggs": {"percentile_stats": {"percentiles": {"field": "value", "percents": percents}}}, + } + + # ------------------------------------------------------------------ + # Static response parsers – extract results from a search response + # ------------------------------------------------------------------ + + @staticmethod + def _parse_values_response(result, mapper): + return [mapper(v["_source"]) for v in result["hits"]["hits"]] + + @staticmethod + def _parse_one_response(result, mapper): + hits = result["hits"]["total"] + if isinstance(hits, dict): + hits = hits["value"] + return mapper(result["hits"]["hits"][0]["_source"]) if hits > 0 else None + + @staticmethod + def _parse_stats_response(result): + return result["aggregations"]["metric_stats"] + + @staticmethod + def _parse_error_rate_response(result): + buckets = result["aggregations"]["error_rate"]["buckets"] + count_success = count_errors = 0 + for bucket in buckets: + k = bucket["key_as_string"] + n = int(bucket["doc_count"]) + if k == "true": + count_success = n + elif k == "false": + count_errors = n + if count_errors == 0: + return 0.0 + if count_success == 0: + return 1.0 + return count_errors / (count_errors + count_success) + + @staticmethod + def _parse_percentiles_response(result): + hits = result["hits"]["total"] + if isinstance(hits, dict): + hits = hits["value"] + if hits > 0: + raw = result["aggregations"]["percentile_stats"]["values"] + return collections.OrderedDict(sorted(raw.items(), key=lambda t: float(t[0]))) + return None + + # ------------------------------------------------------------------ + # Multi-search: execute many queries in one HTTP round-trip + # ------------------------------------------------------------------ + + def msearch(self, requests): + """Execute multiple search queries in a single ``_msearch`` HTTP call. + + :param requests: A list of ``(query_body, response_mapper)`` pairs. + :return: A list of mapped results in the same order as *requests*. + """ + if not requests: + return [] + index = self._index_handler.index_name(self._race_timestamp) + body = [] + for query_body, _ in requests: + body.append({"index": index}) + body.append(query_body) + response = self._client.msearch(body=body) + return [mapper(r) for (_, mapper), r in zip(requests, response["responses"])] + def to_externalizable(self, clear=False): # no need for an externalizable representation - stores everything directly return None @@ -2402,8 +2437,17 @@ def __init__(self, store, track, challenge): self.track = track self.challenge = challenge - def __call__(self): - result = GlobalStats() + # Full percentile set — used in _call_batched so we issue one request regardless of sample size, + # then trim the response to the correct subset when processing. + _MAX_PERCENTS = [50, 90, 99, 99.9, 99.99, 100] + + def __call__(self, cluster_name=None): + if hasattr(self.store, "msearch"): + return self._call_batched(cluster_name) + return self._call_sequential(cluster_name) + + def _call_sequential(self, cluster_name=None): + result = GlobalStats(cluster_name=cluster_name) for tasks in self.challenge.schedule: for task in tasks: @@ -2491,6 +2535,183 @@ def __call__(self): return result + def _call_batched(self, cluster_name=None): + """Batch all per-task queries into a single ``_msearch`` HTTP call, then + call the existing sequential helpers for cluster-level metrics so there + is a single source of truth for those metric names.""" + store = self.store + requests = [] # list of (query_body, response_mapper) + normal = SampleType.Normal + + def add(query_body, mapper): + idx = len(requests) + requests.append((query_body, mapper)) + return idx + + # --- Per-task request slots --- + task_entries = [] # list of (task_obj, task_name, op_type, slot_dict) + for tasks_in_step in self.challenge.schedule: + for task in tasks_in_step: + t = task.name + op = task.operation.type + s = {} + s["error_rate"] = add( + store._build_error_rate_query(t, op, normal, cluster_name), + EsMetricsStore._parse_error_rate_response, + ) + s["duration"] = add( + store._build_one_query("service_time", t, None, None, cluster_name, sort_key="relative-time"), + lambda r: EsMetricsStore._parse_one_response(r, lambda doc: doc["relative-time"]), + ) + # throughput: one stats query covers min/mean/max; separate percentiles[50] for median + s["tput_stats"] = add( + store._build_stats_query("throughput", t, op, normal, cluster_name), + EsMetricsStore._parse_stats_response, + ) + s["tput_pcts"] = add( + store._build_percentiles_query("throughput", t, op, normal, [50.0], cluster_name), + EsMetricsStore._parse_percentiles_response, + ) + s["tput_unit"] = add( + store._build_one_query("throughput", t, None, None, cluster_name), + lambda r: EsMetricsStore._parse_one_response(r, lambda doc: doc["unit"]), + ) + # latency / service_time / processing_time: request full percentile set; trim on response + for metric in ("latency", "service_time", "processing_time"): + s[f"{metric}_stats"] = add( + store._build_stats_query(metric, t, op, normal, cluster_name), + EsMetricsStore._parse_stats_response, + ) + s[f"{metric}_pcts"] = add( + store._build_percentiles_query(metric, t, op, normal, self._MAX_PERCENTS, cluster_name), + EsMetricsStore._parse_percentiles_response, + ) + s[f"{metric}_unit"] = add( + store._build_one_query(metric, t, None, None, cluster_name), + lambda r: EsMetricsStore._parse_one_response(r, lambda doc: doc["unit"]), + ) + task_entries.append((task, t, op, s)) + + # --- Execute task-level requests in one HTTP call --- + responses = store.msearch(requests) + + # --- Build result from batch responses --- + result = GlobalStats(cluster_name=cluster_name) + + for task, t, op, s in task_entries: + error_rate = responses[s["error_rate"]] + duration = responses[s["duration"]] + + tput_stats = responses[s["tput_stats"]] + tput_pcts = responses[s["tput_pcts"]] + tput_unit = responses[s["tput_unit"]] + if tput_stats and tput_stats["count"] > 0 and tput_pcts: + throughput = { + "min": tput_stats["min"], + "mean": tput_stats["avg"], + "median": tput_pcts.get("50.0"), + "max": tput_stats["max"], + "unit": tput_unit, + } + else: + throughput = {"min": None, "mean": None, "median": None, "max": None, "unit": tput_unit} + + latency_results = {} + for metric in ("latency", "service_time", "processing_time"): + mstats = responses[s[f"{metric}_stats"]] + mpcts = responses[s[f"{metric}_pcts"]] + munit = responses[s[f"{metric}_unit"]] + sample_size = mstats["count"] if mstats else 0 + if sample_size > 0 and mpcts: + desired = {str(float(p)) for p in percentiles_for_sample_size(sample_size)} + ordered = collections.OrderedDict( + (encode_float_key(k), v) for k, v in mpcts.items() if k in desired + ) + ordered["mean"] = mstats["avg"] + ordered["unit"] = munit + latency_results[metric] = ordered + else: + latency_results[metric] = {} + + if task.operation.include_in_reporting or error_rate > 0: + result.add_op_metrics( + t, + task.operation.name, + throughput, + latency_results["latency"], + latency_results["service_time"], + latency_results["processing_time"], + error_rate, + duration, + self.merge(self.track.meta_data, self.challenge.meta_data, task.operation.meta_data, task.meta_data), + ) + + # --- Cluster-level metrics: delegate to sequential helpers (single source of truth) --- + self.logger.debug("Gathering indexing metrics.") + result.total_time = self.sum("indexing_total_time") + result.total_time_per_shard = self.shard_stats("indexing_total_time") + result.indexing_throttle_time = self.sum("indexing_throttle_time") + result.indexing_throttle_time_per_shard = self.shard_stats("indexing_throttle_time") + result.merge_time = self.sum("merges_total_time") + result.merge_time_per_shard = self.shard_stats("merges_total_time") + result.merge_count = self.sum("merges_total_count") + result.refresh_time = self.sum("refresh_total_time") + result.refresh_time_per_shard = self.shard_stats("refresh_total_time") + result.refresh_count = self.sum("refresh_total_count") + result.flush_time = self.sum("flush_total_time") + result.flush_time_per_shard = self.shard_stats("flush_total_time") + result.flush_count = self.sum("flush_total_count") + result.merge_throttle_time = self.sum("merges_total_throttled_time") + result.merge_throttle_time_per_shard = self.shard_stats("merges_total_throttled_time") + + self.logger.debug("Gathering ML max processing times.") + result.ml_processing_time = self.ml_processing_time_stats() + + self.logger.debug("Gathering garbage collection metrics.") + result.young_gc_time = self.sum("node_total_young_gen_gc_time") + result.young_gc_count = self.sum("node_total_young_gen_gc_count") + result.old_gc_time = self.sum("node_total_old_gen_gc_time") + result.old_gc_count = self.sum("node_total_old_gen_gc_count") + result.zgc_cycles_gc_time = self.sum("node_total_zgc_cycles_gc_time") + result.zgc_cycles_gc_count = self.sum("node_total_zgc_cycles_gc_count") + result.zgc_pauses_gc_time = self.sum("node_total_zgc_pauses_gc_time") + result.zgc_pauses_gc_count = self.sum("node_total_zgc_pauses_gc_count") + + self.logger.debug("Gathering segment memory metrics.") + result.memory_segments = self.median("segments_memory_in_bytes") + result.memory_doc_values = self.median("segments_doc_values_memory_in_bytes") + result.memory_terms = self.median("segments_terms_memory_in_bytes") + result.memory_norms = self.median("segments_norms_memory_in_bytes") + result.memory_points = self.median("segments_points_memory_in_bytes") + result.memory_stored_fields = self.median("segments_stored_fields_memory_in_bytes") + result.dataset_size = self.sum("dataset_size_in_bytes") + result.store_size = self.sum("store_size_in_bytes") + result.translog_size = self.sum("translog_size_in_bytes") + median_segment_count = self.median("segments_count") + result.segment_count = int(median_segment_count) if median_segment_count is not None else median_segment_count + + self.logger.debug("Gathering transform processing times.") + result.total_transform_processing_times = self.total_transform_metric("total_transform_processing_time") + result.total_transform_index_times = self.total_transform_metric("total_transform_index_time") + result.total_transform_search_times = self.total_transform_metric("total_transform_search_time") + result.total_transform_throughput = self.total_transform_metric("total_transform_throughput") + + self.logger.debug("Gathering Ingest Pipeline metrics.") + result.ingest_pipeline_cluster_count = self.sum("ingest_pipeline_cluster_count") + result.ingest_pipeline_cluster_time = self.sum("ingest_pipeline_cluster_time") + result.ingest_pipeline_cluster_failed = self.sum("ingest_pipeline_cluster_failed") + + self.logger.debug("Gathering disk usage metrics.") + result.disk_usage_total = self.disk_usage("disk_usage_total") + result.disk_usage_inverted_index = self.disk_usage("disk_usage_inverted_index") + result.disk_usage_stored_fields = self.disk_usage("disk_usage_stored_fields") + result.disk_usage_doc_values = self.disk_usage("disk_usage_doc_values") + result.disk_usage_points = self.disk_usage("disk_usage_points") + result.disk_usage_norms = self.disk_usage("disk_usage_norms") + result.disk_usage_term_vectors = self.disk_usage("disk_usage_term_vectors") + + return result + def merge(self, *args): # This is similar to dict(collections.ChainMap(args)) except that we skip `None` in our implementation. result = {} @@ -2614,7 +2835,8 @@ def single_latency(self, task, operation_type, metric_name="latency"): class GlobalStats: - def __init__(self, d=None): + def __init__(self, d=None, cluster_name=None): + self.cluster_name = cluster_name self.op_metrics = self.v(d, "op_metrics", default=[]) self.total_time = self.v(d, "total_time") self.total_time_per_shard = self.v(d, "total_time_per_shard", default={}) From 8f7d315586fe4d93e42829b7796c0a3e36c414e7 Mon Sep 17 00:00:00 2001 From: Gareth Ellis Date: Thu, 4 Jun 2026 13:46:39 +0200 Subject: [PATCH 2/2] Lint --- esrally/metrics.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/esrally/metrics.py b/esrally/metrics.py index 004415732..e33388019 100644 --- a/esrally/metrics.py +++ b/esrally/metrics.py @@ -1258,7 +1258,15 @@ def _get(self, name, task, operation_type, sample_type, node_name, cluster_name, return self._parse_values_response(result, mapper) def get_one( - self, name, sample_type=None, node_name=None, task=None, cluster_name=None, mapper=lambda doc: doc["value"], sort_key=None, sort_reverse=False + self, + name, + sample_type=None, + node_name=None, + task=None, + cluster_name=None, + mapper=lambda doc: doc["value"], + sort_key=None, + sort_reverse=False, ): query = self._build_one_query(name, task, sample_type, node_name, cluster_name, sort_key, sort_reverse) self.logger.debug("Issuing get against index=[%s], query=[%s].", self._index_handler.index_name(self._race_timestamp), query) @@ -2624,9 +2632,7 @@ def add(query_body, mapper): sample_size = mstats["count"] if mstats else 0 if sample_size > 0 and mpcts: desired = {str(float(p)) for p in percentiles_for_sample_size(sample_size)} - ordered = collections.OrderedDict( - (encode_float_key(k), v) for k, v in mpcts.items() if k in desired - ) + ordered = collections.OrderedDict((encode_float_key(k), v) for k, v in mpcts.items() if k in desired) ordered["mean"] = mstats["avg"] ordered["unit"] = munit latency_results[metric] = ordered