TensorFlow核心组件系列之Graph的底层机制探索

2023-05-3016:47:23人工智能与大数据Comments828 views字数 6636阅读模式

Graph(计算图)是TensorFlow的核心组件。在TensorFlow中,Graph承担着重要的角色,用于表示深度学习模型的计算过程和数据流动。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

Graph的底层机制涉及TensorFlow框架的前端和后端系统,它们协同工作以构建和执行计算图。前端系统负责构造计算图,定义模型的结构和操作,而后端系统则提供运行时环境,负责执行计算图中的操作。通过深入研究Graph的底层机制,我们可以揭示TensorFlow框架的内部工作原理,为模型的设计和优化提供指导。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

下面我们先来简单介绍下Graph:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

在TensorFlow中,Graph是由节点(Node)和边(Edge)组成的有向无环图。节点(Node)代表计算操作,而边(Edge)则表示数据流动。每个节点可以接收多个输入边和输出边,形成数据的传递和转换关系。这种节点和边的组织结构能够清晰地表示模型的计算过程和依赖关系,方便进行模型优化和并行计算。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

在TensorFlow的前端系统中,使用Operation来表示图中的节点实例,而Tensor则表示图中的边实例,用来连接Op节点。Op包含了节点的计算逻辑和操作类型,而Tensor承载了数据并在节点之间传递。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

在后端C++系统中,TensorFlow 定义Edge和Node来表示计算图,此外边缘(Edge)和张量(Tensor)之间也存在着相互调用的关系。Edge是连接Node的连接线,用于传递数据和建立计算依赖关系,而Tensor则是在Edge上携带数据的载体。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

Tensor在TF后端系统中维护了底层数据的指针和形状信息,通过引用计数的方式进行数据的生命周期管理。这种设计使得TensorFlow能够实现延迟计算和内存复用,提高了计算效率和资源利用率。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

1. 计算图Graph的功能特性

首先,Graph的特性之一是灵活性。TensorFlow的计算图允许用户自由定义模型的结构和操作,从而满足各种复杂的深度学习需求。通过添加、删除或修改节点和边,我们可以灵活地设计和调整计算图,以适应不同任务和模型架构的要求。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

其次,Graph的特性之一是计算图的优化。TensorFlow提供了丰富的优化技术,通过对计算图进行优化,可以提高模型的性能和效率。例如,通过图剪枝技术可以去除无用的节点和边,减少计算和内存消耗。同时,使用并行计算技术可以将计算图中的操作并行执行,加快模型的训练和推断速度。此外,还可以使用量化技术对计算图中的张量进行精度压缩,降低模型的存储和计算成本。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

此外,Graph还为模型的可视化和调试提供了支持。TensorFlow提供了可视化工具,可以将计算图以图形化的方式呈现,帮助开发者直观地理解模型的结构和数据流动。通过分析计算图,我们可以定位和解决潜在的问题,提高模型的稳定性和可靠性。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

2. 前端(Python)Graph的定义

一个 Graph 对象将包含一系列 Operation 对象,表示计算单元 的集合。同时,它间接持有一系列 Tensor 对象,表示数据单元的集合。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

下面我们先来看看前端Python端Graph的定义:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

class Graph(object):
       def __init__(self):
         self._lock = threading.Lock()
         self._nodes_by_id = dict()    # GUARDED_BY(self._lock)
         self._next_id_counter = 0     # GUARDED_BY(self._lock)
         self._nodes_by_name = dict()  # GUARDED_BY(self._lock)
         self._registered_ops = op_def_registry.get_registered_ops()

        def _add_op(self, op):
                self._check_not_finalized()
                if not isinstance(op, (Tensor, Operation)):
                  raise TypeError("op must be a Tensor or Operation: %s" % op)
                with self._lock:
                  # pylint: disable=protected-access
                  if op._id in self._nodes_by_id:
                    raise ValueError("cannot add an op with id %d as it already "
                                     "exists in the graph" % op._id)
                  if op.name in self._nodes_by_name:
                    raise ValueError("cannot add op with name %s as that name "
                                     "is already used" % op.name)
                  self._nodes_by_id[op._id] = op
                  self._nodes_by_name[op.name] = op
                  self._version = max(self._version, op._id)

         def add_to_collection(name, value):
                      get_default_graph().add_to_collection(name, value)

         # 替换线程默认图
               def as_default(self):
                   return _default_graph_stack.get_controller(self)

               # 栈式管理,push pop
               @tf_contextlib.contextmanager
               def get_controller(self, default):
                    try:
                      context.context_stack.push(default.building_function, default.as_default)
                    finally:
                      context.context_stack.pop()

         def create_op(
                  self,
                  op_type,
                  inputs,
                  dtypes=None,  # pylint: disable=redefined-outer-name
                  input_types=None,
                  name=None,
                  attrs=None,
                  op_def=None,
                  compute_shapes=True,
                  compute_device=True):
                        for idx, a in enumerate(inputs):
                          if not isinstance(a, Tensor):
                            raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
                        return self._create_op_internal(op_type, inputs, dtypes, input_types, name,
                                                        attrs, op_def, compute_device)

