本文整理汇总了Python中matplotlib.pyplot.rc_context函数的典型用法代码示例。如果您正苦于以下问题:Python rc_context函数的具体用法?Python rc_context怎么用?Python rc_context使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了rc_context函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _fig_size_cntx
def _fig_size_cntx(fig, fig_size_inches, tight_layout):
"""Resize a figure in a context
Parameters
----------
fig : matplotlib.figure.Figure
The figure to resize
fig_size_inches : tuple
The (height, width) to use in the context. If None, the size
is not changed
tight_layout : boolean
When True, tight layout is used.
"""
orig_size = fig.get_size_inches()
orig_layout = fig.get_tight_layout()
if fig_size_inches is not None:
fig.set_size_inches(*fig_size_inches)
fig.set_tight_layout(tight_layout)
if tight_layout:
rc_params = {'savefig.bbox': 'tight'}
else:
rc_params = {'savefig.bbox': 'standard'}
try:
with plt.rc_context(rc_params):
yield fig
finally:
fig.set_size_inches(*orig_size)
fig.set_tight_layout(orig_layout)
开发者ID:soft-matter,项目名称:pims,代码行数:28,代码来源:display.py
示例2: filterstats
def filterstats(input_fn, output_dir, topn=None,
maxeerates=[0.25, 0.5, 0.75, 1, 1.25, 1.5], maxns=None):
if not os.path.isdir(output_dir):
raise ValueError("directory {} does not exist".format(output_dir))
minlen_fn = os.path.join(output_dir, "filterstats_minlen.txt")
trunclen_fn = os.path.join(output_dir, "filterstats_trunclen.txt")
plot_fn = os.path.join(output_dir, "filterstats_plot.png")
minlen, trunclen = _stats(
input_fn=input_fn,
topn=topn,
maxeerates=maxeerates,
maxns=maxns)
minlen.to_csv(minlen_fn, sep="\t", float_format="%.3f", index=False)
trunclen.to_csv(trunclen_fn, sep="\t", float_format="%.3f", index=False)
# custom rc. svg.fonttype": "none" corrects the conversion of text in PDF
# and SVG files
rc = {
"xtick.labelsize": 8,
"ytick.labelsize": 8,
"axes.labelsize": 10,
"legend.fontsize": 10,
"svg.fonttype": "none"}
with plt.rc_context(rc=rc):
_plot(minlen, trunclen, plot_fn)
开发者ID:compmetagen,项目名称:micca,代码行数:30,代码来源:_filterstats.py
示例3: stats
def stats(input_fn, output_dir, topn=None):
if not os.path.isdir(output_dir):
raise ValueError("directory {} does not exist".format(output_dir))
len_dist_fn = os.path.join(output_dir, "stats_lendist.txt")
qual_dist_fn = os.path.join(output_dir, "stats_qualdist.txt")
qual_summ_fn = os.path.join(output_dir, "stats_qualsumm.txt")
len_dist_plot_fn = os.path.join(output_dir, "stats_lendist_plot.png")
qual_dist_plot_fn = os.path.join(output_dir, "stats_qualdist_plot.png")
qual_summ_plot_fn = os.path.join(output_dir, "stats_qualsumm_plot.png")
len_dist, qual_dist, qual_summ = _stats(input_fn=input_fn, topn=topn)
len_dist.to_csv(len_dist_fn, sep="\t", float_format="%.3f", index=False)
qual_dist.to_csv(qual_dist_fn, sep="\t", float_format="%.3f", index=False)
qual_summ.to_csv(qual_summ_fn, sep="\t", float_format="%.3f", index=False)
# custom rc. "svg.fonttype: none" corrects the conversion of text in PDF
# and SVG files
rc = {
"xtick.labelsize": 8,
"ytick.labelsize": 8,
"axes.labelsize": 10,
"legend.fontsize": 10,
"svg.fonttype": "none"}
with plt.rc_context(rc=rc):
_plot_len_dist(len_dist, len_dist_plot_fn)
_plot_qual_dist(qual_dist, qual_dist_plot_fn)
_plot_qual_summ(qual_summ, qual_summ_plot_fn)
开发者ID:compmetagen,项目名称:micca,代码行数:32,代码来源:_stats.py
示例4: plot_alpha
def plot_alpha(metadata, category, hue):
import seaborn as sns
with plt.rc_context(dict(sns.axes_style("darkgrid"),
**sns.plotting_context("notebook", font_scale=2))):
width = len(metadata[category].unique())
plt.figure(figsize=(width*4, 8))
sns.boxplot(x=category, y='Alpha diversity',
data=metadata.sort(category), hue=hue, palette='cubehelix')
开发者ID:jairideout,项目名称:q2d2,代码行数:8,代码来源:wui.py
示例5: minorticksubplot
def minorticksubplot(xminor, yminor, i):
rc = {'xtick.minor.visible': xminor,
'ytick.minor.visible': yminor}
with plt.rc_context(rc=rc):
ax = fig.add_subplot(2, 2, i)
assert (len(ax.xaxis.get_minor_ticks()) > 0) == xminor
assert (len(ax.yaxis.get_minor_ticks()) > 0) == yminor
开发者ID:dstansby,项目名称:matplotlib,代码行数:8,代码来源:test_ticker.py
示例6: test_get_color_cycle
def test_get_color_cycle(self):
if mpl_ge_150:
colors = [(1., 0., 0.), (0, 1., 0.)]
prop_cycle = plt.cycler(color=colors)
with plt.rc_context({"axes.prop_cycle": prop_cycle}):
result = utils.get_color_cycle()
assert result == colors
开发者ID:bicycle1885,项目名称:seaborn,代码行数:8,代码来源:test_palettes.py
示例7: view
def view(self, test=False):
"""Displays the graph"""
if test:
self._attr["style"] = True
AttrConf.MPL_STYLE["interactive"] = False
if self._attr["concat"]:
if self._attr["style"]:
with plt.rc_context(AttrConf.MPL_STYLE):
self._plot_concat()
else:
self._plot_concat()
else:
if self._attr["style"]:
with plt.rc_context(AttrConf.MPL_STYLE):
self._plot(self._attr["permute"])
else:
self._plot(self._attr["permute"])
开发者ID:phil-chen,项目名称:trappy,代码行数:19,代码来源:LinePlot.py
示例8: create_icon_axes
def create_icon_axes(fig, ax_position, lw_bars, lw_grid, lw_border, rgrid):
"""
Create a polar axes containing the matplotlib radar plot.
Parameters
----------
fig : matplotlib.figure.Figure
The figure to draw into.
ax_position : (float, float, float, float)
The position of the created Axes in figure coordinates as
(x, y, width, height).
lw_bars : float
The linewidth of the bars.
lw_grid : float
The linewidth of the grid.
lw_border : float
The linewidth of the Axes border.
rgrid : array-like
Positions of the radial grid.
Returns
-------
ax : matplotlib.axes.Axes
The created Axes.
"""
with plt.rc_context({'axes.edgecolor': MPL_BLUE,
'axes.linewidth': lw_border}):
ax = fig.add_axes(ax_position, projection='polar')
ax.set_axisbelow(True)
N = 7
arc = 2. * np.pi
theta = np.arange(0.0, arc, arc / N)
radii = np.array([2, 6, 8, 7, 4, 5, 8])
width = np.pi / 4 * np.array([0.4, 0.4, 0.6, 0.8, 0.2, 0.5, 0.3])
bars = ax.bar(theta, radii, width=width, bottom=0.0, align='edge',
edgecolor='0.3', lw=lw_bars)
for r, bar in zip(radii, bars):
color = *cm.jet(r / 10.)[:3], 0.6 # color from jet with alpha=0.6
bar.set_facecolor(color)
ax.tick_params(labelbottom=False, labeltop=False,
labelleft=False, labelright=False)
ax.grid(lw=lw_grid, color='0.9')
ax.set_rmax(9)
ax.set_yticks(rgrid)
# the actual visible background - extends a bit beyond the axis
ax.add_patch(Rectangle((0, 0), arc, 9.58,
facecolor='white', zorder=0,
clip_on=False, in_layout=False))
return ax
开发者ID:QuLogic,项目名称:matplotlib,代码行数:53,代码来源:logos2.py
示例9: prepare
def prepare(self):
sns.set_style('ticks')
sns.set_context('paper')
with plt.rc_context(plot_params):
self.fig = plt.figure(figsize=(7, 7))
gs = plt.GridSpec(3, 2)
self.ax = {
'ispectrum': self.fig.add_subplot(gs[2, :]),
'scatter': self.fig.add_subplot(gs[:2, :]),
}
# self.ax['violin'] = self.fig.add_subplot(gs[1:, -1])
self.gs = gs
开发者ID:fabiansinz,项目名称:efish_locking,代码行数:12,代码来源:figure_locking.py
示例10: view
def view(self, test=False):
"""Displays the graph"""
if test:
self._attr["style"] = True
AttrConf.MPL_STYLE["interactive"] = False
permute = self._attr["permute"] and not self._attr["concat"]
if self._attr["style"]:
with plt.rc_context(AttrConf.MPL_STYLE):
self._resolve(permute, self._attr["concat"])
else:
self._resolve(permute, self._attr["concat"])
开发者ID:JaviMerino,项目名称:trappy,代码行数:13,代码来源:StaticPlot.py
示例11: stats
def stats(input_fn, output_dir, step=100, replace=False, seed=0):
if not os.path.isdir(output_dir):
raise ValueError("directory {} does not exist".format(output_dir))
sample_summ_fn = os.path.join(output_dir, "tablestats_samplesumm.txt")
otu_summ_fn = os.path.join(output_dir, "tablestats_otusumm.txt")
rarecurve_fn = os.path.join(output_dir, "tablestats_rarecurve.txt")
rarecurve_plot_fn = os.path.join(output_dir, "tablestats_rarecurve_plot.png")
table = micca.table.read(input_fn)
# sample summary
sample_summ = pd.DataFrame({
"Depth": table.sum(),
"NOTU": (table > 0).sum(),
"NSingle": (table == 1).sum()},
columns=["Depth", "NOTU", "NSingle"])
sample_summ.index.name = "Sample"
sample_summ.sort_values(by="Depth", inplace=True)
sample_summ.to_csv(sample_summ_fn, sep='\t')
# OTU summary
otu_summ = pd.DataFrame({
"N": table.sum(axis=1),
"NSample": (table > 0).sum(axis=1)},
columns=["N", "NSample"])
otu_summ.index.name = "OTU"
otu_summ.sort_values(by="N", inplace=True, ascending=False)
otu_summ.to_csv(otu_summ_fn, sep='\t')
# rarefaction curves
rarecurve = micca.table.rarecurve(table, step=step, replace=replace,
seed=seed)
rarecurve.to_csv(rarecurve_fn, sep='\t', float_format="%.0f", na_rep="NA")
# custom rc. "svg.fonttype: none" corrects the conversion of text in PDF
# and SVG files
rc = {
"xtick.labelsize": 8,
"ytick.labelsize": 8,
"axes.labelsize": 10,
"legend.fontsize": 10,
"svg.fonttype": "none"}
with plt.rc_context(rc=rc):
fig = plt.figure(figsize=(10, 6))
plt.plot(rarecurve.index, rarecurve.to_numpy(), color="k")
plt.xlabel("Depth")
plt.ylabel("#OTUs")
fig.savefig(rarecurve_plot_fn, dpi=300, bbox_inches='tight', format="png")
开发者ID:compmetagen,项目名称:micca,代码行数:51,代码来源:table.py
示例12: prepare
def prepare(self):
sns.set_style('ticks')
sns.set_context('paper')
with plt.rc_context(plot_params):
self.fig = plt.figure(figsize=(7, 7))
gs = plt.GridSpec(3, 12)
self.ax = {
'spectrum': self.fig.add_subplot(gs[:2, 4:]),
'ISI': self.fig.add_subplot(gs[0, :4]),
'cycle': self.fig.add_subplot(gs[1, :4]),
'vs_freq': self.fig.add_subplot(gs[2, :4]),
}
self.ax['cycle_ampl'] = self.ax['cycle'].twinx()
self.ax['circ'] = self.fig.add_subplot(gs[2, 4:8])
self.ax['contrast'] = self.fig.add_subplot(gs[2, 8:])
开发者ID:fabiansinz,项目名称:efish_locking,代码行数:15,代码来源:figure_pyramidals.py
示例13: prepare
def prepare(self):
sns.set_style('ticks')
sns.set_context('paper')
with plt.rc_context(plot_params):
self.fig = plt.figure(figsize=(7, 7))
gs = plt.GridSpec(3, 2)
self.ax = {
'scatter': self.fig.add_subplot(gs[-1, :]),
# 'spectrum': self.fig.add_subplot(gs[1:, :-1]),
'ISI': self.fig.add_subplot(gs[1, 1]),
'EOD': self.fig.add_subplot(gs[1, 0]),
}
self.ax['scatter_base'] = self.fig.add_subplot(gs[0, :])
self.ax['EOD_ampl'] = self.ax['EOD'].twinx()
self.gs = gs
开发者ID:fabiansinz,项目名称:efish_locking,代码行数:16,代码来源:figure_intro_punit.py
示例14: plot_dendrogram
def plot_dendrogram(clustering, size=10):
link = clustering['linkage']
labels = clustering['labels']
link_function, colors = get_dendrogram_color_fun(link, clustering['reorder_vec'],
labels)
# set figure properties
figsize = (size, size*.6)
with sns.axes_style('white'):
fig = plt.figure(figsize=figsize)
# **********************************
# plot dendrogram
# **********************************
with plt.rc_context({'lines.linewidth': size*.125}):
dendrogram(link, link_color_func=link_function,
orientation='top')
开发者ID:IanEisenberg,项目名称:Self_Regulation_Ontology,代码行数:16,代码来源:individual_structure.py
示例15: plot_stacked_bar
def plot_stacked_bar(df):
import seaborn as sns
with plt.rc_context(dict(sns.axes_style("darkgrid"),
**sns.plotting_context("notebook", font_scale=1.8))):
f, ax = plt.subplots(1, figsize=(10, 10))
x = list(range(len(df.columns)))
bottom = np.array([0] * len(df.columns))
cat_percents = []
for id_ in df.index:
color = '#' + ''.join(np.random.choice(list('ABCDEF123456789'), 6))
ax.bar(x, df.loc[id_], color=color, bottom=bottom, align='center')
bottom = df.loc[id_] + bottom
cat_percents.append(''.join(["[{0:.2f}] ".format(x) for x in df.loc[id_].tolist()]))
legend_labels = [' '.join(e) for e in zip(cat_percents, df.index.tolist())]
ax.set_xticks(x)
ax.set_xticklabels(df.columns.tolist())
ax.set_ylim([0, 1])
ax.legend(legend_labels, loc='center left', bbox_to_anchor=(1, 0.5))
开发者ID:jairideout,项目名称:q2d2,代码行数:20,代码来源:wui.py
示例16: render
def render(self):
"""
Actually render the figure.
:returns: A :mod:`matplotlib` figure object.
"""
# Use custom matplotlib context
with plt.rc_context(rc=custom_mpl.custom_rc(rc=self.custom_mpl_rc)):
# Create figure if necessary
figure, axes = self._render_grid()
# Render depending on animation type
if self.animation["type"] is False:
self._render_no_animation(axes)
elif self.animation["type"] == "gif":
self._render_gif_animation(figure, axes)
elif self.animation["type"] == "animation":
# TODO
return None
else:
return None
# Use tight_layout to optimize layout, use custom padding
figure.tight_layout(pad=1) # TODO: Messes up animations
return figure
开发者ID:Phyks,项目名称:replot,代码行数:24,代码来源:figure.py
示例17: plot_dendrogram
def plot_dendrogram(loading, clustering, title=None,
break_lines=True, drop_list=None, double_drop_list=None,
absolute_loading=False, size=4.6, dpi=300,
filename=None):
""" Plots HCA results as dendrogram with loadings underneath
Args:
loading: pandas df, a results EFA loading matrix
clustering: pandas df, a results HCA clustering
title (optional): str, title to plot
break_lines: whether to separate EFA heatmap based on clusters, default=True
drop_list (optional): list of cluster indices to drop the cluster label
drop_list (optional): list of cluster indices to drop the cluster label twice
absolute_loading: whether to plot the absolute loading value, default False
plot_dir: if set, where to save the plot
"""
c = loading.shape[1]
# extract cluster vars
link = clustering['linkage']
DVs = clustering['clustered_df'].columns
ordered_loading = loading.loc[DVs]
if absolute_loading:
ordered_loading = abs(ordered_loading)
# get cluster sizes
labels=clustering['labels']
cluster_sizes = [np.sum(labels==(i+1)) for i in range(max(labels))]
link_function, colors = get_dendrogram_color_fun(link, clustering['reorder_vec'],
labels)
# set figure properties
figsize = (size, size*.6)
# set up axes' size
heatmap_height = ordered_loading.shape[1]*.035
heat_size = [.1, heatmap_height]
dendro_size=[np.sum(heat_size), .3]
# set up plot axes
dendro_size = [.15,dendro_size[0], .78, dendro_size[1]]
heatmap_size = [.15,heat_size[0],.78,heat_size[1]]
cbar_size = [.935,heat_size[0],.015,heat_size[1]]
ordered_loading = ordered_loading.T
with sns.axes_style('white'):
fig = plt.figure(figsize=figsize)
ax1 = fig.add_axes(dendro_size)
# **********************************
# plot dendrogram
# **********************************
with plt.rc_context({'lines.linewidth': size*.125}):
dendrogram(link, ax=ax1, link_color_func=link_function,
orientation='top')
# change axis properties
ax1.tick_params(axis='x', which='major', labelsize=14,
labelbottom=False)
ax1.get_yaxis().set_visible(False)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.spines['bottom'].set_visible(False)
ax1.spines['left'].set_visible(False)
# **********************************
# plot loadings as heatmap below
# **********************************
ax2 = fig.add_axes(heatmap_size)
cbar_ax = fig.add_axes(cbar_size)
max_val = np.max(abs(loading.values))
# bring to closest .25
max_val = ceil(max_val*4)/4
sns.heatmap(ordered_loading, ax=ax2,
cbar=True, cbar_ax=cbar_ax,
yticklabels=True,
xticklabels=True,
vmax = max_val, vmin = -max_val,
cbar_kws={'orientation': 'vertical',
'ticks': [-max_val, 0, max_val]},
cmap=sns.diverging_palette(220,15,n=100,as_cmap=True))
ax2.set_yticklabels(ax2.get_yticklabels(), rotation=0)
ax2.tick_params(axis='y', labelsize=size*heat_size[1]*30/c, pad=size/4, length=0)
# format cbar axis
cbar_ax.set_yticklabels([format_num(-max_val), 0, format_num(max_val)])
cbar_ax.tick_params(labelsize=size*heat_size[1]*25/c, length=0, pad=size/2)
cbar_ax.set_ylabel('Factor Loading', rotation=-90,
fontsize=size*heat_size[1]*30/c, labelpad=size*2)
# add lines to heatmap to distinguish clusters
if break_lines == True:
xlim = ax2.get_xlim();
ylim = ax2.get_ylim()
step = xlim[1]/len(labels)
cluster_breaks = [i*step for i in np.cumsum(cluster_sizes)]
ax2.vlines(cluster_breaks[:-1], ylim[0], ylim[1], linestyles='dashed',
linewidth=size*.1, colors=[.5,.5,.5], zorder=10)
# **********************************
# plot cluster names
# **********************************
beginnings = np.hstack([[0],np.cumsum(cluster_sizes)[:-1]])
centers = beginnings+np.array(cluster_sizes)//2+.5
offset = .07
if 'cluster_names' in clustering.keys():
ax2.tick_params(axis='x', reset=True, top=False, bottom=False, width=size/8, length=0)
#.........这里部分代码省略.........
开发者ID:IanEisenberg,项目名称:Self_Regulation_Ontology,代码行数:101,代码来源:HCA_plots.py
示例18: plot_fluxnet_comparison_one_site
def plot_fluxnet_comparison_one_site(driver, science_test_data_dir,
compare_data_dict, result_dir, plot_dir,
plots_to_make, context, style, var_names,
months, obs_dir, subdir):
if check_site_files(obs_dir, subdir):
# get CSV file from site directory to get lat/lng for site
lat, lng = get_fluxnet_lat_lon(obs_dir, subdir)
print(lat, lng)
# loop over data to compare
data = {}
for key, items in compare_data_dict.items():
if key == "ecflux":
try:
# load Ameriflux data
data[key] = read_fluxnet_obs(subdir,
science_test_data_dir,
items)
except OSError:
warnings.warn(
"this %s site does not have data" % subdir)
elif key == "VIC.4.2.d":
try:
# load VIC 4.2 simulations
data[key] = read_vic_42_output(lat, lng,
science_test_data_dir,
items)
except OSError:
warnings.warn(
"this site has a lat/lng precision issue")
else:
try:
# load VIC 5 simulations
data[key] = read_vic_5_output(lat, lng,
result_dir,
items)
except OSError:
warnings.warn(
"this site has a lat/lng precision issue")
# make figures
# plot preferences
fs = 15
dpi = 150
if 'annual_mean_diurnal_cycle' in plots_to_make:
# make annual mean diurnal cycle plots
with plt.rc_context(dict(sns.axes_style(style),
**sns.plotting_context(context))):
f, axarr = plt.subplots(4, 1, figsize=(8, 8), sharex=True)
for i, (vic_var, variable_name) in enumerate(
var_names.items()):
# calculate annual mean diurnal cycle for each
# DataFrame
annual_mean = {}
for key, df in data.items():
annual_mean[key] = pd.DataFrame(
df[vic_var].groupby(df.index.hour).mean())
df = pd.DataFrame(
{key: d[vic_var] for key, d in annual_mean.items()
if vic_var in d})
for key, series in df.iteritems():
series.plot(
linewidth=compare_data_dict[key]['linewidth'],
ax=axarr[i],
color=compare_data_dict[key]['color'],
linestyle=compare_data_dict[key]['linestyle'],
zorder=compare_data_dict[key]['zorder'])
axarr[i].legend(loc='upper left')
axarr[i].set_ylabel(
'%s ($W/{m^2}$)' % variable_name,
size=fs)
axarr[i].set_xlabel('Time of Day (Hour)', size=fs)
axarr[i].set_xlim([0, 24])
axarr[i].xaxis.set_ticks(np.arange(0, 24, 3))
# save plot
plotname = '%s_%s.png' % (lat, lng)
os.makedirs(os.path.join(plot_dir, 'annual_mean'),
exist_ok=True)
savepath = os.path.join(plot_dir, 'annual_mean', plotname)
plt.savefig(savepath, bbox_inches='tight', dpi=dpi)
plt.clf()
plt.close()
if 'monthly_mean_diurnal_cycle' in plots_to_make:
#.........这里部分代码省略.........
开发者ID:BramDr,项目名称:VIC,代码行数:101,代码来源:test_utils.py
示例19: plot_snotel_comparison_one_site
def plot_snotel_comparison_one_site(
driver, science_test_data_dir,
compare_data_dict,
result_dir, plot_dir,
plots_to_make,
plot_variables, context, style, filename):
print(plots_to_make)
# get lat/lng from filename
file_split = re.split('_', filename)
lng = file_split[3].split('.txt')[0]
lat = file_split[2]
print('Plotting {} {}'.format(lat, lng))
# loop over data to compare
data = {}
for key, items in compare_data_dict.items():
# read in data
if key == "snotel":
data[key] = read_snotel_swe_obs(filename,
science_test_data_dir,
items)
elif key == "VIC.4.2.d":
data[key] = read_vic_42_output(lat, lng,
science_test_data_dir,
items)
else:
data[key] = read_vic_5_output(lat, lng,
result_dir,
items)
# loop over variables to plot
for plot_variable, units in plot_variables.items():
if 'water_year' in plots_to_make:
with plt.rc_context(dict(sns.axes_style(style),
**sns.plotting_context(context))):
fig, ax = plt.subplots(figsize=(10, 10))
df = pd.DataFrame({key: d[plot_variable] for key, d in
data.items() if plot_variable in d})
for key, series in df.iteritems():
series.plot(
use_index=True,
linewidth=compare_data_dict[key]['linewidth'],
ax=ax,
color=compare_data_dict[key]['color'],
linestyle=compare_data_dict[key]
['linestyle'],
zorder=compare_data_dict[key]['zorder'])
ax.legend(loc='upper left')
ax.set_ylabel("%s [%s]" % (plot_variable, units))
# save figure
os.makedirs(os.path.join(plot_dir, plot_variable),
exist_ok=True)
plotname = '%s_%s.png' % (lat, lng)
savepath = os.path.join(plot_dir, plot_variable, plotname)
plt.savefig(savepath, bbox_inches='tight')
print(savepath)
plt.clf()
plt.close()
开发者ID:BramDr,项目名称:VIC,代码行数:69,代码来源:test_utils.py
示例20: draw_termite_plot
def draw_termite_plot(values_mat, col_labels, row_labels,
highlight_cols=None, highlight_colors=None,
save=False):
"""
Make a "termite" plot, typically used for assessing topic models with a tabular
layout that promotes comparison of terms both within and across topics.
Args:
values_mat (``np.ndarray`` or matrix): matrix of values with shape
(# row labels, # col labels) used to size the dots on the grid
col_labels (seq[str]): labels used to identify x-axis ticks on the grid
row_labels(seq[str]): labels used to identify y-axis ticks on the grid
highlight_cols (int or seq[int], optional): indices for columns
to visually highlight in the plot with contrasting colors
highlight_colors (tuple of 2-tuples): each 2-tuple corresponds to a pair
of (light/dark) matplotlib-friendly colors used to highlight a single
column; if not specified (default), a good set of 6 pairs are used
save (str, optional): give the full /path/to/fname on disk to save figure
Returns:
``matplotlib.axes.Axes.axis``: axis on which termite plot is plotted
Raises:
ValueError: if more columns are selected for highlighting than colors
or if any of the inputs' dimensions don't match
References:
.. Chuang, Jason, Christopher D. Manning, and Jeffrey Heer. "Termite:
Visualization techniques for assessing textual topic models."
Proceedings of the International Working Conference on Advanced
Visual Interfaces. ACM, 2012.
.. seealso:: :func:`TopicModel.termite_plot <textacy.tm.TopicModel.termite_plot>`
"""
try:
plt
except NameError:
raise ImportError(
'matplotlib is not installed, so textacy.viz won\'t work; install it \
individually, or along with textacy via `pip install textacy[viz]`')
n_rows, n_cols = values_mat.shape
max_val = np.max(values_mat)
if n_rows != len(row_labels):
msg = "values_mat and row_labels dimensions don't match: {} vs. {}".format(
n_rows, len(row_labels))
raise ValueError(msg)
if n_cols != len(col_labels):
msg = "values_mat and col_labels dimensions don't match: {} vs. {}".format(
n_cols, len(col_labels))
raise ValueError(msg)
if highlight_colors is None:
highlight_colors = COLOR_PAIRS
if highlight_cols is not None:
if isinstance(highlight_cols, int):
highlight_cols = (highlight_cols,)
elif len(highlight_cols) > len(highlight_colors):
msg = 'no more than {} columns may be highlighted at once'.format(
len(highlight_colors))
raise ValueError(msg)
highlight_colors = {hc: COLOR_PAIRS[i]
for i, hc in enumerate(highlight_cols)}
with plt.rc_context(RC_PARAMS):
fig, ax = plt.subplots(figsize=(pow(n_cols, 0.8), pow(n_rows, 0.66)))
_ = ax.set_yticks(range(n_rows))
yticklabels = ax.set_yticklabels(row_labels,
fontsize=14, color='gray')
if highlight_cols is not None:
for i, ticklabel in enumerate(yticklabels):
max_tick_val = max(values_mat[i, hc] for hc in highlight_cols)
for hc in highlight_cols:
if max_tick_val > 0 and values_mat[i, hc] == max_tick_val:
ticklabel.set_color(highlight_colors[hc][1])
ax.get_xaxis().set_ticks_position('top')
_ = ax.set_xticks(range(n_cols))
xticklabels = ax.set_xticklabels(col_labels,
fontsize=14, color='gray',
rotation=30, ha='left')
if highlight_cols is not None:
gridlines = ax.get_xgridlines()
for i, ticklabel in enumerate(xticklabels):
if i in highlight_cols:
ticklabel.set_color(highlight_colors[i][1])
gridlines[i].set_color(highlight_colors[i][0])
gridlines[i].set_alpha(0.5)
for col_ind in range(n_cols):
if highlight_cols is not None and col_ind in highlight_cols:
ax.scatter([col_ind for _ in range(n_rows)],
[i for i in range(n_rows)],
s=600 * (values_mat[:, col_ind] / max_val),
alpha=0.5, linewidth=1,
color=highlight_colors[col_ind][0],
edgecolor=highlight_colors[col_ind][1])
else:
ax.scatter([col_ind for _ in range(n_rows)],
#.........这里部分代码省略.........
开发者ID:chartbeat-labs,项目名称:textacy,代码行数:101,代码来源:termite.py
注:本文中的matplotlib.pyplot.rc_context函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论