• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python pyplot.rc_context函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中matplotlib.pyplot.rc_context函数的典型用法代码示例。如果您正苦于以下问题:Python rc_context函数的具体用法?Python rc_context怎么用?Python rc_context使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了rc_context函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: _fig_size_cntx

def _fig_size_cntx(fig, fig_size_inches, tight_layout):
    """Resize a figure in a context

    Parameters
    ----------
    fig : matplotlib.figure.Figure
        The figure to resize
    fig_size_inches : tuple
        The (height, width) to use in the context. If None, the size
        is not changed
    tight_layout : boolean
        When True, tight layout is used.
    """
    orig_size = fig.get_size_inches()
    orig_layout = fig.get_tight_layout()
    if fig_size_inches is not None:
        fig.set_size_inches(*fig_size_inches)
    fig.set_tight_layout(tight_layout)
    if tight_layout:
        rc_params = {'savefig.bbox': 'tight'}
    else:
        rc_params = {'savefig.bbox': 'standard'}
    try:
        with plt.rc_context(rc_params):
            yield fig
    finally:
        fig.set_size_inches(*orig_size)
        fig.set_tight_layout(orig_layout)
开发者ID:soft-matter,项目名称:pims,代码行数:28,代码来源:display.py


示例2: filterstats

def filterstats(input_fn, output_dir, topn=None,
                maxeerates=[0.25, 0.5, 0.75, 1, 1.25, 1.5], maxns=None):

    if not os.path.isdir(output_dir):
        raise ValueError("directory {} does not exist".format(output_dir))

    minlen_fn = os.path.join(output_dir, "filterstats_minlen.txt")
    trunclen_fn = os.path.join(output_dir, "filterstats_trunclen.txt")
    plot_fn = os.path.join(output_dir, "filterstats_plot.png")

    minlen, trunclen = _stats(
        input_fn=input_fn,
        topn=topn,
        maxeerates=maxeerates,
        maxns=maxns)

    minlen.to_csv(minlen_fn, sep="\t", float_format="%.3f", index=False)
    trunclen.to_csv(trunclen_fn, sep="\t", float_format="%.3f", index=False)

    # custom rc. svg.fonttype": "none" corrects the conversion of text in PDF
    # and SVG files
    rc = {
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "axes.labelsize": 10,
        "legend.fontsize": 10,
        "svg.fonttype": "none"}

    with plt.rc_context(rc=rc):
        _plot(minlen, trunclen, plot_fn)
开发者ID:compmetagen,项目名称:micca,代码行数:30,代码来源:_filterstats.py


示例3: stats

def stats(input_fn, output_dir, topn=None):

    if not os.path.isdir(output_dir):
        raise ValueError("directory {} does not exist".format(output_dir))

    len_dist_fn = os.path.join(output_dir, "stats_lendist.txt")
    qual_dist_fn = os.path.join(output_dir, "stats_qualdist.txt")
    qual_summ_fn = os.path.join(output_dir, "stats_qualsumm.txt")

    len_dist_plot_fn = os.path.join(output_dir, "stats_lendist_plot.png")
    qual_dist_plot_fn = os.path.join(output_dir, "stats_qualdist_plot.png")
    qual_summ_plot_fn = os.path.join(output_dir, "stats_qualsumm_plot.png")

    len_dist, qual_dist, qual_summ = _stats(input_fn=input_fn, topn=topn)

    len_dist.to_csv(len_dist_fn, sep="\t", float_format="%.3f", index=False)
    qual_dist.to_csv(qual_dist_fn, sep="\t", float_format="%.3f", index=False)
    qual_summ.to_csv(qual_summ_fn, sep="\t", float_format="%.3f", index=False)

    # custom rc. "svg.fonttype: none" corrects the conversion of text in PDF
    # and SVG files
    rc = {
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "axes.labelsize": 10,
        "legend.fontsize": 10,
        "svg.fonttype": "none"}

    with plt.rc_context(rc=rc):
        _plot_len_dist(len_dist, len_dist_plot_fn)
        _plot_qual_dist(qual_dist, qual_dist_plot_fn)
        _plot_qual_summ(qual_summ, qual_summ_plot_fn)