从源码可以看出,为了快速索引图中的节点信息,在当前Graph的作用域内为每个 Operation 分配唯一的 id, 并在Graph中存储 _nodes_by_id 的数据字典。同时,为了可以根据节点的名字快速索引节点信息,在Graph中也存储了 _nodes_by_name 的数据字典。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

在图构造期,OP 通过 OP 构造器创建,最终被添加至当前的 Graph 实例中。当图被冻 结后,便不能往图中追加节点了,使得 Graph 实例在多线程中被安全地共享。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

堆栈管理的默认图文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

此外从图管理可以看出,默认图采用了堆栈的管理方式,通过push pop操作进行管理。而当前的默认图就是栈顶的图。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

我们来举个例子:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

print tf.get_default_graph()

with tf.Graph().as_default() as g:
    print tf.get_default_graph()

print tf.get_default_graph()

<tensorflow.python.framework.ops.Graph object at 0x106329fd0>
<tensorflow.python.framework.ops.Graph object at 0x18205cc0d0>
<tensorflow.python.framework.ops.Graph object at 0x10d025fd0>

从上面可以看出,当我们在作用域内创建新图并将其作为默认图,但在我们退出作用域后,又变为原来的默认图。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

图的创建工厂文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

下面我们来说明下图是如何被创建出来的?文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

当 Client 使用 OP 构造器创建一个 Operation 实例时,将最 终调用 Graph.create_op 方法,将该 Operation 实例注册到该图实例中。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

也就是说,一方面,Graph 充当 Operation 的工厂,负责 Operation 的创建职责;另一 方面,Graph 充当 Operation 的仓库,负责 Operation 的存储,检索,转换等操作。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

TensorFlow核心组件系列之Graph的底层机制探索

这个过程常称为计算图的构造。在计算图的构造期间,并不会触发运行时的 OP 运算, 它仅仅描述计算节点之间的依赖关系,并构建 DAG 图,对整个计算过程做整体规划。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

此外,TF还提供了GraphKey类,进行更加方便管理和检索节点信息:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

class GraphKeys(object):
    GLOBAL_VARIABLES = "variables"
  # Key to collect local variables that are local to the machine and are not
  # saved/restored.
  LOCAL_VARIABLES = "local_variables"
  # optimizers.
  TRAINABLE_VARIABLES = "trainable_variables"
  SAVERS = "savers"
  # Key to collect weights
  WEIGHTS = "weights"
  # Key to collect biases
  BIASES = "biases"
  # Key to collect activations
  ACTIVATIONS = "activations"
  # Key to collect update_ops
  UPDATE_OPS = "update_ops"
  # Key to collect losses
  LOSSES = "losses"
  ...

  # Key to indicate various ops.
  INIT_OP = "init_op"
  LOCAL_INIT_OP = "local_init_op"
  SUMMARY_OP = "summary_op"
  GLOBAL_STEP = "global_step"

  # Used to count the number of evaluations performed during a single evaluation
  # run.
  EVAL_STEP = "eval_step"
  TRAIN_OP = "train_op"

  # Key for control flow context.
  COND_CONTEXT = "cond_context"
  WHILE_CONTEXT = "while_context"

  # Used to store v2 summary names.
  _SUMMARY_COLLECTION = "_SUMMARY_V2"

  # List of all collections that keep track of variables.
  _VARIABLE_COLLECTIONS = [
      GLOBAL_VARIABLES,
      LOCAL_VARIABLES,
      METRIC_VARIABLES,
      MODEL_VARIABLES,
      TRAINABLE_VARIABLES,
      MOVING_AVERAGE_VARIABLES,
      CONCATENATED_VARIABLES,
      TRAINABLE_RESOURCE_VARIABLES,
  ]

