浅笑の博客

我们的征途是星辰大海


  • 首页

  • 标签29

  • 分类6

  • 归档47

  • 留言板

  • 搜索

TensorFlow学习笔记2:图(Graph)与会话(Session)机制

发表于 2019-08-06 分类于 深度学习 , Python Valine: 本文字数: 3.4k


计算图是TensorFlow的核心概念,使用图(Graph)来表示计算任务,由节点和边组成。TensorFlow由前端负责构建计算图,后端负责执行计算图。
为了执行图的计算,图必须在会话(Session)里面启动,会话将图的操作分发到CPU、GPU等设备上执行。
下面将介绍如何在TensorFlow里面创建会话、图以及基本操作。

图(Graph)

TensorFlow Python库已经有一个默认图 (default graph),如果没有创建新的计算图,则默认情况下是在这个default graph里面创建节点和边。
在图里面添加节点非常方便。例如现在要创建这样的计算图,两个张量相加,如下图:

代码如下:

1
2
3
4
import tensorflow as tf 
a=tf.constant([1.0,2.0], name='a')
b=tf.constant([3.0,4.0], name='b')
result = tf.add(a,b)

现在默认图就有了三个节点,两个constant(),和一个add()。
为了真正使两个张量相加并得到结果,就必须在会话里面启动这个图。

会话(Session)

会话的创建

要启动计算图,首先要创建一个Session对象。
使用tf.Session()创建会话,调用run()函数执行计算图。如果没有传入任何创建参数,会话构造器将启动默认图。如果要指定某个计算图,则传入计算图参数(如g1),则创建会话方式为tf.Session(graph=g1)创建会话(Session)主要有以下三种方式:

  1. 创建一个会话

    1
    2
    3
    4
    5
    6
    7
    8
    #启动默认图
    sess=tf.Session()
    result_value = sess.run(result)
    print(result_value)
    # ==> [4.0 6.0]

    # 任务完成, 关闭会话.
    sess.close()
  2. 创建一个会话

    1
    2
    3
    4
    with tf.Session() as sess:
    result_value = sess.run(result)
    print(result_value)
    # ==> [4.0 6.0]
  3. 创建一个默认会话

    1
    2
    3
    4
    sess=tf.Session()
    with sess.as_default():
    result_value = result.eval()
    print(result_value)

当指定默认会话后,可以通过tf.Tensor.eval函数来计算一个张量的取值。

  1. 创建一个交互式会话
    在交互式环境下(例如IPython),使用设置默认会话的方式来获取张量的取值更加方便,TensorFlow提供了一种在交互式环境下直接构建默认会话的函数:tf.InteractiveSession,该函数会自动将生成的会话注册为默认会话,使用 tf.Tensor.eval()代替 Session.run(),代码如下:
    1
    2
    3
    4
    sess= tf.InteractiveSession()
    result_value = result.eval()
    print(result_value)
    sess.close()

Fetch(取回)

在使用sess.run( )运行图时,我们可以传入fetches,用于取回某些操作或tensor的输出内容。fetches可以是list,tuple,namedtuple,dict中的任意一个。fetches可以是一个列表,在op的一次运行中一起获得(而不是逐个去获取 tensor)多个tensor值。

1
2
c = sess.run(a)            # fetches可以为单个数a
d = sess.run([a, b]) # fetches可以为一个列表[a, b]

Feed(注入)

TensorFlow提供了feed注入机制, 它可以临时替代graph中任意op操作的输入tensor,可以对graph中任何操作提交补丁(直接插入一个tensor)。
feed机制只在调用它的方法内有效,方法结束,feed就会消失。最常见的用例是把某些特殊操作为feed注入的对象。你可以提供数据feed_dict,作为sess.run( )调用的参数。使用tf.placeholder( ),为某些操作的输入创建占位符。

1
2
3
4
5
6
7
8
9
10
11
12
13
import tensorflow as tf
import numpy as np

x = np.ones((2, 3))
y = np.ones((3, 2))

input1 = tf.placeholder(tf.int32)
input2 = tf.placeholder(tf.int32)

output = tf.matmul(input1, input2)

with tf.Session() as sess:
print(sess.run(output, feed_dict = {input1:x, input2:y}))

如果没有正确提供tf.placeholder( ),feed操作将产生错误。注意,feed注入的值不能是tf的tensor对象,应该是Python常量、字符串、列表、numpy ndarrays,或者TensorHandles。

构建多个计算图

在TensorFlow中可以构建多个计算图,计算图之间的张量和运算是不会共享的,通过这种方式,可以在同个项目中构建多个网络模型,而相互之间不会受影响。
使用tf.Graph()函数构建图,构建多个计算图的方式如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 构建计算图g1
g1=tf.Graph()
with g1.as_default():
# 在计算图g1中定义变量'v',并设置初始值为0。
v=tf.get_variable('v',initializer=tf.zeros_initializer()(shape = [1]))

# 构建计算图g2
g2=tf.Graph()
with g2.as_default():
# 在计算图g2中定义变量'v',并设置初始值微1。
v=tf.get_variable('v',initializer=tf.ones_initializer()(shape = [1]))

# 在计算图g1中读取变量'v'的取值
with tf.Session(graph=g1) as sess:
tf.global_variables_initializer().run()
with tf.variable_scope('',reuse=True):
print(sess.run(tf.get_variable('v')))
# 输出结果[0.]

# 在计算图g2中读取变量'v'的取值
with tf.Session(graph=g2) as sess:
tf.global_variables_initializer().run()
with tf.variable_scope('',reuse=True):
print(sess.run(tf.get_variable('v')))
# 输出结果[1.]。

指定运行设备

如果电脑有多个GPU,可以在图、会话中指定要运行的设备

  • 在图中指定运行设备

    1
    2
    3
    4
    g=tf.Graph()
    # 指定计算运行的设备。
    with g.device('/gpu:0'):
    result=tf.add(a,b)
  • 在会话中指定运行设备

    1
    2
    3
    with tf.Session() as sess:
    with tf.device("/gpu:0"):
    result=tf.add(a,b)

运行的设备用字符串进行标识,目前支持的设备包括:
“/cpu:0”: 机器的 CPU
“/gpu:0”: 机器的第一个 GPU,如果有的话
“/gpu:1”: 机器的第二个 GPU,以此类推

TensorFlow
TensorFlow学习笔记1:张量与变量
TensorFlow学习笔记3:激励函数
Zheng Yujie

Zheng Yujie

C++/Python/深度学习
47 日志
6 分类
29 标签
目录
  1. 1. 图(Graph)
  2. 2. 会话(Session)
    1. 2.1. 会话的创建
    2. 2.2. Fetch(取回)
    3. 2.3. Feed(注入)
  3. 3. 构建多个计算图
  4. 4. 指定运行设备
© 2019 Zheng Yujie | 全站共199k字
浙ICP备 - 19035016号
0%