开发者ID:compmetagen,项目名称:micca,代码行数:32,代码来源:_stats.py


示例4: plot_alpha

def plot_alpha(metadata, category, hue):
    import seaborn as sns
    with plt.rc_context(dict(sns.axes_style("darkgrid"),
                             **sns.plotting_context("notebook", font_scale=2))):
        width = len(metadata[category].unique())
        plt.figure(figsize=(width*4, 8))
        sns.boxplot(x=category, y='Alpha diversity',
                    data=metadata.sort(category), hue=hue, palette='cubehelix')
开发者ID:jairideout,项目名称:q2d2,代码行数:8,代码来源:wui.py


示例5: minorticksubplot

    def minorticksubplot(xminor, yminor, i):
        rc = {'xtick.minor.visible': xminor,
              'ytick.minor.visible': yminor}
        with plt.rc_context(rc=rc):
            ax = fig.add_subplot(2, 2, i)

        assert (len(ax.xaxis.get_minor_ticks()) > 0) == xminor
        assert (len(ax.yaxis.get_minor_ticks()) > 0) == yminor
开发者ID:dstansby,项目名称:matplotlib,代码行数:8,代码来源:test_ticker.py


示例6: test_get_color_cycle

    def test_get_color_cycle(self):

        if mpl_ge_150:
            colors = [(1., 0., 0.), (0, 1., 0.)]
            prop_cycle = plt.cycler(color=colors)
            with plt.rc_context({"axes.prop_cycle": prop_cycle}):
                result = utils.get_color_cycle()
                assert result == colors
开发者ID:bicycle1885,项目名称:seaborn,代码行数:8,代码来源:test_palettes.py


示例7: view

    def view(self, test=False):
        """Displays the graph"""

        if test:
            self._attr["style"] = True
            AttrConf.MPL_STYLE["interactive"] = False

        if self._attr["concat"]:
            if self._attr["style"]:
                with plt.rc_context(AttrConf.MPL_STYLE):
                    self._plot_concat()
            else:
                self._plot_concat()
        else:
            if self._attr["style"]:
                with plt.rc_context(AttrConf.MPL_STYLE):
                    self._plot(self._attr["permute"])
            else:
                self._plot(self._attr["permute"])
开发者ID:phil-chen,项目名称:trappy,代码行数:19,代码来源:LinePlot.py


示例8: create_icon_axes

def create_icon_axes(fig, ax_position, lw_bars, lw_grid, lw_border, rgrid):
    """
    Create a polar axes containing the matplotlib radar plot.

    Parameters
    ----------
    fig : matplotlib.figure.Figure
        The figure to draw into.
    ax_position : (float, float, float, float)
        The position of the created Axes in figure coordinates as
        (x, y, width, height).
    lw_bars : float
        The linewidth of the bars.
    lw_grid : float
        The linewidth of the grid.
    lw_border : float
        The linewidth of the Axes border.
    rgrid : array-like
        Positions of the radial grid.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The created Axes.
    """
    with plt.rc_context({'axes.edgecolor': MPL_BLUE,
                         'axes.linewidth': lw_border}):
        ax = fig.add_axes(ax_position, projection='polar')
        ax.set_axisbelow(True)

        N = 7
        arc = 2. * np.pi
        theta = np.arange(0.0, arc, arc / N)
        radii = np.array([2, 6, 8, 7, 4, 5, 8])
        width = np.pi / 4 * np.array([0.4, 0.4, 0.6, 0.8, 0.2, 0.5, 0.3])
        bars = ax.bar(theta, radii, width=width, bottom=0.0, align='edge',
                      edgecolor='0.3', lw=lw_bars)
        for r, bar in zip(radii, bars):
            color = *cm.jet(r / 10.)[:3], 0.6  # color from jet with alpha=0.6
            bar.set_facecolor(color)

        ax.tick_params(labelbottom=False, labeltop=False,
                       labelleft=False, labelright=False)

        ax.grid(lw=lw_grid, color='0.9')
        ax.set_rmax(9)
        ax.set_yticks(rgrid)

        # the actual visible background - extends a bit beyond the axis
        ax.add_patch(Rectangle((0, 0), arc, 9.58,
                               facecolor='white', zorder=0,
                               clip_on=False, in_layout=False))
        return ax
