本文整理汇总了Python中pylab.pyplot函数的典型用法代码示例。如果您正苦于以下问题:Python pyplot函数的具体用法?Python pyplot怎么用?Python pyplot使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了pyplot函数的17个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: geweke_plot
def geweke_plot(data, name, format='png', suffix='-diagnostic', path='./', fontmap = None,
verbose=1):
# Generate Geweke (1992) diagnostic plots
if fontmap is None: fontmap = {1:10, 2:8, 3:6, 4:5, 5:4}
# Generate new scatter plot
figure()
x, y = transpose(data)
scatter(x.tolist(), y.tolist())
# Plot options
xlabel('First iteration', fontsize='x-small')
ylabel('Z-score for %s' % name, fontsize='x-small')
# Plot lines at +/- 2 sd from zero
pyplot((nmin(x), nmax(x)), (2, 2), '--')
pyplot((nmin(x), nmax(x)), (-2, -2), '--')
# Set plot bound
ylim(min(-2.5, nmin(y)), max(2.5, nmax(y)))
xlim(0, nmax(x))
# Save to file
if not os.path.exists(path):
os.mkdir(path)
if not path.endswith('/'):
path += '/'
savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:CosmologyTaskForce,项目名称:pymc,代码行数:29,代码来源:Matplot.py
示例2: display
def display(self, xaxis, alpha, new=True):
"""
E.display(xaxis, alpha = .8)
:Arguments: xaxis, alpha
Plots the CI region on the current figure, with respect to
xaxis, at opacity alpha.
:Note: The fill color of the envelope will be self.mass
on the grayscale.
"""
if new:
figure()
if self.ndim == 1:
if self.mass>0.:
x = concatenate((xaxis,xaxis[::-1]))
y = concatenate((self.lo, self.hi[::-1]))
fill(x,y,facecolor='%f' % self.mass,alpha=alpha, label = ('centered CI ' + str(self.mass)))
else:
pyplot(xaxis,self.value,'k-',alpha=alpha, label = ('median'))
else:
if self.mass>0.:
subplot(1,2,1)
contourf(xaxis[0],xaxis[1],self.lo,cmap=cm.bone)
colorbar()
subplot(1,2,2)
contourf(xaxis[0],xaxis[1],self.hi,cmap=cm.bone)
colorbar()
else:
contourf(xaxis[0],xaxis[1],self.value,cmap=cm.bone)
colorbar()
开发者ID:CosmologyTaskForce,项目名称:pymc,代码行数:32,代码来源:Matplot.py
示例3: trace
def trace(data, name, format='png', datarange=(None, None), suffix='', path='./', rows=1, columns=1,
num=1, last=True, fontmap = None, verbose=1):
"""
Generates trace plot from an array of data.
:Arguments:
data: array or list
Usually a trace from an MCMC sample.
name: string
The name of the trace.
datarange: tuple or list
Preferred y-range of trace (defaults to (None,None)).
format (optional): string
Graphic output format (defaults to png).
suffix (optional): string
Filename suffix.
path (optional): string
Specifies location for saving plots (defaults to local directory).
fontmap (optional): dict
Font map for plot.
"""
if fontmap is None: fontmap = {1:10, 2:8, 3:6, 4:5, 5:4}
# Stand-alone plot or subplot?
standalone = rows==1 and columns==1 and num==1
if standalone:
if verbose>0:
print_('Plotting', name)
figure()
subplot(rows, columns, num)
pyplot(data.tolist())
ylim(datarange)
# Plot options
title('\n\n %s trace'%name, x=0., y=1., ha='left', va='top', fontsize='small')
# Smaller tick labels
tlabels = gca().get_xticklabels()
setp(tlabels, 'fontsize', fontmap[rows/2])
tlabels = gca().get_yticklabels()
setp(tlabels, 'fontsize', fontmap[rows/2])
if standalone:
if not os.path.exists(path):
os.mkdir(path)
if not path.endswith('/'):
path += '/'
# Save to file
savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:CosmologyTaskForce,项目名称:pymc,代码行数:60,代码来源:Matplot.py
示例4: trace
def trace(data, name, format='png', datarange=(None, None), suffix='', path='./', rows=1, columns=1, num=1, last=True, fontmap = {1:10, 2:8, 3:6, 4:5, 5:4}, verbose=1):
# Internal plotting specification for handling nested arrays
# Stand-alone plot or subplot?
standalone = rows==1 and columns==1 and num==1
if standalone:
if verbose>0:
print 'Plotting', name
figure()
subplot(rows, columns, num)
pyplot(data.tolist())
ylim(datarange)
# Plot options
if last:
xlabel('Iteration', fontsize='x-small')
ylabel(name, fontsize='x-small')
# Smaller tick labels
tlabels = gca().get_xticklabels()
setp(tlabels, 'fontsize', fontmap[rows])
tlabels = gca().get_yticklabels()
setp(tlabels, 'fontsize', fontmap[rows])
if standalone:
if not os.path.exists(path):
os.mkdir(path)
if not path.endswith('/'):
path += '/'
# Save to file
savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:along1x,项目名称:pymc,代码行数:34,代码来源:Matplot.py
示例5: geweke_plot
def geweke_plot(data,
name,
format='png',
suffix='-diagnostic',
path='./',
fontmap=None):
'''
Generate Geweke (1992) diagnostic plots.
:Arguments:
data: list
List (or list of lists for vector-valued variables) of Geweke diagnostics, output
from the `pymc.diagnostics.geweke` function .
name: string
The name of the plot.
format (optional): string
Graphic output format (defaults to png).
suffix (optional): string
Filename suffix (defaults to "-diagnostic").
path (optional): string
Specifies location for saving plots (defaults to local directory).
fontmap (optional): dict
Font map for plot.
'''
if fontmap is None:
fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}
# Generate new scatter plot
figure()
x, y = transpose(data)
scatter(x.tolist(), y.tolist())
# Plot options
xlabel('First iteration', fontsize='x-small')
ylabel('Z-score for %s' % name, fontsize='x-small')
# Plot lines at +/- 2 sd from zero
pyplot((nmin(x), nmax(x)), (2, 2), '--')
pyplot((nmin(x), nmax(x)), (-2, -2), '--')
# Set plot bound
ylim(min(-2.5, nmin(y)), max(2.5, nmax(y)))
xlim(0, nmax(x))
# Save to file
if not os.path.exists(path):
os.mkdir(path)
if not path.endswith('/'):
path += '/'
savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:shfengcj,项目名称:pymc,代码行数:57,代码来源:Matplot.py
示例6: zplot
def zplot(pvalue_dict,
name='',
format='png',
path='./',
fontmap=None,
verbose=1):
"""Plots absolute values of z-scores for model validation output from
diagnostics.validate()."""
if verbose:
print_('\nGenerating model validation plot')
if fontmap is None:
fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}
x, y, labels = [], [], []
for i, var in enumerate(pvalue_dict):
# Get p-values
pvals = pvalue_dict[var]
# Take absolute values of inverse-standard normals
zvals = abs(special.ndtri(pvals))
x = append(x, zvals)
y = append(y, ones(size(zvals)) * (i + 1))
vname = var
vname += " (%i)" % size(zvals)
labels = append(labels, vname)
# Spawn new figure
figure()
subplot(111)
subplots_adjust(left=0.25, bottom=0.1)
# Plot scores
pyplot(x, y, 'o')
# Set range on axes
ylim(0, size(pvalue_dict) + 2)
xlim(xmin=0)
# Tick labels for y-axis
yticks(arange(len(labels) + 2), append(append("", labels), ""))
# X label
xlabel("Absolute z transformation of p-values")
if not os.path.exists(path):
os.mkdir(path)
if not path.endswith('/'):
path += '/'
if name:
name += '-'
savefig("%s%svalidation.%s" % (path, name, format))
开发者ID:shfengcj,项目名称:pymc,代码行数:54,代码来源:Matplot.py
示例7: discrepancy_plot
def discrepancy_plot(
data, name="discrepancy", report_p=True, format="png", suffix="-gof", path="./", fontmap=None, verbose=1
):
# Generate goodness-of-fit deviate scatter plot
if verbose > 0:
print_("Plotting", name + suffix)
if fontmap is None:
fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}
# Generate new scatter plot
figure()
try:
x, y = transpose(data)
except ValueError:
x, y = data
scatter(x, y)
# Plot x=y line
lo = nmin(ravel(data))
hi = nmax(ravel(data))
datarange = hi - lo
lo -= 0.1 * datarange
hi += 0.1 * datarange
pyplot((lo, hi), (lo, hi))
# Plot options
xlabel("Observed deviates", fontsize="x-small")
ylabel("Simulated deviates", fontsize="x-small")
if report_p:
# Put p-value in legend
count = sum(s > o for o, s in zip(x, y))
text(
lo + 0.1 * datarange,
hi - 0.1 * datarange,
"p=%.3f" % (count / len(x)),
horizontalalignment="center",
fontsize=10,
)
# Save to file
if not os.path.exists(path):
os.mkdir(path)
if not path.endswith("/"):
path += "/"
savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:roban,项目名称:pymc,代码行数:48,代码来源:Matplot.py
示例8: trace
def trace(
data,
name,
format="png",
datarange=(None, None),
suffix="",
path="./",
rows=1,
columns=1,
num=1,
last=True,
fontmap=None,
verbose=1,
):
# Internal plotting specification for handling nested arrays
if fontmap is None:
fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}
# Stand-alone plot or subplot?
standalone = rows == 1 and columns == 1 and num == 1
if standalone:
if verbose > 0:
print_("Plotting", name)
figure()
subplot(rows, columns, num)
pyplot(data.tolist())
ylim(datarange)
# Plot options
title("\n\n %s trace" % name, x=0.0, y=1.0, ha="left", va="top", fontsize="small")
# Smaller tick labels
tlabels = gca().get_xticklabels()
setp(tlabels, "fontsize", fontmap[rows / 2])
tlabels = gca().get_yticklabels()
setp(tlabels, "fontsize", fontmap[rows / 2])
if standalone:
if not os.path.exists(path):
os.mkdir(path)
if not path.endswith("/"):
path += "/"
# Save to file
savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:roban,项目名称:pymc,代码行数:48,代码来源:Matplot.py
示例9: test_mesh_metric
def test_mesh_metric():
mesh = RectangleMesh(0,0,1,1,20,20)
mesh = adapt(interpolate(Constant(((10.,0.),(0.,10.))),TensorFunctionSpace(mesh,'CG',1)))
#extract mesh metric
MpH = mesh_metric2(mesh)
# Plot element i
i = 20; t = linspace(0,2*pi,101)
ind = MpH.function_space().dofmap().cell_dofs(i)
thecell = mesh.cells()[i]
centerxy = mesh.coordinates()[thecell,:].mean(0).repeat(3).reshape([2,3]).T
cxy = mesh.coordinates()[thecell,:]-centerxy
pyplot(cxy[:,0],cxy[:,1],'-b')
H = MpH.vector().gather(ind).reshape(2,2);# H = array([[H[1],H[0]],[H[0],H[2]]])
#H = MpH.vector().gather(ind); H = array([[H[1],H[0]],[H[0],H[2]]])
#H = MpH.vector().array()[ind]; H = array([[H[1],H[0]],[H[0],H[2]]])
[v,w] = linalg.eig(H); v /= pysqrt(3) #v = 1/pysqrt(v)/pysqrt(3)
elxy = array([pycos(t),pysin(t)]).T.dot(w).dot(diag(v)).dot(w.T)
hold('on'); pyplot(elxy[:,0],elxy[:,1],'-r'); hold('off'); axis('equal')
print('triangle area: %0.6f, ellipse axis product(*3*sqrt(3)/4): %0.6f' % (pyabs(linalg.det(array([cxy[1,:]-cxy[0,:],cxy[2,:]-cxy[0,:]])))/2,v[0]*v[1]*3*sqrt(3)/4))
show()
开发者ID:taupalosaurus,项目名称:pragmatic,代码行数:20,代码来源:mesh_metric2_example.py
示例10: discrepancy_plot
def discrepancy_plot(data, name, report_p=True, format='png', suffix='-gof', path='./', fontmap = {1:10, 2:8, 3:6, 4:5, 5:4}, verbose=1):
# Generate goodness-of-fit deviate scatter plot
if verbose>0:
print 'Plotting', name+suffix
# Generate new scatter plot
figure()
try:
x, y = transpose(data)
except ValueError:
x, y = data
scatter(x, y)
# Plot x=y line
lo = nmin(ravel(data))
hi = nmax(ravel(data))
datarange = hi-lo
lo -= 0.1*datarange
hi += 0.1*datarange
pyplot((lo, hi), (lo, hi))
# Plot options
xlabel('Observed deviates', fontsize='x-small')
ylabel('Simulated deviates', fontsize='x-small')
if report_p:
# Put p-value in legend
count = sum(s>o for o,s in zip(x,y))
text(lo+0.1*datarange, hi-0.1*datarange,
'p=%.3f' % (count/len(x)), horizontalalignment='center',
fontsize=10)
# Save to file
if not os.path.exists(path):
os.mkdir(path)
if not path.endswith('/'):
path += '/'
savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:along1x,项目名称:pymc,代码行数:38,代码来源:Matplot.py
示例11: _plot_time
def _plot_time(file_name, down_sample=1):
from pylab import plot as pyplot
from pylab import arange, xlabel, ylabel, title, grid, show
try:
down_sample = int(down_sample)
except TypeError:
print("argument down_sample must be int")
raise SystemExit
wr = wave.open(file_name, 'r')
song = _time_data(wr, down_sample=down_sample)
num_frames = wr.getnframes()
frame_rate = wr.getframerate()
t = arange(0.0, (num_frames - down_sample) / frame_rate, down_sample / frame_rate)
pyplot(t, song)
xlabel('time (s)')
ylabel('amplitude (maximum 2^8, minimum -2^8)')
title('Amplitude of track {} over time'.format(file_name))
grid(True)
show()
开发者ID:jameh,项目名称:music-entropy,代码行数:23,代码来源:music.py
示例12: display
def display(self, axes, xlab=None, ylab=None, name=None, new=True):
if name:
name_str = name
else:
name_str = ''
if self.ndim == 1:
if new:
figure()
pyplot(axes, self.lo, 'k-.', label=name_str + ' mean-sd')
pyplot(axes, self.hi, 'k-.', label=name_str + 'mean+sd')
pyplot(axes, self.mean, 'k-', label=name_str + 'mean')
if name:
title(name)
elif self.ndim == 2:
if new:
figure(figsize=(14, 4))
subplot(1, 3, 1)
contourf(axes[0], axes[1], self.lo, cmap=cm.bone)
title(name_str + ' mean-sd')
if xlab:
xlabel(xlab)
if ylab:
ylabel(ylab)
colorbar()
subplot(1, 3, 2)
contourf(axes[0], axes[1], self.mean, cmap=cm.bone)
title(name_str + ' mean')
if xlab:
xlabel(xlab)
if ylab:
ylabel(ylab)
colorbar()
subplot(1, 3, 3)
contourf(axes[0], axes[1], self.hi, cmap=cm.bone)
title(name_str + ' mean+sd')
if xlab:
xlabel(xlab)
if ylab:
ylabel(ylab)
colorbar()
else:
raise ValueError(
'Only 1- and 2- dimensional functions can be displayed')
savefig(
"%s%s%s.%s" % (
self._plotpath,
self.name,
self.suffix,
self._format))
开发者ID:Gwill,项目名称:pymc,代码行数:53,代码来源:Matplot.py
示例13: summary_plot
#.........这里部分代码省略.........
value = variable.value
# Number of elements in current variable
k = size(value)
# Append variable name(s) to list
if k > 1:
names = var_str(varname, shape(value)[int(shape(value)[0]==1):])
labels += names
else:
labels.append(varname)
# labels.append('\n'.join(varname.split('_')))
# Add spacing for each chain, if more than one
e = [0] + [(chain_spacing * ((i + 2) / 2)) * (
-1) ** i for i in range(chains - 1)]
# Loop over chains
for j, quants in enumerate(data):
# Deal with multivariate nodes
if k > 1:
ravelled_quants = list(map(ravel, quants))
for i, quant in enumerate(transpose(ravelled_quants)):
q = ravel(quant)
# Y coordinate with jitter
y = -(var + i) + e[j]
if quartiles:
# Plot median
pyplot(q[2], y, 'bo', markersize=4)
# Plot quartile interval
errorbar(
x=(q[1],
q[3]),
y=(y,
y),
linewidth=2,
color="blue")
else:
# Plot median
pyplot(q[1], y, 'bo', markersize=4)
# Plot outer interval
errorbar(
x=(q[0],
q[-1]),
y=(y,
y),
linewidth=1,
color="blue")
else:
# Y coordinate with jitter
y = -var + e[j]
if quartiles:
# Plot median
pyplot(quants[2], y, 'bo', markersize=4)
# Plot quartile interval
errorbar(
开发者ID:Gwill,项目名称:pymc,代码行数:67,代码来源:Matplot.py
示例14: summary_plot
#.........这里部分代码省略.........
try:
# First try missing-value stochastic
value = variable.get_stoch_value()
except AttributeError:
# All other variable types
value = variable.value
# Number of elements in current variable
k = size(value)
# Append variable name(s) to list
if k>1:
names = var_str(varname, shape(value))
labels += names
else:
labels.append('\n'.join(varname.split('_')))
# Add spacing for each chain, if more than one
e = [0] + [(chain_spacing * ((i+2)/2))*(-1)**i for i in range(chains-1)]
# Loop over chains
for j,quants in enumerate(data):
# Deal with multivariate nodes
if k>1:
for i,q in enumerate(transpose(quants)):
# Y coordinate with jitter
y = -(var+i) + e[j]
if quartiles:
# Plot median
pyplot(q[2], y, 'bo', markersize=4)
# Plot quartile interval
errorbar(x=(q[1],q[3]), y=(y,y), linewidth=2, color="blue")
else:
# Plot median
pyplot(q[1], y, 'bo', markersize=4)
# Plot outer interval
errorbar(x=(q[0],q[-1]), y=(y,y), linewidth=1, color="blue")
else:
# Y coordinate with jitter
y = -var + e[j]
if quartiles:
# Plot median
pyplot(quants[2], y, 'bo', markersize=4)
# Plot quartile interval
errorbar(x=(quants[1],quants[3]), y=(y,y), linewidth=2, color="blue")
else:
# Plot median
pyplot(quants[1], y, 'bo', markersize=4)
# Plot outer interval
errorbar(x=(quants[0],quants[-1]), y=(y,y), linewidth=1, color="blue")
# Increment index
var += k
# Define range of y-axis
ylim(-var+0.5, -0.5)
开发者ID:along1x,项目名称:pymc,代码行数:67,代码来源:Matplot.py
示例15: pair_posterior
#.........这里部分代码省略.........
for p in nodes:
trueval[p] = None
np=len(nodes)
ns = {}
for p in nodes:
if not p.value.shape:
ns[p] = 1
else:
ns[p] = len(p.value.ravel())
index_now = -1
tracelen = {}
ravelledtrace={}
titles={}
indices={}
cum_indices={}
for p in nodes:
tracelen[p] = p.trace().shape[0]
ravelledtrace[p] = p.trace().reshape((tracelen[p],-1))
titles[p]=[]
indices[p] = []
cum_indices[p]=[]
for j in range(ns[p]):
# Should this index be included?
if mask[p]:
if not mask[p].ravel()[j]:
indices[p].append(j)
this_index=True
else:
this_index=False
else:
indices[p].append(j)
this_index=True
# If so:
if this_index:
index_now+=1
cum_indices[p].append(index_now)
# Figure out title string
if ns[p]==1:
titles[p].append(p.__name__)
else:
titles[p].append(p.__name__ + get_index_list(p.value.shape,j).__repr__())
if new:
figure(figsize = (10,10))
n = index_now+1
for p in nodes:
for j in range(len(indices[p])):
# Marginals
ax=subplot(n,n,(cum_indices[p][j])*(n+1)+1)
setp(ax.get_xticklabels(),fontsize=fontsize)
setp(ax.get_yticklabels(),fontsize=fontsize)
hist(ravelledtrace[p][:,j],normed=True,fill=False)
xlabel(titles[p][j],size=fontsize)
# Bivariates
for i in range(len(nodes)-1):
p0 = nodes[i]
for j in range(len(indices[p0])):
p0_i = indices[p0][j]
p0_ci = cum_indices[p0][j]
for k in range(i,len(nodes)):
p1=nodes[k]
if i==k:
l_range = range(j+1,len(indices[p0]))
else:
l_range = range(len(indices[p1]))
for l in l_range:
p1_i = indices[p1][l]
p1_ci = cum_indices[p1][l]
subplot_index = p0_ci*(n) + p1_ci+1
ax=subplot(n, n, subplot_index)
setp(ax.get_xticklabels(),fontsize=fontsize)
setp(ax.get_yticklabels(),fontsize=fontsize)
try:
H, x, y = histogram2d(ravelledtrace[p1][:,p1_i],ravelledtrace[p0][:,p0_i])
contourf(x,y,H,cmap=cm.bone)
except:
print 'Unable to plot histogram for ('+titles[p1][l]+','+titles[p0][j]+'):'
pyplot(ravelledtrace[p1][:,p1_i],ravelledtrace[p0][:,p0_i],'k.',markersize=1.)
axis('tight')
xlabel(titles[p1][l],size=fontsize)
ylabel(titles[p0][j],size=fontsize)
plotname = ''
for obj in nodes:
plotname += obj.__name__ + ''
if not os.path.exists(path):
os.mkdir(path)
if not path.endswith('/'):
path += '/'
savefig("%s%s%s.%s" % (path, plotname, suffix, format))
开发者ID:along1x,项目名称:pymc,代码行数:101,代码来源:Matplot.py
示例16: discrepancy_plot
def discrepancy_plot(
data, name='discrepancy', report_p=True, format='png', suffix='-gof', path='./',
fontmap=None):
'''
Generate goodness-of-fit deviate scatter plot.
:Arguments:
data: list
List (or list of lists for vector-valued variables) of discrepancy values, output
from the `pymc.diagnostics.discrepancy` function .
name: string
The name of the plot.
report_p: bool
Flag for annotating the p-value to the plot.
format (optional): string
Graphic output format (defaults to png).
suffix (optional): string
Filename suffix (defaults to "-gof").
path (optional): string
Specifies location for saving plots (defaults to local directory).
fontmap (optional): dict
Font map for plot.
'''
if verbose > 0:
print_('Plotting', name + suffix)
if fontmap is None:
fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}
# Generate new scatter plot
figure()
try:
x, y = transpose(data)
except ValueError:
x, y = data
scatter(x, y)
# Plot x=y line
lo = nmin(ravel(data))
hi = nmax(ravel(data))
datarange = hi - lo
lo -= 0.1 * datarange
hi += 0.1 * datarange
pyplot((lo, hi), (lo, hi))
# Plot options
xlabel('Observed deviates', fontsize='x-small')
ylabel('Simulated deviates', fontsize='x-small')
if report_p:
# Put p-value in legend
count = sum(s > o for o, s in zip(x, y))
text(lo + 0.1 * datarange, hi - 0.1 * datarange,
'p=%.3f' % (count / len(x)), horizontalalignment='center',
fontsize=10)
# Save to file
if not os.path.exists(path):
os.mkdir(path)
if not path.endswith('/'):
path += '/'
savefig("%s%s%s.%s" % (path, name, suffix, format))
开发者ID:Gwill,项目名称:pymc,代码行数:70,代码来源:Matplot.py
示例17: adv_convergence
#.........这里部分代码省略.........
u, ps = U.split()
#SOLVE CONCENTRATION
mm = mesh_metric2(mesh)
vdir = u/sqrt(inner(u,u)+DOLFIN_EPS)
if iii == 0 or use_reform == False:
Q2 = FunctionSpace(mesh,'CG',2); c = Function(Q2)
q = TestFunction(Q2); p = TrialFunction(Q2)
newq = (q+dot(vdir,dot(mm,vdir))*inner(grad(q),vdir)) #SUPG
if use_reform:
F = newq*(fac/((1+exp(-c))**2)*exp(-c))*inner(grad(c),u)*dx
J = derivative(F,c)
bc = DirichletBC(Q2, Expression("-log("+str(float(fac)) +"/("+testsol+"+"+str(float(delta))+")-1)"), left)
# bc = DirichletBC(Q, -ln(fac/(Expression(testsol)+delta)-1), left)
problem = NonlinearVariationalProblem(F,c,bc,J)
solver = NonlinearVariationalSolver(problem)
solver.parameters["newton_solver"]["relaxation_parameter"] = relp
solver.solve()
else:
a2 = newq*inner(grad(p),u)*dx
bc = DirichletBC(Q2, Expression(testsol), left)
L2 = Constant(0.)*q*dx
solve(a2 == L2, c, bc)
if (not bool(use_adapt)) or iii == Nadapt-1:
break
um = project(sqrt(inner(u,u)),FunctionSpace(mesh,'CG',2))
H = metric_pnorm(um, eta, max_edge_ratio=1+49*(use_adapt!=2), p=2)
H2 = metric_pnorm(c, eta, max_edge_ratio=1+49*(use_adapt!=2), p=2)
#H3 = metric_pnorm(ps , eta, max_edge_ratio=1+49*(use_adapt!=2), p=2)
H4 = metric_ellipse(H,H2)
#H5 = metric_ellipse(H3,H4,mesh)
mesh = adapt(H4)
if use_reform:
Q2 = FunctionSpace(mesh,'CG',2)
c = interpolate(c,Q2)
if use_reform:
c = project(fac/(1+exp(-c))-delta,FunctionSpace(mesh,'CG',2))
L2error = bnderror(c,Expression(testsol),ds)
dofs.append(len(c.vector().array())+len(U.vector().array()))
L2errors.append(L2error)
# fid = open("DOFS_L2errors_mesh_c_CG"+str(CGorder)+outname+".mpy",'w')
# pickle.dump([dofs[0],L2errors[0],c.vector().array().min(),c.vector().array().max()-1,mesh.cells(),mesh.coordinates(),c.vector().array()],fid)
# fid.close();
log(INFO+1,"%1dX ADAPT<->SOLVE complete: DOF=%5d, error=%0.0e, min(c)=%0.0e,max(c)-1=%0.0e" % (Nadapt, dofs[len(dofs)-1], L2error,c.vector().array().min(),c.vector().array().max()-1))
# PLOT MESH + solution
figure()
testf = interpolate(c ,FunctionSpace(mesh,'CG',1))
testfe = interpolate(Expression(testsol),FunctionSpace(mesh,'CG',1))
vtx2dof = vertex_to_dof_map(FunctionSpace(mesh, "CG" ,1))
zz = testf.vector().array()[vtx2dof]; zz[zz==1] -= 1e-16
hh=tricontourf(mesh.coordinates()[:,0],mesh.coordinates()[:,1],mesh.cells(),zz,100,cmap=get_cmap('binary'))
colorbar(hh)
hold('on'); triplot(mesh.coordinates()[:,0],mesh.coordinates()[:,1],mesh.cells(),color='r',linewidth=0.5); hold('off')
axis('equal'); box('off')
# savefig(outname+'final_mesh_CG2.png',dpi=300) #; savefig('outname+final_mesh_CG2.eps',dpi=300)
#PLOT ERROR
figure()
xe = interpolate(Expression("x[0]"),FunctionSpace(mesh,'CG',1)).vector().array()
ye = interpolate(Expression("x[1]"),FunctionSpace(mesh,'CG',1)).vector().array()
I = xe - Lx/2 > -DOLFIN_EPS; I2 = ye[I].argsort()
pyplot(ye[I][I2],testf.vector().array()[I][I2]-testfe.vector().array()[I][I2],'-b'); ylabel('error')
# PLOT L2error graph
figure()
pyloglog(dofs,L2errors,'-b.',linewidth=2,markersize=16); xlabel('Degree of freedoms'); ylabel('L2 error')
# SAVE SOLUTION
dofs = array(dofs); L2errors = array(L2errors)
fid = open("DOFS_L2errors_CG"+str(CGorder)+outname+".mpy",'w')
pickle.dump([dofs,L2errors],fid)
fid.close();
# #show()
# #LOAD SAVED SOLUTIONS
# fid = open("DOFS_L2errors_CG2"+outname+".mpy",'r')
# [dofs,L2errors] = pickle.load(fid)
# fid.close()
#
# PERFORM FITS ON LAST THREE POINTS
NfitP = 5
I = array(range(len(dofs)-NfitP,len(dofs)))
slope,ints = polyfit(pylog(dofs[I]), pylog(L2errors[I]), 1)
fid = open("DOFS_L2errors_CG2_fit"+outname+".mpy",'w')
pickle.dump([dofs,L2errors,slope,ints],fid)
fid.close()
#PLOT THEM TOGETHER
if CGorderL != [2]:
fid = open("DOFS_L2errors_CG3.mpy",'r')
[dofs_old,L2errors_old] = pickle.load(fid)
fid.close()
slope2,ints2 = polyfit(pylog(dofs_old[I]), pylog(L2errors_old[I]), 1)
figure()
pyloglog(dofs,L2errors,'-b.',dofs_old,L2errors_old,'--b.',linewidth=2,markersize=16)
hold('on'); pyloglog(dofs,pyexp2(ints)*dofs**slope,'-r',dofs_old,pyexp2(ints2)*dofs_old**slope2,'--r',linewidth=1); hold('off')
xlabel('Degree of freedoms'); ylabel('L2 error')
legend(['CG2','CG3',"%0.2f*log(DOFs)" % slope, "%0.2f*log(DOFs)" % slope2]) #legend(['new data','old_data'])
# savefig('comparison.png',dpi=300) #savefig('comparison.eps');
if not noplot:
show()
开发者ID:meshadaptation,项目名称:pragmatic,代码行数:101,代码来源:adv_convergence.py
注:本文中的pylab.pyplot函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论