views:

183

answers:

2

I have a loop that executes the body about 200 times. In each loop iteration, it does a sophisticated calculation, and then as debugging, I wish to produce a heatmap of a NxM matrix. But, generating this heatmap is unbearably slow and significantly slow downs an already slow algorithm.

My code is along the lines:

import numpy
import matplotlib.pyplot as plt
for i in range(200):
    matrix = complex_calculation()
    plt.set_cmap("gray")
    plt.imshow(matrix)
    plt.savefig("frame{0}.png".format(i))

The matrix, from numpy, is not huge --- 300 x 600 of doubles. Even if I do not save the figure and instead update an on-screen plot, it's even slower.

Surely I must be abusing pyplot. (Matlab can do this, no problem.) How do I speed this up?

+2  A: 

Try putting plt.clf() in the loop to clear the current figure:

for i in range(200):
    matrix = complex_calculation()
    plt.set_cmap("gray")
    plt.imshow(matrix)
    plt.savefig("frame{0}.png".format(i))
    plt.clf()

If you don't do this, the loop slows down as the machine struggles to allocate more and more memory for the figure.

unutbu
Still slow, but at least it's bearable now.
carl
+2  A: 

I think this is a bit faster:

import matplotlib.pyplot as plt
from matplotlib import cm
fig = plt.figure()
ax = fig.add_axes([0.1,0.1,0.8,0.8])
for i in range(200):
    matrix = complex_calculation()
    ax.imshow(matrix, cmap=cm.gray)
    fig.savefig("frame{0}.png".format(i))

plt.imshow calls gca which calls gcf which checks to see if there is a figure; if not, it creates one. By manually instantiating the figure first, you do not need to do all that.

Steve