开发者ID:QuLogic,项目名称:matplotlib,代码行数:53,代码来源:logos2.py


示例9: prepare

 def prepare(self):
     sns.set_style('ticks')
     sns.set_context('paper')
     with plt.rc_context(plot_params):
         self.fig = plt.figure(figsize=(7, 7))
         gs = plt.GridSpec(3, 2)
         self.ax = {
             'ispectrum': self.fig.add_subplot(gs[2, :]),
             'scatter': self.fig.add_subplot(gs[:2, :]),
         }
         # self.ax['violin'] = self.fig.add_subplot(gs[1:, -1])
     self.gs = gs
开发者ID:fabiansinz,项目名称:efish_locking,代码行数:12,代码来源:figure_locking.py


示例10: view

    def view(self, test=False):
        """Displays the graph"""

        if test:
            self._attr["style"] = True
            AttrConf.MPL_STYLE["interactive"] = False

        permute = self._attr["permute"] and not self._attr["concat"]
        if self._attr["style"]:
            with plt.rc_context(AttrConf.MPL_STYLE):
                self._resolve(permute, self._attr["concat"])
        else:
            self._resolve(permute, self._attr["concat"])
开发者ID:JaviMerino,项目名称:trappy,代码行数:13,代码来源:StaticPlot.py


示例11: stats

def stats(input_fn, output_dir, step=100, replace=False, seed=0):

    if not os.path.isdir(output_dir):
        raise ValueError("directory {} does not exist".format(output_dir))

    sample_summ_fn = os.path.join(output_dir, "tablestats_samplesumm.txt")
    otu_summ_fn = os.path.join(output_dir, "tablestats_otusumm.txt")
    rarecurve_fn = os.path.join(output_dir, "tablestats_rarecurve.txt")
    rarecurve_plot_fn = os.path.join(output_dir, "tablestats_rarecurve_plot.png")

    table = micca.table.read(input_fn)

    # sample summary
    sample_summ = pd.DataFrame({
        "Depth": table.sum(),
        "NOTU": (table > 0).sum(),
        "NSingle": (table == 1).sum()},
        columns=["Depth", "NOTU", "NSingle"])
    sample_summ.index.name = "Sample"
    sample_summ.sort_values(by="Depth", inplace=True)
    sample_summ.to_csv(sample_summ_fn, sep='\t')

    # OTU summary
    otu_summ = pd.DataFrame({
        "N": table.sum(axis=1),
        "NSample": (table > 0).sum(axis=1)},
        columns=["N", "NSample"])
    otu_summ.index.name = "OTU"
    otu_summ.sort_values(by="N", inplace=True, ascending=False)
    otu_summ.to_csv(otu_summ_fn, sep='\t')

    # rarefaction curves
    rarecurve = micca.table.rarecurve(table, step=step, replace=replace,
                                      seed=seed)
    rarecurve.to_csv(rarecurve_fn, sep='\t', float_format="%.0f", na_rep="NA")

    # custom rc. "svg.fonttype: none" corrects the conversion of text in PDF
    # and SVG files
    rc = {
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "axes.labelsize": 10,
        "legend.fontsize": 10,
        "svg.fonttype": "none"}

    with plt.rc_context(rc=rc):
        fig = plt.figure(figsize=(10, 6))
        plt.plot(rarecurve.index, rarecurve.to_numpy(), color="k")
        plt.xlabel("Depth")
        plt.ylabel("#OTUs")
        fig.savefig(rarecurve_plot_fn, dpi=300, bbox_inches='tight', format="png")
开发者ID:compmetagen,项目名称:micca,代码行数:51,代码来源:table.py


