[源码笔记]keras源码分析之Container

本篇继续讨论keras的源码结构。

第一篇源码笔记中我们观察了Layer, TensorNode是如何耦合在一起的,而本篇的重点是观察多层网络构成的有向无环图(DAG)。主要涉及的文件为keras/engine/topology.py, 要观察的类是Container

Container对象:DAG的拓扑原型

在第一篇中我们提到,Keras Tensor中增强的\_keras_history属性使得我们仅通过输入和输出的Tensor,就可以构建出整张计算图。而Container对象正是实现了这样的过程。

计算图的构建

DAG计算图的构建在Container对象实例化时完成,主要包括如下几个操作:

1) 记录Container的首尾连接信息

def __init__(self, inputs, outputs, name=None):
  for x in self.outputs:
      layer, node_index, tensor_index = x._keras_history
      self.output_layers.append(layer)
      self.output_layers_node_indices.append(node_index)
      self.output_layers_tensor_indices.append(tensor_index)

  for x in self.inputs:
      layer, node_index, tensor_index = x._keras_history
      self.input_layers.append(layer)
      self.input_layers_node_indices.append(node_index)
      self.input_layers_tensor_indices.append(tensor_index)

2) 从output_tensors开始反向递归构建计算图,采用广度优先的准则,本步的关键是构建nodes_in_decreasing_depth这一队列,这些Node包含的连接信息和深度信息将是后续正向传播和反向训练计算执行顺序的依据。

  def build_map_of_graph(tensor, finished_nodes, nodes_in_progress):
      layer, node_index, tensor_index = tensor._keras_history
      node = layer.inbound_nodes[node_index]
      nodes_in_progress.add(node)

      # 广度优先搜索
      for i in range(len(node.inbound_layers)):
          x = node.input_tensors[i]
          layer = node.inbound_layers[i]
          node_index = node.node_indices[i]
          tensor_index = node.tensor_indices[i]
          # 递归调用
          build_map_of_graph(x, finished_nodes, nodes_in_progress,
                             layer, node_index, tensor_index)

      # 维护两个队列
      finished_nodes.add(node)
      nodes_in_progress.remove(node)
      nodes_in_decreasing_depth.append(node)

  # 反向构建DAG
  for x in self.outputs:
      build_map_of_graph(x, finished_nodes, nodes_in_progress)

3) 计算各节点的深度并按深度标定节点在DAG中的位置

  # 根据队列标定各节点的深度
  for node in reversed(nodes_in_decreasing_depth):
      depth = nodes_depths.setdefault(node, 0)
      previous_depth = layers_depths.get(node.outbound_layer, 0)
      depth = max(depth, previous_depth)
      layers_depths[node.outbound_layer] = depth
      nodes_depths[node] = depth

      for i in range(len(node.inbound_layers)):
          inbound_layer = node.inbound_layers[i]
          node_index = node.node_indices[i]
          inbound_node = inbound_layer.inbound_nodes[node_index]
          previous_depth = nodes_depths.get(inbound_node, 0)
          nodes_depths[inbound_node] = max(depth + 1, previous_depth)

  # 按深度标定各节点的位置
  nodes_by_depth = {}
  for node, depth in nodes_depths.items():
      if depth not in nodes_by_depth:
          nodes_by_depth[depth] = []
      nodes_by_depth[depth].append(node)

  # 按深度标定各层的位置
  layers_by_depth = {}
  for layer, depth in layers_depths.items():
      if depth not in layers_by_depth:
          layers_by_depth[depth] = []
      layers_by_depth[depth].append(layer)

  self.layers_by_depth = layers_by_depth
  self.nodes_by_depth = nodes_by_depth

4)将整个Container并入Node以保持兼容性

  self.outbound_nodes = []
  self.inbound_nodes = []
  Node(outbound_layer=self,
       inbound_layers=[],
       node_indices=[],
       tensor_indices=[],
       input_tensors=self.inputs,
       output_tensors=self.outputs,
       ...)

计算图中的计算

计算在Container对象的call()方法完成,其实现又依靠内部方法run_internal_graph()

def run_internal_graph(self, inputs, masks=None):
       depth_keys = list(self.nodes_by_depth.keys())
       depth_keys.sort(reverse=True)
       # 依据深度
       for depth in depth_keys:
           nodes = self.nodes_by_depth[depth]
           # 对同一深度上的Node进行计算
           for node in nodes:
               layer = node.outbound_layer # Node对应的layer
               reference_input_tensors = node.input_tensors
               reference_output_tensors = node.output_tensors
               computed_data = []
               if len(computed_data) == len(reference_input_tensors):
                   # 在Layer中进行计算
                   with K.name_scope(layer.name):
                       if len(computed_data) == 1:
                           computed_tensor, computed_mask = computed_data[0]
                           output_tensors = _to_list(layer.call(computed_tensor, **kwargs))
                           computed_tensors = [computed_tensor]
                       else:
                           computed_tensors = [x[0] for x in computed_data]
                           output_tensors = _to_list(layer.call(computed_tensors, **kwargs))
       output_tensors = []
       output_masks = []
       for x in self.outputs:
           tensor, mask = tensor_map[str(id(x))]
           output_tensors.append(tensor)
           output_masks.append(mask)
       return output_tensors, output_masks

从上面的代码可以看到计算是依据深度进行的,并通过更新computed_dataoutput_tensor等变量完成整张图的遍历计算。

继续阅读系列第三篇:【源码笔记】keras源码分析之Model

@ddlee

[源码笔记]keras源码分析之Model 深度学习中的权重衰减
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×