我就废话不多说了,大家还是直接看代码吧~
import tensorflow as tf n1 = tf.constant(2) n2 = tf.constant(3) n3 = tf.constant(4) def cond1(i, a, b): return i < n1 def cond2(i, a, b): return i < n2 def cond3(i, a, b): return i < n3 def body(i, a, b): return i + 1, b, a + b i1, a1, b1 = tf.while_loop(cond1, body, (2, 1, 1)) i2, a2, b2 = tf.while_loop(cond2, body, (2, 1, 1)) i3, a3, b3 = tf.while_loop(cond3, body, (2, 1, 1)) sess = tf.Session() print(sess.run(i1)) print(sess.run(a1)) print(sess.run(b1)) print("-") print(sess.run(i2)) print(sess.run(a2)) print(sess.run(b2)) print("-") print(sess.run(i3)) print(sess.run(a3)) print(sess.run(b3))
print结果:
2 1 1 - 3 1 2 - 4 2 3
可见body函数返回的三个变量又传给了body
补充知识:tensorflow在tf.while_loop循环(非一般循环)中使用操纵变量该怎么做
代码(操纵全局变量)
xiaojie=1 i=tf.constant(0,dtype=tf.int32) batch_len=tf.constant(10,dtype=tf.int32) loop_cond = lambda a,b: tf.less(a,batch_len) #yy=tf.Print(batch_len,[batch_l<b>本文来源gao@!dai!ma.com搞$$代^@码!网</b>en],"batch_len:") yy=tf.constant(0) loop_vars=[i,yy] def _recurrence(i,yy): c=tf.constant(2,dtype=tf.int32) x=tf.multiply(i,c) global xiaojie xiaojie=xiaojie+1 print_info=tf.Print(x,[x],"x:") yy=yy+print_info i=tf.add(i,1) # print (xiaojie) return i,yy i,yy=tf.while_loop(loop_cond,_recurrence,loop_vars,parallel_iterations=1)#可以批处理 sess = tf.Session() print (sess.run(i)) print (xiaojie)
输出的是10和2。
也就是xiaojie只被修改了一次。
这个时候,在_recurrence循环体中添加语句
print (xiaojie)
会输出2。而且只输出一次。具体为什么,最后总结的时候再解释。
代码(操纵类成员变量)class RNN_Model():
def __init__(self): self.xiaojie=1 def test_RNN(self): i=tf.constant(0,dtype=tf.int32) batch_len=tf.constant(10,dtype=tf.int32) loop_cond = lambda a,b: tf.less(a,batch_len) #yy=tf.Print(batch_len,[batch_len],"batch_len:") yy=tf.constant(0) loop_vars=[i,yy] def _recurrence(i,yy): c=tf.constant(2,dtype=tf.int32) x=tf.multiply(i,c) self.xiaojie=self.xiaojie+1 print_info=tf.Print(x,[x],"x:") yy=yy+print_info i=tf.add(i,1) print ("_recurrence:",self.xiaojie) return i,yy i,yy=tf.while_loop(loop_cond,_recurrence,loop_vars,parallel_iterations=1)#可以批处理 sess = tf.Session() sess.run(yy) print (self.xiaojie) if __name__ == "__main__": model = RNN_Model()#构建树,并且构建词典 model.test_RNN()
输出是:
_recurrence: 2 10 2
tf.while_loop操纵全局变量和类成员变量总结
为什么_recurrence中定义的print操作只执行一次呢,这是因为_recurrence中的print相当于一种对代码的定义,直接在定义的过程中就执行了。所以,可以看到输出是在sess.run之前的。但是,定义的其它操作就是数据流图中的操作,需要在sess.run中执行。
就必须在sess.run中执行。但是,全局变量xiaojie也好,还是类成员变量xiaojie也好。其都不是图中的内容。因此,tf.while_loop执行的是tensorflow计算图中的循环,对于不是在计算图中的,就不会参与循环。注意:而且必须是与loop_vars中指定的变量存在数据依赖关系的tensor才可以!此外,即使是依赖关系,也必须是_recurrence循环体中return出的变量,才会真正的变化。比如,见下面的self.L。总之,想操纵变量,就要传入loop_vars!