from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
import socket
print("Running on computer: ", socket.gethostname())
# To support both python 2 and python 3
from __future__ import division, print_function, unicode_literals
# Common imports
import tensorflow as tf
import numpy as np
import os
# to make this notebook's output stable across runs
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
from IPython.display import clear_output, Image, display, HTML
def strip_consts(graph_def, max_const_size=32):
"""Strip large constant values from graph_def."""
strip_def = tf.GraphDef()
for n0 in graph_def.node:
n = strip_def.node.add()
n.MergeFrom(n0)
if n.op == 'Const':
tensor = n.attr['value'].tensor
size = len(tensor.tensor_content)
if size > max_const_size:
tensor.tensor_content = b"<stripped %d bytes>"%size
return strip_def
def show_graph(graph_def, max_const_size=32):
"""Visualize TensorFlow graph."""
if hasattr(graph_def, 'as_graph_def'):
graph_def = graph_def.as_graph_def()
strip_def = strip_consts(graph_def, max_const_size=max_const_size)
code = """
<script>
function load() {{
document.getElementById("{id}").pbtxt = {data};
}}
</script>
<link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
<div style="height:600px">
<tf-graph-basic id="{id}"></tf-graph-basic>
</div>
""".format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))
iframe = """
<iframe seamless style="width:800px;height:620px;border:1" srcdoc="{}"></iframe>
""".format(code.replace('"', '"'))
display(HTML(iframe))
How to compute?
$$(a+b)*(b+1)$$
The formula gives us the instructions how to compute:
add $a$ and $b$ (let's call the intermediate result $c$)
add $b$ and $1$ (let's call the intermediate result $d$)
multiply $c$ and $d$ (let's call the result $e$)
To actually execute instructions, we need to povide values for the placeholders $a$ and $b$.
Computation graph describes the schematics of the computation
A node can be evaluated if it's parents have values (if necessary, after evaluating the parents)
In Python, executation happens immediately when the commands occur
reset_graph()
try:
del a,b
except NameError:
pass
#/usr/bin/python
(a+b)*(b+1)
#/usr/bin/python
a = 2 # provide values for placeholders 'a' and 'b'
b = 1
(a+b)*(b+1) # now it works
reset_graph()
del a,b
# load tensorflow
import tensorflow as tf
# build a computation graph
a = tf.placeholder(tf.float32, name="a")
b = tf.placeholder(tf.float32, name="b")
y = tf.constant(1., name="one")
c = tf.add(a, b, name="c")
d = tf.add(b, y, name="d")
e = tf.multiply(c, d, name="e")
print(a)
print(b)
print(y)
print(c)
print(d)
print(e)
show_graph(tf.get_default_graph())
Two abstraction layers, important to keep apart: 1. tensorflow graph, 2. python objects(=variables)
reset_graph()
# tensorflow: build a computation graph
a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
c = a+b # convenient shorthand notation.
d = b+1 # note: operations on graph nodes almost
e = c*d # always create new nodes in the graph
print(a) # a is a python variable, it's value is a node in tensorflow's graph
print(b) # b is a python variable, it's value is a node in tensorflow's graph
print(c) # c is a python variable, etc.
print(d)
print(e)
show_graph(tf.get_default_graph())
reset_graph()
del a,b,c,d,e
# tensorflow: build a computation graph
a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
e = (a+b)*(b+1) # creates multiple nodes in the graph, one per operation
print(a) # a is a python variable, it's value is a node in tensorflow's graph
print(b) # b is a python variable, it's value is a node in tensorflow's graph
# intermediate results have graph nodes, but no python variable pointing to them
print(e)
show_graph(tf.get_default_graph())
reset_graph()
# graph names can be arbitrary, independent of python
a = tf.placeholder(tf.float32, name="spam")
b = tf.placeholder(tf.float32, name="ham")
e = tf.multiply(a+b,b+1, name="eggs") # creates multiple nodes, one per operation
print(a)
print(b)
print(e)
show_graph(tf.get_default_graph())
reset_graph()
del a,b,e
# python names can be arbitrary, independent of graph
spam = tf.placeholder(tf.float32)
ham = tf.placeholder(tf.float32)
eggs = (spam+ham)*(ham+1)
print(spam)
print(ham)
print(eggs)
show_graph(tf.get_default_graph())
reset_graph()
del spam,ham,eggs
# tensorflow
import tensorflow as tf
a = tf.placeholder(tf.float32, name="a") # duplicate node names are
b = tf.placeholder(tf.float32, name="a") # automatically de-duplicated
e = tf.multiply(a+b,b+1, name="a")
print(a)
print(b)
print(e)
show_graph(tf.get_default_graph())
reset_graph()
del a,b,e
sess = tf.Session() # open a Session
value_of_e = sess.run(e) # evaluate node 'e'
print("value of e is ", value_of_e)
sess.close() # release resources after use
# build graph
a = tf.placeholder(tf.float32, name="a")
b = tf.placeholder(tf.float32, name="b")
c = a+b
d = b+1
e = c*d
# now evaluate
sess = tf.Session() # open a Session
value_of_e = sess.run(e) # evaluate node 'e'
print("value of e is ", value_of_e)
sess.close() # release resources after use
InvalidArgumentErrorTraceback (most recent call last)
<ipython-input-96-b2928f055672> in <module>()
7 sess = tf.Session() # open a Session
8
----> 9 value_of_e = sess.run(e) # evaluate node 'e'
10 print("value of e is ", value_of_e)
11
...
1338 except KeyError:
1339 pass
-> 1340 raise type(e)(node_def, op, message)
1341
1342 def _extend_graph(self):
InvalidArgumentError: You must feed a value for placeholder tensor 'b' with dtype float
[[Node: b = Placeholder[dtype=DT_FLOAT, shape=<unknown>, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
To evaluate a node, parents must have values, or must be evaluatable.
...same graph as before...
sess = tf.Session() # open a Session
value_of_e = sess.run(e, feed_dict={a: 2, b: 1}) # evaluate node 'e' for some a/b
print("e evaluated for a=2 and b=1 is ", value_of_e)
sess.close() # be a good citizen, release resources after use
sess = tf.Session()
value_of_e = sess.run(e, feed_dict={a: 3, b: 1})
print("e evaluated for a=3 and b=1 is ", value_of_e)
sess.close()
sess = tf.Session()
value = sess.run(e, feed_dict={a: 2, b: 1, c: -1}) # we can specify values for other
print("e evaluated for c=-1 is ", value) # nodes, not just placeholders
value = sess.run(e, feed_dict={b: 1, c: -1}) # if c has a value, a isn't needed
print("e evaluated for c=-1 is ", value)
sess.close()
sess = tf.Session()
othervalue = sess.run(e, feed_dict={c: -1}) # only c is now enough to evaluate e
sess.close()
# 'with' construct closes sess automatically after use
with tf.Session() as sess:
value = sess.run(e, feed_dict={a: 2, b: 1})
print("e=", value)
reset_graph()
del a,b,c,d,e
# what does this code do?
a = tf.placeholder(tf.float32, name="a")
b = a+1
b = tf.log(b, name="b") # we can re-use python variables
b = b+1
print(a)
print(b)
with tf.Session() as sess:
value = sess.run(b, feed_dict={a: 2.})
print("value=", value)
reset_graph()
# what does this code do?
a = tf.placeholder(tf.float32, name="a")
a = a+1
a = tf.log(a, name="b") # we can re-use python variables
a = a+1
with tf.Session() as sess:
value = sess.run(a, feed_dict={a: 2.})
print("value=", value)
print(a)
reset_graph()
del a
Calling a python function will execute it immediately. If a function operates on tensors, typically, this means it will add nodes to the graph.
def f(x):
return x+1
a = tf.placeholder(tf.float32, name="a")
b = f(a)
b = tf.log(b)
b = f(b)
with tf.Session() as sess:
value = sess.run(b, feed_dict={a: 2.})
print("value=", value)
reset_graph()
del a,b
reset_graph()
a = tf.constant(2, name="a")
b = tf.constant(1, name="b")
c = a+b
d = b+1
e = c*d
with tf.Session() as sess:
value_of_c = sess.run(c)
print("c=", value_of_c)
value_of_d = sess.run(d)
print("d=", value_of_d)
value_of_e = sess.run(e) # recomputes c and d
print("e=", value_of_e)
with tf.Session() as sess:
value_of_c,value_of_d,value_of_e = sess.run([c,d,e])
# c and d are evaluated only once
print("c=", value_of_c)
print("d=", value_of_d)
print("e=", value_of_e)
# get tensorflow graph as python object
g=tf.get_default_graph()
print( g.get_operations() ) # quick check what's inside
# one cannot delete individual nodes from the graph
# but one can delete all, starting again with an empty graph
tf.reset_default_graph()
g=tf.get_default_graph()
print( g.get_operations() ) # nothing inside
reset_graph()
a = tf.constant(5)
b = tf.constant(2)
with tf.Session() as sess:
for i in range(10):
value = sess.run(a+b)
print("i=", i, "i*value=", i*value)
# each construct 'a+b' creates a corresponding node in the graph
show_graph(tf.get_default_graph())
reset_graph()
a = tf.constant(5)
b = tf.constant(2)
g = tf.get_default_graph()
g.finalize() # allow no more changes to the graph
with tf.Session() as sess:
for i in range(10):
value = sess.run(a+b) # error, a new node would be created
print("i=", i, "i*value=", i*value)
reset_graph()
a = tf.constant(5)
b = tf.constant(2)
a_plus_b = a+b
with tf.Session() as sess:
for i in range(10):
value = sess.run(a_plus_b)
print("i=", i, "i*value=", i*value)
show_graph(tf.get_default_graph())
reset_graph()
a = tf.Variable(0, name="var") # new variable, to be initialized with value 0
b = a+1
inc = tf.assign(a, b, name="assignment") # shorthand: a.assign(b)
# this creates a node called 'assignment'
# evaluating 'assignment' causes the value of 'b' to be assigned to 'a'
# 'inc' is a python variable that points to the assignment node
print(a)
print(b)
print(inc)
print()
with tf.Session() as sess:
sess.run(a.initializer) # initialize variable 'a'
for i in range(5):
value = sess.run(inc)
print("step ", i, "value=", value)
show_graph(tf.get_default_graph())
reset_graph()
a = tf.Variable(0, name="a") # new variable, to be initialized with value 0
b = tf.add(a, 1) # values of 'a' plus 1
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(a.initializer) # initialize variable 'a'
for i in range(5):
value = sess.run(b) # evaluating 'b' does not change 'a'
print("step ", i, "value=", value)
reset_graph()
a = tf.Variable(0, name="a") # new variable
with tf.Session() as sess:
sess.run(a.initializer) # assign value 0 to variable 'a'
for i in range(5):
a = a+1 # creates new node in every step: a+1, a+1+1,...
value = sess.run(a)
print("step ", i, "value=", value)
# in a real program, this would run out of memory quickly!
show_graph(tf.get_default_graph())
reset_graph()
a = tf.Variable(0, name="a")
inc = a.assign(a+1) # convenient shorthand for tf.assign(a,...)
with tf.Session() as sess:
sess.run(a.initializer) # initialize variable 'a'
for i in range(5):
value = sess.run(inc)
print("step ", i, "value=", value)
reset_graph()
# with a numeric constant, as before
a = tf.Variable(1, name="a")
print(a)
# with a numpy object
random_numpy = np.random.randn()
b = tf.Variable(random_numpy, name="c")
print(b)
# with a (here generated) tensorflow constant
random_tensor = tf.random_normal([], mean=1, stddev=0.5)
c = tf.Variable(random_tensor, name="b")
print(c)
# with another tensorflow variable
d = tf.Variable(a, name="d")
print(d)
# every call creates a new variable (accessing old ones: later)
e = tf.Variable(1, name="a")
print(e)
init = tf.global_variables_initializer() # creates a new node in the graph
with tf.Session() as sess:
sess.run(init)
res_a,res_b,res_c,res_d,res_e = sess.run([a,b,c,d,e])
print("a=",res_a)
print("b=",res_b)
print("c=",res_c)
print("d=",res_d)
print("e=",res_e)
reset_graph()
a = tf.placeholder(tf.float32,[5,3,2]) # 3-dim tensor of size 5x3x2
print(a)
# with a list of number -> vector
digits = tf.constant([3,1,4]) # constant 1-dim tensor with entries [3,1,4]
print(digits)
# initialize a variable with a numpy object
random_numpy = np.random.randn(3,3,1,3) # 4-dim tensor of size 3x3x1x3
c = tf.Variable(random_numpy)
print(c)
# most functions operate componentwise on tensors
d = tf.exp(c+1)
print(d)
# special functions change the shape, e.g. "reductions"
e = tf.reduce_max(d) # find maximum of all entries
f = tf.reduce_sum(d, axis=0) # sum along first axis
print(e)
print(f)
# e.g. get a tensor's shape
s = c.shape # result is a python object: a list of length 3
print(s)
# e.g. get a tensor's shape
s = tf.shape(c) # result is a node in the graph that holds a vector of length 3
print(s)
# e.g. get a tensor's shape
s = tf.shape(c) # result is a node in the graph that holds a vector of length 3
with tf.Session() as sess:
val = sess.run(s) # evaluate
print(val)
reset_graph()
import urllib2
# let's load some data in numpy
fid = urllib2.urlopen("https://cvml.ist.ac.at/courses/DLWT_W17/data/Xtrain.txt")
Xdata = np.loadtxt(fid)
fid = urllib2.urlopen("https://cvml.ist.ac.at/courses/DLWT_W17/data/Ytrain.txt")
Ydata = np.loadtxt(fid).reshape(-1, 1)
print("data shape = ", Xdata.shape)
print("labels shape = ", Ydata.shape)
Xtrn,Ytrn = Xdata[::2],Ydata[::2] # half of the points for training
ntrn,dim = Xtrn.shape
Xval,Yval = Xdata[1::2],Ydata[1::2] # rest for model evaluation
$$\text{inputs:}X=\begin{pmatrix}x_1\\x_2\\\vdots\\x_n\end{pmatrix}\in\mathbb{R}^{n\times d} \ \text{outputs:}Y=\begin{pmatrix}y_1\\y_2\\\vdots\\y_n\end{pmatrix}\in\mathbb{R}^{n\times 1} $$
$$w^\ast=\operatorname{argmin}_w \frac{1}{n}\sum_{i=1}^n (w^\top x_i - y_i)^2$$
# numpy solution
XtX = np.dot(Xtrn.T, Xtrn)
XtY = np.dot(Xtrn.T, Ytrn)
w = np.dot(np.linalg.inv(XtX), XtY)
print("w[:5]=", w[:5].T)
$$\text{loss of any $w$:}\qquad L(w) = \frac{1}{n}\sum_{i=1}^n (w^\top x_i - y_i)^2$$
pred = np.dot(Xtrn, w)
loss = np.mean(np.square(pred-Ytrn))
print("loss on training data =", loss)
pred_new = np.dot(Xval, w)
loss_new = np.mean(np.square(pred_new-Yval))
print("loss on new data=", loss_new)
reset_graph()
The same in tensorflow:
# define graph
X = tf.constant(Xtrn, name="Xtrn")
Y = tf.constant(Ytrn, name="Ytrn")
Xt = tf.transpose(X, name="Xtranspose")
XtX = tf.matmul(Xt, X, name="XtX")
XtY = tf.matmul(Xt, Y, name="XtY")
w = tf.matmul(tf.matrix_inverse(XtX), XtY, name="w")
print("w:", w)
pred = tf.matmul(X, w, name="pred")
loss = tf.reduce_mean(tf.square(pred-Y), name="loss") # mean of tensor elements
# evaluate graph
with tf.Session() as sess:
w_value,loss_value = sess.run([w, loss])
print("w[:5]=", w_value[:5].T)
print("loss on training data =", loss_value)
show_graph(tf.get_default_graph())
# additional graph nodes
Xnew = tf.constant(Xval, name="Xnew")
Ynew = tf.constant(Yval, name="Ynew")
prednew = tf.matmul(Xnew, w, name="prednew")
lossnew = tf.reduce_mean(tf.square(prednew-Ynew), name="lossnew")
with tf.Session() as sess:
loss_val = sess.run(lossnew)
print("loss on new data=", loss_new)
reset_graph()
reset_graph()
X = tf.placeholder(dtype=tf.float64, name="Xtrn")
Y = tf.placeholder(dtype=tf.float64, name="Ytrn")
Xt = tf.transpose(X, name="Xtranspose")
XtX = tf.matmul(Xt, X, name="XtX")
XtY = tf.matmul(Xt, Y, name="XtY")
w = tf.matmul(tf.matrix_inverse(XtX), XtY, name="w")
pred = tf.matmul(X, w, name="pred")
loss = tf.reduce_mean(tf.square(pred-Y),name="loss")
with tf.Session() as sess:
w_value = sess.run(w, feed_dict={X: Xtrn, Y:Ytrn})
loss_train = sess.run(loss, feed_dict={X: Xtrn, Y:Ytrn})
print("w[:5]=", w_value[:5].T)
print("loss on training data =", loss_train)
loss_val = sess.run(loss, feed_dict={X: Xval, Y:Yval})
print("loss on validation data =", loss_val)
Wrong! Evaluating 'loss' on Xval recomputes $w$ from new X/Y-data!
reset_graph()
Good use of a variable: keep parameter between evaluations
X = tf.placeholder(dtype=tf.float64, name="X")
Y = tf.placeholder(dtype=tf.float64, name="Y")
w = tf.Variable(tf.zeros(shape=(dim,1), dtype=tf.float64), name="w")
Xt = tf.transpose(X, name="Xtranspose")
XtX = tf.matmul(Xt, X, name="XtX")
XtY = tf.matmul(Xt, Y, name="XtY")
compute_w = w.assign( tf.matmul(tf.matrix_inverse(XtX), XtY, name="w") )
pred = tf.matmul(X, w, name="pred")
loss = tf.reduce_mean(tf.square(pred-Y), name="loss")
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
w_value = sess.run(compute_w, feed_dict={X: Xtrn, Y:Ytrn})
print("w[:5]=", w_value[:5].T)
loss_train = sess.run(loss, feed_dict={X: Xtrn, Y:Ytrn})
print("loss on training data =", loss_train)
loss_val = sess.run(loss, feed_dict={X: Xval, Y:Yval})
print("loss on validation data =", loss_val)
show_graph(tf.get_default_graph())
reset_graph()
nsteps = 10000
eta = 0.005
X = tf.placeholder(dtype=tf.float32, name="X")
Y = tf.placeholder(dtype=tf.float32, name="Y")
w = tf.Variable(tf.zeros(shape=(dim,1), dtype=tf.float32), name="w")
pred = tf.matmul(X, w, name="pred")
residual = pred-Y
loss = tf.reduce_mean(tf.square(residual), name="loss")
Xt = tf.transpose(X, name="Xtranspose")
gradient_of_loss = 2./ntrn * tf.matmul(Xt, residual) # derived on paper!
update_w = w.assign( w - eta*gradient_of_loss )
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init) # init variables
for i in range(nsteps):
sess.run(update_w, feed_dict={X: Xtrn, Y:Ytrn}) # execute update to w
if i%1000 == 0:
loss_train = sess.run(loss, feed_dict={X: Xtrn, Y:Ytrn})
loss_val = sess.run(loss, feed_dict={X: Xval, Y:Yval})
print("step", i, "loss(train)", loss_train, "loss(val)", loss_val)
reset_graph()
nsteps = 10000
eta = 0.005
X = tf.placeholder(dtype=tf.float32, name="X")
Y = tf.placeholder(dtype=tf.float32, name="Y")
w = tf.Variable(tf.zeros(shape=(dim,1), dtype=tf.float32), name="w")
pred = tf.matmul(X, w, name="pred")
residual = pred-Y
loss = tf.reduce_mean(tf.square(residual), name="loss")
gradient_of_loss = tf.gradients(loss, [w])[0] # <-- automatic differentiation
update_w = w.assign( w - eta*gradient_of_loss )
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init) # init variables
for i in range(nsteps):
sess.run(update_w, feed_dict={X: Xtrn, Y:Ytrn}) # execute update to w
if i%1000 == 0:
loss_train = sess.run(loss, feed_dict={X: Xtrn, Y:Ytrn})
loss_val = sess.run(loss, feed_dict={X: Xval, Y:Yval})
print("step", i, "loss(train)", loss_train, "loss(val)", loss_val)
show_graph(tf.get_default_graph())
# numpy
def my_func(a, b):
z = 0
for i in range(10):
z = a * np.cos(z + i) + z * np.sin(b - i)
return z
# we can evaluate the function for any a,b
print(my_func(0.2, 0.3))
reset_graph()
# In tensorflow, we can not just evaluate the function,
# but also its partial derivatives without (much) overhead
a = tf.Variable(0.2, name="a")
b = tf.Variable(0.3, name="b")
z = tf.constant(0.0, name="z")
for i in range(10):
z = a * tf.cos(z + i) + z * tf.sin(b - i) # this builds a complex graph
grads = tf.gradients(z, [a, b]) # gradient of z with respect to a and b
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
val = sess.run(z)
print("val=", val)
g = sess.run(grads)
print("gradient=", g)
show_graph(tf.get_default_graph())
Chain rule: The derivative of a node w.r.t any node is the sum over all paths of the products along the paths! How to compute this efficiently?
Forward mode: compute derivative of all nodes w.r.t. a fixed one (here $b$)
Backward mode: compute derivative of a fixed node (here $e$) w.r.t. all nodes
practice building and running tensorflow graphs
implement Logistic Regression, i.e. learn $w$ by minimizing the logistic loss $$L(w) = \frac{1}{n}\sum_{i=1}^n \log\big(1+\exp(-y_i w^\top x_i)\big)$$ using a) fixed data, and b) data being handed in via placeholders
try different learning rates to find one that convergences faster than $\eta=0.001$
(optional) create a version with analytically computed gradients, compare it speed
reset_graph()
# Creates a graph.
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
sess = tf.Session()
# Runs the op.
options = tf.RunOptions(output_partition_graphs=True)
metadata = tf.RunMetadata()
a_val = sess.run(a, options=options, run_metadata=metadata)
# Show **where** the computation happened
print(metadata.partition_graphs)
In [4]: def f():
...: A = np.random.randn(10000,10000).astype(np.float32)
...: Asquared = np.dot(A,A)
...:
...: %time f()
...:
CPU times: user 32.6 s, sys: 1.15 s, total: 33.7 s
Wall time: 33.1 s
In [5]: def f():
...: A = np.random.randn(10000,10000).astype(np.float64)
...: Asquared = np.dot(A,A)
...:
...: %time f()
...:
CPU times: user 1min 8s, sys: 1.68 s, total: 1min 9s
Wall time: 1min 8s
In [6]: with tf.device("/cpu:0"):
...: A = tf.random_normal([10000,10000], dtype=tf.float32)
...: Asquared = tf.matmul(A,A, name="A2")
...:
...: with tf.Session() as sess:
...: %time res = sess.run(Asquared)
...:
2017-11-25 12:24:04.899238: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120]...
CPU times: user 3min 3s, sys: 1.24 s, total: 3min 4s
Wall time: 3.45 s
In [7]: with tf.device("/cpu:0"):
...: A = tf.random_normal([10000,10000], dtype=tf.float64)
...: Asquared = tf.matmul(A,A, name="A2")
...:
...: with tf.Session() as sess:
...: %time res = sess.run(Asquared)
...:
2017-11-25 12:24:26.075294: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120]...
CPU times: user 6min 10s, sys: 2.54 s, total: 6min 12s
Wall time: 6.91 s
In [8]: with tf.device("/gpu:0"):
...: A = tf.random_normal([10000,10000], dtype=tf.float32)
...: Asquared = tf.matmul(A,A, name="A2")
...:
...: with tf.Session() as sess:
...: %time res = sess.run(Asquared)
...:
2017-11-25 12:26:26.685615: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120]...
CPU times: user 236 ms, sys: 140 ms, total: 376 ms
Wall time: 358 ms
In [9]: with tf.device("/gpu:0"):
...: A = tf.random_normal([10000,10000], dtype=tf.float64)
...: Asquared = tf.matmul(A,A, name="A2")
...:
...: with tf.Session() as sess:
...: %time res = sess.run(Asquared)
...:
2017-11-25 12:26:43.021277: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120]...
CPU times: user 72 ms, sys: 568 ms, total: 640 ms
Wall time: 634 ms
In [4]: def f():
...: A = np.random.randn(10000,10000).astype(np.float32)
...: Asquared = np.dot(A,A)
...:
...: %time f()
...:
CPU times: user 32.6 s, sys: 1.15 s, total: 33.7 s
Wall time: 33.1 s
In [5]: def f():
...: A = np.random.randn(10000,10000).astype(np.float64)
...: Asquared = np.dot(A,A)
...:
...: %time f()
...:
CPU times: user 1min 8s, sys: 1.68 s, total: 1min 9s
Wall time: 1min 8s
In [6]: with tf.device("/cpu:0"):
...: A = tf.random_normal([10000,10000], dtype=tf.float32)
...: Asquared = tf.matmul(A,A, name="A2")
...:
...: with tf.Session() as sess:
...: %time res = sess.run(Asquared)
...:
2017-11-25 12:24:04.899238: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120]...
CPU times: user 3min 3s, sys: 1.24 s, total: 3min 4s
Wall time: 3.45 s
In [7]: with tf.device("/cpu:0"):
...: A = tf.random_normal([10000,10000], dtype=tf.float64)
...: Asquared = tf.matmul(A,A, name="A2")
...:
...: with tf.Session() as sess:
...: %time res = sess.run(Asquared)
...:
2017-11-25 12:24:26.075294: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120]...
CPU times: user 6min 10s, sys: 2.54 s, total: 6min 12s
Wall time: 6.91 s
In [8]: with tf.device("/gpu:0"):
...: A = tf.random_normal([10000,10000], dtype=tf.float32)
...: Asquared = tf.matmul(A,A, name="A2")
...:
...: with tf.Session() as sess:
...: %time res = sess.run(Asquared)
...:
2017-11-25 12:26:26.685615: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120]...
CPU times: user 236 ms, sys: 140 ms, total: 376 ms
Wall time: 358 ms
In [9]: with tf.device("/gpu:0"):
...: A = tf.random_normal([10000,10000], dtype=tf.float64)
...: Asquared = tf.matmul(A,A, name="A2")
...:
...: with tf.Session() as sess:
...: %time res = sess.run(Asquared)
...:
2017-11-25 12:26:43.021277: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120]...
CPU times: user 72 ms, sys: 568 ms, total: 640 ms
Wall time: 634 ms