cmsml.tensorflow#

Classes, functions and tools for efficiently working with TensorFlow.

Functions:

import_tf([log_level, autograph_verbosity])

Imports TensorFlow and returns a 3-tuple containing the module itself, the v1 compatibility API (i.e.

save_frozen_graph(path, obj[, ...])

Extracts a TensorFlow graph from an object obj and saves it at path.

save_graph(path, obj[, ...])

Deprecated.

load_frozen_graph(path[, create_session, ...])

Reads a saved TensorFlow graph from path and returns it.

load_graph(path[, create_session, ...])

Deprecated.

write_graph_summary(graph, summary_dir, **kwargs)

Writes the summary of a graph to a directory summary_dir using a tf.summary.FileWriter (v1) or tf.summary.create_file_writer (v2).

load_model(model_path)

Load and return the SavedModel stored at model_path.

load_graph_def(model_path[, serving_key])

Loads the model saved at model_path and returns the GraphDef of it.

get_graph_ops(graph_def[, node_def_number])

Extracts all ops from a graph_def and returns them as a list.

Classes:

OpsData([devices])

AOT needs two requirements to work:

import_tf(log_level: int | str = 'WARNING', autograph_verbosity: int = 3) tuple[module, module | None, tuple[int]][source]#

Imports TensorFlow and returns a 3-tuple containing the module itself, the v1 compatibility API (i.e. the TensorFlow module itself if v1 is the primarily installed version), and the package version as a 3-tuple containing integers. Example:

tf, tf1, tf_version = import_tf()

At some point in the future, when v1 support might get fully removed from TensorFlow 2 or higher, the second tuple element might be None.

The verbosity of logs printed by TensorFlow and AutoGraph can be controlled through log_level and autograph_verbosity.

save_frozen_graph(path: str, obj: Any, variables_to_constants: bool = False, output_names: list[str] | None = None, *args, **kwargs) None[source]#

Extracts a TensorFlow graph from an object obj and saves it at path. The graph is optionally transformed into a simpler representation with all its variables converted to constants when variables_to_constants is True. The saved file contains the graph as a protobuf. The accepted types of obj greatly depend on the available API versions.

When the v1 API is found (which is also the case when tf.compat.v1 is available in v2), Graph, GraphDef and Session objects are accepted. However, when variables_to_constants is True, obj must be a session and output_names should refer to names of operations whose subgraphs are extracted (usually just one).

For TensorFlow v2, obj can also be a compiled keras model, or either a polymorphic or concrete function as returned by tf.function. Polymorphic functions either must have a defined input signature (tf.function(input_signature=(...,))) or they must accept no arguments in the first place. See the TensorFlow documentation on concrete functions for more info.

args and kwargs are forwarded to tf.train.write_graph (v1) or tf.io.write_graph (v2).

save_graph(path: str, obj: Any, variables_to_constants: bool = False, output_names: list[str] | None = None, *args, **kwargs) None[source]#

Deprecated. Please use save_frozen_graph().

load_frozen_graph(path: str, create_session: bool | None = None, session_kwargs: dict | None = None, as_text: bool | None = None) Session' at 0x7f41c0dc4e10>][source]#

Reads a saved TensorFlow graph from path and returns it. When create_session is True, a session object (compatible with the v1 API) is created and returned as the second value of a 2-tuple. The default value of create_session is True when TensorFlow v1 is detected, and False otherwise. In case a session is created, session_kwargs are forwarded to the session constructor as keyword arguments when set. When as_text is either True or None, and the file extension is ".pbtxt" or ".pb.txt", the content of the file at path is expected to be a human-readable text file. Otherwise, it is read as a binary protobuf file. Example:

graph = load_frozen_graph("path/to/model.pb", create_session=False)

graph, session = load_frozen_graph("path/to/model.pb", create_session=True)
load_graph(path: str, create_session: bool | None = None, session_kwargs: dict | None = None, as_text: bool | None = None) Session' at 0x7f41c0dc71d0>][source]#

Deprecated. Please use load_frozen_graph().

write_graph_summary(graph: <MockModule 'tensorflow.Graph' at 0x7f41c0dc7c90>, summary_dir: str, **kwargs) None[source]#

Writes the summary of a graph to a directory summary_dir using a tf.summary.FileWriter (v1) or tf.summary.create_file_writer (v2). This summary can be used later on to visualize the graph via tensorboard. graph can be either a graph object or a path to a protobuf file. In the latter case, load_frozen_graph() is used and all kwargs are forwarded.

Note

When used with TensorFlow v1, eager mode must be disabled.

load_model(model_path: str) <MockModule 'tensorflow.Model' at 0x7f41c0dc7710>[source]#

Load and return the SavedModel stored at model_path. If the model was saved using keras it will be loaded using keras SavedModel API, otherwise tensorflow’s SavedModel API is used.

load_graph_def(model_path: str, serving_key: str = <MockModule 'tensorflow.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY'>) GraphDef[source]#

Loads the model saved at model_path and returns the GraphDef of it. Supported input types are tensorflow and keras SavedModels, as well as frozen graphs.

class OpsData(devices: tuple[str] | None = None)[source]#
AOT needs two requirements to work:
  1. the outcome of an ops-kernel needs to be deterministic

  2. the ops-kernel needs to have an XLA implementation.

Tensorflow can return a markdown table containing all XLA compatible ops. This class is a wrapper to create this table and consequently read it.

Methods:

read_ops_table([device])

Generate a markdown table for device and returns it.

parse_ops_table([table, device])

Read a given markdown-table generated with 'tf2xla_supported_ops' and returns a dictionary contaning all ops with XLA implementation.

classmethod read_ops_table(device: str = 'cpu') str[source]#

Generate a markdown table for device and returns it.

classmethod parse_ops_table(table: str | None = None, *, device: str = 'cpu') dict[str, dict][source]#

Read a given markdown-table generated with ‘tf2xla_supported_ops’ and returns a dictionary contaning all ops with XLA implementation. For a given table the device information is ignored and extracted from the table. If no table is given one will be generate for given device.

get_graph_ops(graph_def: GraphDef, node_def_number: int = 0) list[str][source]#

Extracts all ops from a graph_def and returns them as a list. If there are multiple FunctionDef instances in the graph, set node_def_number to specify from which GraphDef the ops should be extracted.