示例12: prepare

 def prepare(self):
     sns.set_style('ticks')
     sns.set_context('paper')
     with plt.rc_context(plot_params):
         self.fig = plt.figure(figsize=(7, 7))
         gs = plt.GridSpec(3, 12)
         self.ax = {
             'spectrum': self.fig.add_subplot(gs[:2, 4:]),
             'ISI': self.fig.add_subplot(gs[0, :4]),
             'cycle': self.fig.add_subplot(gs[1, :4]),
             'vs_freq': self.fig.add_subplot(gs[2, :4]),
         }
         self.ax['cycle_ampl'] = self.ax['cycle'].twinx()
         self.ax['circ'] = self.fig.add_subplot(gs[2, 4:8])
         self.ax['contrast'] = self.fig.add_subplot(gs[2, 8:])
开发者ID:fabiansinz,项目名称:efish_locking,代码行数:15,代码来源:figure_pyramidals.py


示例13: prepare

    def prepare(self):
        sns.set_style('ticks')
        sns.set_context('paper')
        with plt.rc_context(plot_params):
            self.fig = plt.figure(figsize=(7, 7))
            gs = plt.GridSpec(3, 2)
            self.ax = {
                'scatter': self.fig.add_subplot(gs[-1, :]),
                # 'spectrum': self.fig.add_subplot(gs[1:, :-1]),
                'ISI': self.fig.add_subplot(gs[1, 1]),
                'EOD': self.fig.add_subplot(gs[1, 0]),

            }
            self.ax['scatter_base'] =  self.fig.add_subplot(gs[0, :])
            self.ax['EOD_ampl'] = self.ax['EOD'].twinx()
        self.gs = gs
开发者ID:fabiansinz,项目名称:efish_locking,代码行数:16,代码来源:figure_intro_punit.py


示例14: plot_dendrogram

def plot_dendrogram(clustering, size=10):
    link = clustering['linkage']
    labels = clustering['labels']
    link_function, colors = get_dendrogram_color_fun(link, clustering['reorder_vec'],
                                                     labels)
    
    # set figure properties
    figsize = (size, size*.6)
    with sns.axes_style('white'):
        fig = plt.figure(figsize=figsize)
        # **********************************
        # plot dendrogram
        # **********************************
        with plt.rc_context({'lines.linewidth': size*.125}):
            dendrogram(link,  link_color_func=link_function,
                       orientation='top')
开发者ID:IanEisenberg,项目名称:Self_Regulation_Ontology,代码行数:16,代码来源:individual_structure.py


示例15: plot_stacked_bar

def plot_stacked_bar(df):
    import seaborn as sns
    with plt.rc_context(dict(sns.axes_style("darkgrid"),
                         **sns.plotting_context("notebook", font_scale=1.8))):
        f, ax = plt.subplots(1, figsize=(10, 10))
        x = list(range(len(df.columns)))
        bottom = np.array([0] * len(df.columns))
        cat_percents = []
        for id_ in df.index:
            color = '#' + ''.join(np.random.choice(list('ABCDEF123456789'), 6))
            ax.bar(x, df.loc[id_], color=color, bottom=bottom, align='center')
            bottom = df.loc[id_] + bottom
            cat_percents.append(''.join(["[{0:.2f}] ".format(x) for x in df.loc[id_].tolist()]))

        legend_labels = [' '.join(e) for e in zip(cat_percents, df.index.tolist())]

        ax.set_xticks(x)
        ax.set_xticklabels(df.columns.tolist())
        ax.set_ylim([0, 1])
        ax.legend(legend_labels, loc='center left', bbox_to_anchor=(1, 0.5))
开发者ID:jairideout,项目名称:q2d2,代码行数:20,代码来源:wui.py


