Source code for cmsml.tensorflow.aot

# coding: utf-8

"""
Tools and objects for working with AOT / XLA.
"""

from __future__ import annotations

import sys
import re
from subprocess import PIPE

from cmsml.util import interruptable_popen
from cmsml.tensorflow.tools import import_tf

tf = import_tf()[0]

from tensorflow.core.framework.graph_pb2 import GraphDef


[docs]class OpsData(object): """ AOT needs two requirements to work: 1) the outcome of an ops-kernel needs to be deterministic 2) the ops-kernel needs to have an XLA implementation. Tensorflow can return a markdown table containing all XLA compatible ops. This class is a wrapper to create this table and consequently read it. """ device_ids = { "cpu": "XLA_CPU_JIT", "gpu": "XLA_GPU_JIT", } def __init__(self: OpsData, devices: tuple[str] | None = None) -> None: """ Sets an iterable of *devices* for which the XLA operations table should be generate. """ super().__init__() # store operation data in a nested dict self._ops = {} # determine ops if not devices: devices = () elif not isinstance(devices, (list, tuple, set)): devices = (devices,) self._determine_ops(devices) @classmethod def _assert_device_supported(cls, device: str) -> None: if device not in cls.device_ids: raise ValueError( f"{device} not in supported devices {list(cls.device_ids.keys())}", )
[docs] @classmethod def read_ops_table( cls, device: str = "cpu", ) -> str: """ Generate a markdown table for *device* and returns it. """ cls._assert_device_supported(device) # tf2xla_supported_ops prints the table # catch the stdout put stream and decode into str cmd = f"tf2xla_supported_ops --device={cls.device_ids[device]}" code, out, _ = interruptable_popen(cmd, stdout=PIPE, executable="/bin/bash", shell=True) if code != 0: raise Exception(f"tf2xla_supported_ops command failed with exit code {code}") return out
[docs] @classmethod def parse_ops_table( cls, table: str | None = None, *, device: str = "cpu", ) -> dict[str, dict]: """ Read a given markdown-*table* generated with 'tf2xla_supported_ops' and returns a dictionary contaning all ops with XLA implementation. For a given table the *device* information is ignored and extracted from the table. If no table is given one will be generate for given *device*. """ cls._assert_device_supported(device) # create the table if empty if not table: table = cls.read_ops_table(device) else: with open(table, "r") as txt_file: table = txt_file.read() # split into lines lines = table.splitlines() # first line contains device information for device, device_id in cls.device_ids.items(): if device_id in lines[0]: break else: raise ValueError(f"no device string found in table header '{lines[0]}'") # read op infos from table lines ops = {} content_started = False cre = re.compile(r"^\`([^\`]+)\`\s+\|\s*(.*)$") for line in lines[1:]: line = line.strip() # find the beginning of the table if not content_started: if line.startswith("---"): content_started = True continue # check if the end is reached if not line: break # parse the line m = cre.match(line) if not m: print(f"error parsing table line: {line}", file=sys.stderr) continue op_name, allowed_types = m.groups() allowed_types = allowed_types.replace("`", "").replace("<br>", "") # save op data ops[op_name] = { "name": op_name, "device": device, "allowed_types": allowed_types, } return ops
def _determine_ops(self: OpsData, devices: tuple[str] | None = None) -> None: """ Merges multiple tables of different devices into 1 dictionary. WARNING: Since its not possible to see from which version the markdown table is generated, try to not mix tables from different tensorflow versions. """ if not devices: devices = tuple(self.device_ids.keys()) # read op dictionaries all_op_dicts = [ self.parse_ops_table(device=device) for device in devices ] # merge ops = {} for op_dicts in all_op_dicts: for op_data in op_dicts.values(): op_name = op_data["name"] if op_name not in ops: ops[op_name] = {} ops[op_name][op_data["device"]] = op_data["allowed_types"] self._ops = ops def _get_unique_ops(self: OpsData, device: str | None = None) -> set[str]: self._assert_device_supported(device) return { op_name for op_name, op_data in self._ops.items() if device is None or op_data.get(device) } @property def cpu_ops(self: OpsData) -> set[str]: # get unique XLA compatible results for CPU only return self._get_unique_ops("cpu") @property def gpu_ops(self: OpsData) -> set[str]: # get unique XLA compatible results for GPU only return self._get_unique_ops("gpu") @property def ops(self: OpsData) -> set[str]: # get unique ops that have CPU or GPU implementation return self._ops def __len__(self: OpsData) -> int: # number of ops return len(self._ops) def __getitem__(self: OpsData, key: str) -> dict: return self._ops[key] def keys(self: OpsData) -> list[str]: return list(self._ops.keys()) def values(self: OpsData) -> list[dict]: return list(self._ops.values()) def items(self: OpsData) -> list[tuple[str, dict]]: return list(self._ops.items()) def get(self: OpsData, *args, **kwargs) -> tuple[str, dict]: return self._ops.get(*args, **kwargs)
[docs]def get_graph_ops(graph_def: GraphDef, node_def_number: int = 0) -> list[str]: """ Extracts all ops from a *graph_def* and returns them as a list. If there are multiple ``FunctionDef`` instances in the graph, set *node_def_number* to specify from which GraphDef the ops should be extracted. """ # extract node definition from the graph "library for savedmodels" num_funcs = len(graph_def.library.function) # library is empty for graph.pb, but not for SavedModels if num_funcs == 0: node_def = graph_def.node else: if node_def_number + 1 > num_funcs: raise AttributeError( f"node_def_number {node_def_number} does not match amount of {num_funcs} " "FunctionDef objects in graph", ) node_def = graph_def.library.function[node_def_number].node_def op_names = [node.op for node in node_def] return sorted(set(op_names), key=op_names.index)