Source code for hatchet.readers.spotdb_reader

# Copyright 2017-2023 Lawrence Livermore National Security, LLC and other
# Hatchet Project Developers. See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: MIT

import pandas as pd

import hatchet.graphframe
from hatchet.node import Node
from hatchet.graph import Graph
from hatchet.frame import Frame
from hatchet.util.timer import Timer


def _find_child_node(node, name):
    """Return child with given name from parent node"""
    for c in node.children:
        if c.frame.get("name") == name:
            return c
    return None


[docs]class SpotDatasetReader: """Reads a (single-run) dataset from SpotDB""" def __init__(self, regionprofile, metadata, attr_info): """Initialize SpotDataset reader Args: regionprofile (dict): Dict with region names to key:value record with metrics. Region names are hierarchical, separated with '/'. Example: { "a/b/c": { "metric": val, ... }, ... } metadata: (dict): Key-value run metadata for this dataset. Example: { "launchdate": 123456789, "figure_of_merit": 42.0 } attr_info (dict): Information about metric attributes. Contains, e.g., type and alias info. This data is optional. Example: { "metric": { "type": "double", "alias": "The Metric", ... }, ... } """ self.regionprofile = regionprofile self.attr_info = attr_info self.metadata = metadata self.df_data = [] self.roots = {} self.metric_columns = set() self.timer = Timer()
[docs] def create_graph(self): """Create the graph. Fills in df_data and metric_columns.""" self.df_data.clear() for pathstr, vals in self.regionprofile.items(): # parse { "a/b/c": { "metric": val, ... }, ... } records if len(pathstr) == 0: continue path = pathstr.split("/") name = path[-1] node = self._create_node(path) metrics = {} for k, v in vals.items(): info = self.attr_info.get(k, dict()) colm = info.get("alias", k) type = info.get("type", "string") if "inclusive" in k: colm += " (inc)" if type == "double": metrics[colm] = float(v) elif type == "int" or type == "uint": metrics[colm] = int(v) else: metrics[colm] = v self.metric_columns.add(colm) self.df_data.append(dict({"name": name, "node": node}, **metrics))
[docs] def read(self, default_metric="Total time (inc)"): """Create GraphFrame for the given Spot dataset.""" with self.timer.phase("graph construction"): self.create_graph() graph = Graph(list(self.roots.values())) graph.enumerate_traverse() dataframe = pd.DataFrame(data=self.df_data) dataframe.set_index(["node"], inplace=True) exc_metrics = [] inc_metrics = [] for m in self.metric_columns: if "(inc)" in m: inc_metrics.append(m) else: exc_metrics.append(m) if default_metric not in dataframe.columns: if len(inc_metrics) > 0: default_metric = inc_metrics[0] elif len(exc_metrics) > 0: default_metric = exc_metrics[0] return hatchet.graphframe.GraphFrame( graph, dataframe, exc_metrics, inc_metrics, metadata=self.metadata, default_metric=default_metric, )
def _create_node(self, path): parent = self.roots.get(path[0], None) if parent is None: parent = Node(Frame(name=path[0])) self.roots[path[0]] = parent node = parent for name in path[1:]: node = _find_child_node(parent, name) if node is None: node = Node(Frame(name=name), parent) parent.add_child(node) parent = node return node
[docs]class SpotDBReader: """Import multiple runs as graph frames from a SpotDB instance""" def __init__(self, db_key, list_of_ids=None, default_metric="Total time (inc)"): """Initialize SpotDBReader Args: db_key (str or SpotDB object): locator for SpotDB instance This can be a SpotDB object directly, or a locator for a spot database, which is a string with either * A directory for .cali files, * A .sqlite file name * A SQL database URL (e.g., "mysql://hostname/db") list_of_ids: The list of run IDs to read from the database. If this is None, returns all runs. default_metric: Name of the default metric for the GraphFrames. """ self.db_key = db_key self.list_of_ids = list_of_ids self.default_metric = default_metric
[docs] def read(self): """Read given runs from SpotDB Returns: List of GraphFrames, one for each entry that was found """ import spotdb if isinstance(self.db_key, str): db = spotdb.connect(self.db_key) else: db = self.db_key runs = self.list_of_ids or db.get_all_run_ids() regionprofiles = db.get_regionprofiles(runs) metadata = db.get_global_data(runs) attr_info = db.get_metric_attribute_metadata() result = [] for run in runs: if run in regionprofiles: result.append( SpotDatasetReader( regionprofiles[run], metadata[run], attr_info ).read(self.default_metric) ) return result