示例16: render

    def render(self):
        """
        Actually render the figure.

        :returns: A :mod:`matplotlib` figure object.
        """
        # Use custom matplotlib context
        with plt.rc_context(rc=custom_mpl.custom_rc(rc=self.custom_mpl_rc)):
            # Create figure if necessary
            figure, axes = self._render_grid()

            # Render depending on animation type
            if self.animation["type"] is False:
                self._render_no_animation(axes)
            elif self.animation["type"] == "gif":
                self._render_gif_animation(figure, axes)
            elif self.animation["type"] == "animation":
                # TODO
                return None
            else:
                return None
            # Use tight_layout to optimize layout, use custom padding
            figure.tight_layout(pad=1)  # TODO: Messes up animations
        return figure
开发者ID:Phyks,项目名称:replot,代码行数:24,代码来源:figure.py


示例17: plot_dendrogram

def plot_dendrogram(loading, clustering, title=None, 
                    break_lines=True, drop_list=None, double_drop_list=None,
                    absolute_loading=False,  size=4.6,  dpi=300, 
                    filename=None):
    """ Plots HCA results as dendrogram with loadings underneath
    
    Args:
        loading: pandas df, a results EFA loading matrix
        clustering: pandas df, a results HCA clustering
        title (optional): str, title to plot
        break_lines: whether to separate EFA heatmap based on clusters, default=True
        drop_list (optional): list of cluster indices to drop the cluster label
        drop_list (optional): list of cluster indices to drop the cluster label twice
        absolute_loading: whether to plot the absolute loading value, default False
        plot_dir: if set, where to save the plot
        
    """


    c = loading.shape[1]
    # extract cluster vars
    link = clustering['linkage']
    DVs = clustering['clustered_df'].columns
    ordered_loading = loading.loc[DVs]
    if absolute_loading:
        ordered_loading = abs(ordered_loading)
    # get cluster sizes
    labels=clustering['labels']
    cluster_sizes = [np.sum(labels==(i+1)) for i in range(max(labels))]
    link_function, colors = get_dendrogram_color_fun(link, clustering['reorder_vec'],
                                                     labels)
    
    # set figure properties
    figsize = (size, size*.6)
    # set up axes' size 
    heatmap_height = ordered_loading.shape[1]*.035
    heat_size = [.1, heatmap_height]
    dendro_size=[np.sum(heat_size), .3]
    # set up plot axes
    dendro_size = [.15,dendro_size[0], .78, dendro_size[1]]
    heatmap_size = [.15,heat_size[0],.78,heat_size[1]]
    cbar_size = [.935,heat_size[0],.015,heat_size[1]]
    ordered_loading = ordered_loading.T

    with sns.axes_style('white'):
        fig = plt.figure(figsize=figsize)
        ax1 = fig.add_axes(dendro_size) 
        # **********************************
        # plot dendrogram
        # **********************************
        with plt.rc_context({'lines.linewidth': size*.125}):
            dendrogram(link, ax=ax1, link_color_func=link_function,
                       orientation='top')
        # change axis properties
        ax1.tick_params(axis='x', which='major', labelsize=14,
                        labelbottom=False)
        ax1.get_yaxis().set_visible(False)
        ax1.spines['top'].set_visible(False)
        ax1.spines['right'].set_visible(False)
        ax1.spines['bottom'].set_visible(False)
        ax1.spines['left'].set_visible(False)
        # **********************************
        # plot loadings as heatmap below
         # **********************************
        ax2 = fig.add_axes(heatmap_size)
        cbar_ax = fig.add_axes(cbar_size)
        max_val = np.max(abs(loading.values))
        # bring to closest .25
        max_val = ceil(max_val*4)/4
        sns.heatmap(ordered_loading, ax=ax2, 
                    cbar=True, cbar_ax=cbar_ax,
                    yticklabels=True,
                    xticklabels=True,
                    vmax =  max_val, vmin = -max_val,
                    cbar_kws={'orientation': 'vertical',
                              'ticks': [-max_val, 0, max_val]},
                    cmap=sns.diverging_palette(220,15,n=100,as_cmap=True))
        ax2.set_yticklabels(ax2.get_yticklabels(), rotation=0)
        ax2.tick_params(axis='y', labelsize=size*heat_size[1]*30/c, pad=size/4, length=0)            
        # format cbar axis
        cbar_ax.set_yticklabels([format_num(-max_val), 0, format_num(max_val)])
        cbar_ax.tick_params(labelsize=size*heat_size[1]*25/c, length=0, pad=size/2)
        cbar_ax.set_ylabel('Factor Loading', rotation=-90, 
                       fontsize=size*heat_size[1]*30/c, labelpad=size*2)
        # add lines to heatmap to distinguish clusters
        if break_lines == True:
            xlim = ax2.get_xlim(); 
            ylim = ax2.get_ylim()
            step = xlim[1]/len(labels)
            cluster_breaks = [i*step for i in np.cumsum(cluster_sizes)]
            ax2.vlines(cluster_breaks[:-1], ylim[0], ylim[1], linestyles='dashed',
                       linewidth=size*.1, colors=[.5,.5,.5], zorder=10)
        # **********************************
        # plot cluster names
        # **********************************
        beginnings = np.hstack([[0],np.cumsum(cluster_sizes)[:-1]])
        centers = beginnings+np.array(cluster_sizes)//2+.5
        offset = .07
        if 'cluster_names' in clustering.keys():
            ax2.tick_params(axis='x', reset=True, top=False, bottom=False, width=size/8, length=0)
