这篇文章主要介绍了基于tensorflow for循环 while循环案例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
我就废话不多说了,大家还是直接看代码吧~
import tensorflow as tf n1 = tf.constant(2) n2 = tf.constant(3) n3 = tf.constant(4) def cond1(i, a, b): return i <n1 def cond2(i, a, b): return i </div><p>print结果:</p><div class="gaodaimacode"><pre class="prettyprint linenums"> 2 1 1 - 3 1 2 - 4 2 3
可见body函数返回的三个变量又传给了body
补充知识:tensorflow在tf.while_loop循环(非一般循环)中使用操纵变量该怎么做
代码(操纵全局变量)
xiaojie=1 i=tf.constant(0,dtype=tf.int32) batch_len=tf.constant(10,dtype=tf.int32) loop_cond = lambda a,b: tf.less(a,batch_len) #yy=tf.Print(batch_len,[batch_len],"batch_len:") yy=tf.constant(0) loop_vars=[i,yy] def _recurrence(i,yy): c=tf.constant(2,dtype=tf.int32) x=tf.multiply(i,c) global xiaojie xiaojie=xiaojie+1 print_info=tf.Print(x,[x],"x:") yy=yy+print_info i=tf.add(i,1) # print (xiaojie) return i,yy i,yy=tf.while_loop(loop_cond,_recurrence,loop_vars,parallel_iterations=1)#可以批处理 sess = tf.Session() print (sess.run(i)) print (xiaojie)
输出的是10和2。
也就是xiaojie只被修改了一次。
这个时候,在_recurrence循环体中添加语句
print (xiaojie)
会输出2。而且只输出一次。具体为什么,最后总结的时候再解释。
代码(操纵类成员变量)class RNN_Model():
def __init__(self): self.xiaojie=1 def test_RNN(self): i=tf.constant(0,dtype=tf.int32) batch_len=tf.constant(10,dtype=tf.int32) loop_cond = lambda a,b: tf.less(a,batch_len) #yy=tf.Print(batch_len,[batch_len],"batch_len:") yy=tf.constant(0) loop_vars=[i,yy] def _recurrence(i,yy): c=tf.constant(2,dtype=tf.int32) x=tf.multiply(i,c) self.xiaojie=self.xiaojie+1 print_info=tf.Print(x,[x],"x:") yy=yy+print_info i=tf.add(i,1) print ("_recurrence:",self.xiaojie) return i,yy i,yy=tf.while_loop(loop_cond,_recurrence,loop_vars,parallel_iterations=1)#可以批处理 sess = tf.Session() sess.run(yy) print (self.xiaojie) if __name__ == "__main__": model = RNN_Model()#构建树,并且构建词典 model.test_RNN()
输出是:
_recurrence: 2 10 2
tf.while_loop操纵全局变量和类成员变量总结
为什么_recurrence中定义的print操作只执行一次呢,这是因为_recurrence中的print相当于一种对代码的定义,直接在定义的过程中就执行了。所以,可以看到输出是在sess.run之前的。但是,定义的其它操作就是数据流图中的操作,需要在sess.run中执行。
来源gaodaima#com搞(代@码网就必须在sess.run中执行。但是,全局变量xiaojie也好,还是类成员变量xiaojie也好。其都不是图中的内容。因此,tf.while_loop执行的是tensorflow计算图中的循环,对于不是在计算图中的,就不会参与循环。注意:而且必须是与loop_vars中指定的变量存在数据依赖关系的tensor才可以!此外,即使是依赖关系,也必须是_recurrence循环体中return出的变量,才会真正的变化。比如,见下面的self.L。总之,想操纵变量,就要传入loop_vars!
如果对一个变量没有修改,就可以直接在循环中以操纵类成员变量或者全局变量的方式只读。
self.L与loop_vars中变量有依赖关系,但是并没有真正被修改。
#IIII通过计算将非叶子节点的词向量也放入nodes_tensor中。 iiii=tf.constant(0,dtype=tf.int32) loop____cond = lambda a,b,c,d,e: tf.less(a,self.sentence_length-1)#iiii的范围是0到sl-2。注意,不包括sl-1。这是因为只需要计算sentence_length-1次,就能构建出一颗树 loop____vars=[iiii,columnLinesOfL,node_tensors_cost_tensor,nodes_tensor,tfPrint] def ____recurrence(iiii,columnLinesOfL,node_tensors_cost_tensor,nodes_tensor,tfPrint):#循环的目的是实现Greedy算法 ### #Greedy的主要目标就是确立树结构。 ### c1 = self.L[:,0:columnLinesOfL-1]#这段代码是从RvNN的matlab的源码中复制过来的,但是Matlab的下标是从1开始,并且Matlab中1:2就是1和2,而python中1:2表示的是1,不包括2,所以,有很大的不同。 c2 = self.L[:,1:columnLinesOfL] c=tf.concat([c1,c2],axis=0) p=tf.tanh(tf.matmul(self.W1,c)+tf.tile(self.b1,[1,columnLinesOfL-1])) p_normalization=self.normalization(p) y=tf.tanh(tf.matmul(self.U,p_normalization)+tf.tile(self.bs,[1,columnLinesOfL-1]))#根据Matlab中的源码来的,即重构后,也有一个激活的过程。 #将Y矩阵拆分成上下部分之后,再分别进行标准化。 columnlines_y=columnLinesOfL-1 (y1,y2)=self.split_by_row(y,columnlines_y) y1_normalization=self.normalization(y1) y2_normalization=self.normalization(y2) #论文中提出一种计算重构误差时要考虑的权重信息。具体见论文,这里暂时不实现。 #这个权重是可以修改的。 alpha_cat=1 bcat=1 #计算重构误差矩阵 ## constant1=tf.constant([[1.0,2.0,3.0],[4.0,5.0,6.0],[7.0,8.0,9.0]]) ## constant2=tf.constant([[1.0,2.0,3.0],[1.0,4.0,2.0],[1.0,6.0,1.0]]) ## constructionErrorMatrix=self.constructionError(constant1,constant2,alpha_cat,bcat) y1c1=tf.subtract(y1_normalization,c1) y2c2=tf.subtract(y2_normalization,c2) constructionErrorMatrix=self.constructionError(y1c1,y2c2,alpha_cat,bcat) ################################################################################ print_info=tf.Print(iiii,[iiii],"\niiii:")#专门为了调试用,输出相关信息。 tfPrint=print_info+tfPrint print_info=tf.Print(columnLinesOfL,[columnLinesOfL],"\nbefore modify. columnLinesOfL:")#专门为了调试用,输出相关信息。 tfPrint=print_info+tfPrint print_info=tf.Print(constructionErrorMatrix,[constructionErrorMatrix],"\nbefore modify. constructionErrorMatrix:",summ以上就是基于tensorflow for循环 while循环案例的详细内容,更多请关注gaodaima搞代码网其它相关文章!