# 用户要快速的检索某类变量可以通过这样的语句
all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

用户可以通过例如tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)语句快速检索变量,也可以自定义分组。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

3. 后端(C++)Graph数据结构

Graph(计算图) 就是节点与边的集合。计算图是一个 DAG 图,计算图的执行过程将按照 DAG 的拓扑排序,依次启动 OP 的运算。其中,如果存在多个入度为 0 的节点,TensorFlow 运行时可以实现并发,同时执行多个 OP 的运算,提高执行效率。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

class Graph {
     private:
      // 所有已知的op计算函数的注册表
      FunctionLibraryDefinition ops_;

      // GraphDef版本号
      const std::unique_ptr<VersionDef> versions_;

      // 节点node列表,通过id来访问
      std::vector<Node*> nodes_;

      // node个数
      int64 num_nodes_ = 0;

      // 边edge列表,通过id来访问
      std::vector<Edge*> edges_;

      // graph中非空edge的数目
      int num_edges_ = 0;

      // 已分配了内存,但还没使用的node和edge
      std::vector<Node*> free_nodes_;
      std::vector<Edge*> free_edges_;

     const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
       auto e = AllocEdge();
       e->src_ = source;
       e->dst_ = dest;
       e->src_output_ = x;
       e->dst_input_ = y;
       CHECK(source->out_edges_.insert(e).second);
       CHECK(dest->in_edges_.insert(e).second);
       edges_.push_back(e);
       edge_set_.insert(e);
       return e;
        }
 }

后端中的Graph主要成员也是节点node和边edge。节点node为计算算子Operation,边为算子所需要的数据,或者代表节点间的依赖关系。边Edge的持有它的源节点和目标节点的指针,从而将两个节点连接起来。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

由Node和Edge,即可以组成图Graph,通过任何节点和任何边,都可以遍历完整图。Graph执行计算时,按照拓扑结构,依次执行每个Node的op计算,最终即可得到输出结果。入度为0的节点,也就是依赖数据已经准备好的节点,可以并发执行,从而提高运行效率。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

系统中存在默认的Graph,初始化Graph时,会添加一个Source节点和Sink节点。Source表示Graph的起始节点,Sink为终止节点。Source的id为0,Sink的id为1,其他节点id均大于1。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

此外,Graph数据结构提供了一系列方法来构建和修改计算图。通过这些方法,我们可以添加和删除节点,连接节点之间的边,以及设置节点的属性和操作。例如,可以通过调用graph->AddNode()方法来添加一个新的节点,并通过node->AddInputEdge()方法来添加输入边。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

4. 总结

前端(Python)Graph的定义主要包括:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

  1. Graph对象包含一系列Operation对象和Tensor对象,用于表示计算单元和数据单元的集合。
  2. 通过为每个Operation分配唯一的id并存储在_nodes_by_id字典中,以及根据节点名字存储在_nodes_by_name字典中,实现快速索引和检索节点信息。
  3. 在图构造期间,通过OP构造器创建Operation,并将其添加到当前的Graph实例中。
  4. 默认图采用堆栈管理方式,通过push和pop操作进行管理,当前默认图是栈顶的图。

后端(C++)Graph数据结构的定义主要包括:文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

  1. 后端中的Graph成员包括是节点node和边edge,在数据结构中分别维护了节点node和边edge的vector数组,是一个典型的图定义数据结构。
  2. 默认的Graph在初始化时会添加一个Source节点和一个Sink节点。Source节点表示计算图的起始节点,而Sink节点表示计算图的终止节点。Source节点的id为0,Sink节点的id为1,其他节点的id都大于1。
  3. Graph数据结构提供了一系列方法来构建和修改计算图,例如AddEdge与`AddNode`

此外,前端的GAG图生成后是通过protobuf的序列化与反序列化发送到后端进行运行与优化的。文章源自菜鸟学院-https://www.cainiaoxueyuan.com/ai/43460.html

  • 本站内容整理自互联网,仅提供信息存储空间服务,以方便学习之用。如对文章、图片、字体等版权有疑问,请在下方留言,管理员看到后,将第一时间进行处理。
  • 转载请务必保留本文链接:https://www.cainiaoxueyuan.com/ai/43460.html

Comment

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定