#.........这里部分代码省略.........
开发者ID:IanEisenberg,项目名称:Self_Regulation_Ontology,代码行数:101,代码来源:HCA_plots.py


示例18: plot_fluxnet_comparison_one_site

def plot_fluxnet_comparison_one_site(driver, science_test_data_dir,
                                     compare_data_dict, result_dir, plot_dir,
                                     plots_to_make, context, style, var_names,
                                     months, obs_dir, subdir):

    if check_site_files(obs_dir, subdir):
        # get CSV file from site directory to get lat/lng for site
        lat, lng = get_fluxnet_lat_lon(obs_dir, subdir)
        print(lat, lng)

        # loop over data to compare
        data = {}
        for key, items in compare_data_dict.items():

            if key == "ecflux":
                try:
                    # load Ameriflux data
                    data[key] = read_fluxnet_obs(subdir,
                                                 science_test_data_dir,
                                                 items)
                except OSError:
                    warnings.warn(
                        "this %s site does not have data" % subdir)

            elif key == "VIC.4.2.d":
                try:
                    # load VIC 4.2 simulations
                    data[key] = read_vic_42_output(lat, lng,
                                                   science_test_data_dir,
                                                   items)

                except OSError:
                    warnings.warn(
                        "this site has a lat/lng precision issue")

            else:
                try:
                    # load VIC 5 simulations
                    data[key] = read_vic_5_output(lat, lng,
                                                  result_dir,
                                                  items)
                except OSError:
                    warnings.warn(
                        "this site has a lat/lng precision issue")

        # make figures

        # plot preferences
        fs = 15
        dpi = 150

        if 'annual_mean_diurnal_cycle' in plots_to_make:

            # make annual mean diurnal cycle plots
            with plt.rc_context(dict(sns.axes_style(style),
                                     **sns.plotting_context(context))):
                f, axarr = plt.subplots(4, 1, figsize=(8, 8), sharex=True)

                for i, (vic_var, variable_name) in enumerate(
                        var_names.items()):

                    # calculate annual mean diurnal cycle for each
                    # DataFrame
                    annual_mean = {}
                    for key, df in data.items():
                        annual_mean[key] = pd.DataFrame(
                            df[vic_var].groupby(df.index.hour).mean())

                    df = pd.DataFrame(
                        {key: d[vic_var] for key, d in annual_mean.items()
                         if vic_var in d})

                    for key, series in df.iteritems():
                        series.plot(
                            linewidth=compare_data_dict[key]['linewidth'],
                            ax=axarr[i],
                            color=compare_data_dict[key]['color'],
                            linestyle=compare_data_dict[key]['linestyle'],
                            zorder=compare_data_dict[key]['zorder'])

                    axarr[i].legend(loc='upper left')
                    axarr[i].set_ylabel(
                        '%s ($W/{m^2}$)' % variable_name,
                        size=fs)
                    axarr[i].set_xlabel('Time of Day (Hour)', size=fs)
                    axarr[i].set_xlim([0, 24])
                    axarr[i].xaxis.set_ticks(np.arange(0, 24, 3))

                # save plot
                plotname = '%s_%s.png' % (lat, lng)
                os.makedirs(os.path.join(plot_dir, 'annual_mean'),
                            exist_ok=True)
                savepath = os.path.join(plot_dir, 'annual_mean', plotname)
                plt.savefig(savepath, bbox_inches='tight', dpi=dpi)

                plt.clf()
                plt.close()

        if 'monthly_mean_diurnal_cycle' in plots_to_make:

#.........这里部分代码省略.........
开发者ID:BramDr,项目名称:VIC,代码行数:101,代码来源:test_utils.py


示例19: plot_snotel_comparison_one_site

def plot_snotel_comparison_one_site(
        driver, science_test_data_dir,
        compare_data_dict,
        result_dir, plot_dir,
        plots_to_make,
        plot_variables, context, style, filename):
    
    print(plots_to_make)
    
    # get lat/lng from filename
    file_split = re.split('_', filename)
    lng = file_split[3].split('.txt')[0]
    lat = file_split[2]
    print('Plotting {} {}'.format(lat, lng))

    # loop over data to compare
    data = {}
    for key, items in compare_data_dict.items():

        # read in data
        if key == "snotel":
            data[key] = read_snotel_swe_obs(filename,
                                            science_test_data_dir,
                                            items)

        elif key == "VIC.4.2.d":
            data[key] = read_vic_42_output(lat, lng,
                                           science_test_data_dir,
                                           items)

        else:
            data[key] = read_vic_5_output(lat, lng,
                                          result_dir,
                                          items)

    # loop over variables to plot
    for plot_variable, units in plot_variables.items():

        if 'water_year' in plots_to_make:

            with plt.rc_context(dict(sns.axes_style(style),
                                     **sns.plotting_context(context))):
                fig, ax = plt.subplots(figsize=(10, 10))

                df = pd.DataFrame({key: d[plot_variable] for key, d in
                                   data.items() if plot_variable in d})

                for key, series in df.iteritems():
                    series.plot(
                        use_index=True,
                        linewidth=compare_data_dict[key]['linewidth'],
                        ax=ax,
                        color=compare_data_dict[key]['color'],
                        linestyle=compare_data_dict[key]
                        ['linestyle'],
                        zorder=compare_data_dict[key]['zorder'])

                ax.legend(loc='upper left')
                ax.set_ylabel("%s [%s]" % (plot_variable, units))

                # save figure
                os.makedirs(os.path.join(plot_dir, plot_variable),
                            exist_ok=True)
                plotname = '%s_%s.png' % (lat, lng)
                savepath = os.path.join(plot_dir, plot_variable, plotname)
                plt.savefig(savepath, bbox_inches='tight')
                print(savepath)
                plt.clf()
                plt.close()
开发者ID:BramDr,项目名称:VIC,代码行数:69,代码来源:test_utils.py


示例20: draw_termite_plot

def draw_termite_plot(values_mat, col_labels, row_labels,
                      highlight_cols=None, highlight_colors=None,
                      save=False):
    """
    Make a "termite" plot, typically used for assessing topic models with a tabular
    layout that promotes comparison of terms both within and across topics.

    Args:
        values_mat (``np.ndarray`` or matrix): matrix of values with shape
            (# row labels, # col labels) used to size the dots on the grid
        col_labels (seq[str]): labels used to identify x-axis ticks on the grid
        row_labels(seq[str]): labels used to identify y-axis ticks on the grid
        highlight_cols (int or seq[int], optional): indices for columns
            to visually highlight in the plot with contrasting colors
        highlight_colors (tuple of 2-tuples): each 2-tuple corresponds to a pair
            of (light/dark) matplotlib-friendly colors used to highlight a single
            column; if not specified (default), a good set of 6 pairs are used
        save (str, optional): give the full /path/to/fname on disk to save figure

    Returns:
        ``matplotlib.axes.Axes.axis``: axis on which termite plot is plotted

    Raises:
        ValueError: if more columns are selected for highlighting than colors
            or if any of the inputs' dimensions don't match

    References:
        .. Chuang, Jason, Christopher D. Manning, and Jeffrey Heer. "Termite:
            Visualization techniques for assessing textual topic models."
            Proceedings of the International Working Conference on Advanced
            Visual Interfaces. ACM, 2012.

    .. seealso:: :func:`TopicModel.termite_plot <textacy.tm.TopicModel.termite_plot>`
    """
    try:
        plt
    except NameError:
        raise ImportError(
            'matplotlib is not installed, so textacy.viz won\'t work; install it \
            individually, or along with textacy via `pip install textacy[viz]`')
    n_rows, n_cols = values_mat.shape
    max_val = np.max(values_mat)

    if n_rows != len(row_labels):
        msg = "values_mat and row_labels dimensions don't match: {} vs. {}".format(
            n_rows, len(row_labels))
        raise ValueError(msg)
    if n_cols != len(col_labels):
        msg = "values_mat and col_labels dimensions don't match: {} vs. {}".format(
            n_cols, len(col_labels))
        raise ValueError(msg)

    if highlight_colors is None:
        highlight_colors = COLOR_PAIRS
    if highlight_cols is not None:
        if isinstance(highlight_cols, int):
            highlight_cols = (highlight_cols,)
        elif len(highlight_cols) > len(highlight_colors):
            msg = 'no more than {} columns may be highlighted at once'.format(
                len(highlight_colors))
            raise ValueError(msg)
        highlight_colors = {hc: COLOR_PAIRS[i]
                            for i, hc in enumerate(highlight_cols)}

    with plt.rc_context(RC_PARAMS):
        fig, ax = plt.subplots(figsize=(pow(n_cols, 0.8), pow(n_rows, 0.66)))

        _ = ax.set_yticks(range(n_rows))
        yticklabels = ax.set_yticklabels(row_labels,
                                         fontsize=14, color='gray')
        if highlight_cols is not None:
            for i, ticklabel in enumerate(yticklabels):
                max_tick_val = max(values_mat[i, hc] for hc in highlight_cols)
                for hc in highlight_cols:
                    if max_tick_val > 0 and values_mat[i, hc] == max_tick_val:
                        ticklabel.set_color(highlight_colors[hc][1])

        ax.get_xaxis().set_ticks_position('top')
        _ = ax.set_xticks(range(n_cols))
        xticklabels = ax.set_xticklabels(col_labels,
                                         fontsize=14, color='gray',
                                         rotation=30, ha='left')
        if highlight_cols is not None:
            gridlines = ax.get_xgridlines()
            for i, ticklabel in enumerate(xticklabels):
                if i in highlight_cols:
                    ticklabel.set_color(highlight_colors[i][1])
                    gridlines[i].set_color(highlight_colors[i][0])
                    gridlines[i].set_alpha(0.5)

        for col_ind in range(n_cols):
            if highlight_cols is not None and col_ind in highlight_cols:
                ax.scatter([col_ind for _ in range(n_rows)],
                           [i for i in range(n_rows)],
                           s=600 * (values_mat[:, col_ind] / max_val),
                           alpha=0.5, linewidth=1,
                           color=highlight_colors[col_ind][0],
                           edgecolor=highlight_colors[col_ind][1])
            else:
                ax.scatter([col_ind for _ in range(n_rows)],
#.........这里部分代码省略.........
开发者ID:chartbeat-labs,项目名称:textacy,代码行数:101,代码来源:termite.py



注:本文中的matplotlib.pyplot.rc_context函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python pyplot.rcdefaults函数代码示例发布时间:2022-05-27
下一篇:
Python pyplot.rc函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap