本文整理汇总了Python中pylab.ones函数的典型用法代码示例。如果您正苦于以下问题:Python ones函数的具体用法?Python ones怎么用?Python ones使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了ones函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: validate_once
def validate_once(true_cf = [pl.ones(3)/3.0, pl.ones(3)/3.0], true_std = 0.01*pl.ones(3), std_bias = [1., 1., 1.], save=False, dir='', i=0):
"""
Generate a set of simulated estimates for the provided true cause fractions; Fit the bad model and
the latent simplex model to this simulated data and calculate quality metrics.
"""
# generate simulation data
X = data.sim_data_for_validation(1000, true_cf, true_std, std_bias)
# fit bad model, calculate fit metrics
bad_model = models.bad_model(X)
bad_model_metrics = calc_quality_metrics(true_cf, true_std, std_bias, bad_model)
retrieve_estimates(bad_model, True, 'bad_model', dir, i)
# fit latent simplex model, calculate fit metrics
m, latent_simplex = models.fit_latent_simplex(X)
latent_simplex_metrics = calc_quality_metrics(true_cf, true_std, std_bias, latent_simplex)
retrieve_estimates(latent_simplex, True, 'latent_simplex', dir, i)
# either write results to disk or return them
if save:
pl.rec2csv(bad_model_metrics, '%s/metrics_bad_model_%i.csv' % (dir, i))
pl.rec2csv(latent_simplex_metrics, '%s/metrics_latent_simplex_%i.csv' % (dir, i))
else:
return bad_model_metrics, latent_simplex_metrics
开发者ID:aflaxman,项目名称:pymc-cod-correct,代码行数:25,代码来源:validate_models.py
示例2: tempo_search
def tempo_search(db, Key, tempo):
"""
::
Static tempo-invariant search
Returns search results for query resampled over a range of tempos.
"""
if not db.configCheck():
print "Failed configCheck in query spec."
print db.configQuery
return None
prop = 1.0 / tempo # the proportion of original samples required for new tempo
qconf = db.configQuery.copy()
X = db.retrieve_datum(Key)
P = db.retrieve_datum(Key, powers=True)
X_m = pylab.mat(X.mean(0))
X_resamp = pylab.array(adb.resample_vector(X - pylab.mat(pylab.ones(X.shape[0])).T * X_m, prop))
X_resamp += pylab.mat(pylab.ones(X_resamp.shape[0])).T * X_m
P_resamp = pylab.array(adb.resample_vector(P, prop))
seqStart = int(pylab.around(qconf["seqStart"] * prop))
qconf["seqStart"] = seqStart
seqLength = int(pylab.around(qconf["seqLength"] * prop))
qconf["seqLength"] = seqLength
tmpconf = db.configQuery
db.configQuery = qconf
res = db.query_data(featData=X_resamp, powerData=P_resamp)
res_resorted = adb.sort_search_result(res.rawData)
db.configQuery = tmpconf
return res_resorted
开发者ID:kitefishlabs,项目名称:BregmanToolkit,代码行数:29,代码来源:audiodb.py
示例3: plotInit
def plotInit(Plotting, Elements):
if (Plotting == 2):
loc = [i.xy for i in Elements]
x = [i.real for i in loc]
y = [i.imag for i in loc]
x = list(sorted(set(x)))
x.remove(-10)
y = list(sorted(set(y)))
X, Y = pylab.meshgrid(x, y)
U = pylab.ones(shape(X))
V = pylab.ones(shape(Y))
pylab.ion()
fig, ax = pylab.subplots(1,1)
graph = ax.quiver(X, Y, U, V)
pylab.draw()
else:
pylab.ion()
graph, = pylab.plot(1, 'ro', markersize = 2)
x = 2
pylab.axis([-x,x,x,-x])
graph.set_xdata(0)
graph.set_ydata(0)
pylab.draw()
return graph
开发者ID:devyeshtandon,项目名称:ParticleMethods,代码行数:28,代码来源:Plotting.py
示例4: filter2d
def filter2d(x, y, axes=['y'], algos=['2sigma']):
"""
Perform 2D data filtration by selected exes.
In:
x : ndarray, X vector
y : ndarray, Y vector
axes : list, axes names which are used to choose filtered values. x, y or any combination
Out:
xnew : ndarray, filtered X
ynew : ndarray, filtered Y
"""
xnew = pl.array(x, dtype='float')
ynew = pl.array(y, dtype='float')
mask_x = pl.ones(len(x), dtype='bool')
mask_y = pl.ones(len(y), dtype='bool')
if 'y' in axes:
mask_y = filter1d(y,algos=algos)
if 'x' in axes:
mask_x = filter1d(x,algos=algos)
mask = mask_x * mask_y
xnew *= mask
ynew *= mask
xnew = pl.ma.masked_equal(xnew,0)
xnew = pl.ma.compressed(xnew)
ynew = pl.ma.masked_equal(ynew,0)
ynew = pl.ma.compressed(ynew)
assert pl.shape(xnew) == pl.shape(ynew)
return xnew, ynew
开发者ID:DanielEColi,项目名称:fnatool,代码行数:30,代码来源:common.py
示例5: example
def example():
from pylab import rand, ones, concatenate
import matplotlib.pyplot as plt
# EXAMPLE data code from:
# http://matplotlib.sourceforge.net/pyplots/boxplot_demo.py
# fake up some data
spread= rand(50) * 100
center = ones(25) * 50
flier_high = rand(10) * 100 + 100
flier_low = rand(10) * -100
data =concatenate((spread, center, flier_high, flier_low), 0)
# fake up some more data
spread= rand(50) * 100
center = ones(25) * 40
flier_high = rand(10) * 100 + 100
flier_low = rand(10) * -100
d2 = concatenate( (spread, center, flier_high, flier_low), 0 )
data.shape = (-1, 1)
d2.shape = (-1, 1)
#data = [data, d2, d2[::2,0]]
data = [data, d2]
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.set_xlim(0,4)
percentile_box_plot(ax, data, [2,3])
plt.show()
开发者ID:boada,项目名称:scripts,代码行数:29,代码来源:boxplot_percentile.py
示例6: jetWoGn
def jetWoGn(reverse=False):
"""
jetWoGn(reverse=False)
- returning a colormap similar to cm.jet, but without green.
if reverse=True, the map starts with red instead of blue.
"""
m=18 # magic number, which works fine
m0=pylab.floor(m*0.0)
m1=pylab.floor(m*0.2)
m2=pylab.floor(m*0.2)
m3=pylab.floor(m/2)-m2-m1
b_ = pylab.hstack( (0.4*pylab.arange(m1)/(m1-1.)+0.6, pylab.ones((m2+m3,)) ) )
g_ = pylab.hstack( (pylab.zeros((m1,)),pylab.arange(m2)/(m2-1.),pylab.ones((m3,))) )
r_ = pylab.hstack( (pylab.zeros((m1,)),pylab.zeros((m2,)),pylab.arange(m3)/(m3-1.)))
r = pylab.hstack((r_,pylab.flipud(b_)))
g = pylab.hstack((g_,pylab.flipud(g_)))
b = pylab.hstack((b_,pylab.flipud(r_)))
if reverse:
r = pylab.flipud(r)
g = pylab.flipud(g)
b = pylab.flipud(b)
ra = pylab.linspace(0.0,1.0,m)
cdict = {'red': zip(ra,r,r),
'green': zip(ra,g,g),
'blue': zip(ra,b,b)}
return LinearSegmentedColormap('new_RdBl',cdict,256)
开发者ID:garciaga,项目名称:pynmd,代码行数:32,代码来源:plot_settings.py
示例7: log_inv
def log_inv(X): # inverts a 3x3 matrix given by the logscale values
if (X.shape[0] != X.shape[1]):
raise Exception("X is not a square matrix and cannot be inverted")
if (X.shape[0] == 1):
return matrix((-X[0,0]))
ldet = log_det(X)
if (ldet == nan):
raise Exception("The determinant of X is 0, cannot calculate the inverse")
if (X.shape[0] == 2): # X is a 2x2 matrix
I = (-log_det(X)) * ones((2,2))
I[0,0] += X[1,1]
I[0,1] += X[0,1] + complex(0, pi)
I[1,0] += X[1,0] + complex(0, pi)
I[1,1] += X[0,0]
return I
if (X.shape[0] == 3): # X is a 3x3 matrix
I = (-log_det(X)) * ones((3,3))
I[0,0] += log_subt_exp(X[1,1]+X[2,2], X[1,2]+X[2,1])
I[0,1] += log_subt_exp(X[0,2]+X[2,1], X[0,1]+X[2,2])
I[0,2] += log_subt_exp(X[0,1]+X[1,2], X[0,2]+X[1,1])
I[1,0] += log_subt_exp(X[2,0]+X[1,2], X[1,0]+X[2,2])
I[1,1] += log_subt_exp(X[0,0]+X[2,2], X[0,2]+X[2,0])
I[1,2] += log_subt_exp(X[0,2]+X[1,0], X[0,0]+X[1,2])
I[2,0] += log_subt_exp(X[1,0]+X[2,1], X[2,0]+X[1,1])
I[2,1] += log_subt_exp(X[2,0]+X[0,1], X[0,0]+X[2,1])
I[2,2] += log_subt_exp(X[0,0]+X[1,1], X[0,1]+X[1,0])
return I
raise Exception("log_inv is only implemented for matrices of size < 4")
开发者ID:issfangks,项目名称:milo-lab,代码行数:33,代码来源:log_matrix.py
示例8: sample
def sample(self, model, evidence):
z = evidence['z']
T, surfaces, sigma_g, sigma_h = [evidence[var] for var in ['T', 'surfaces', 'sigma_g', 'sigma_h']]
mu_h, phi, sigma_z_g, sigma_z_h = [model.known_params[var] for var in ['mu_h', 'phi', 'sigma_z_g', 'sigma_z_h']]
prior_mu_g, prior_cov_g = [model.hyper_params[var] for var in ['prior_mu_g', 'prior_cov_g']]
prior_mu_h, prior_cov_h = [model.hyper_params[var] for var in ['prior_mu_h', 'prior_cov_h']]
n = len(g)
y = ma.asarray(ones((n, 2))*nan)
if sum(T==1) > 0:
y[T==1, 0] = z[T==1]
if sum(T==2) > 0:
y[T==2, 1] = z[T==2]
y[isnan(y)] = ma.masked
kalman = self._kalman
kalman.initial_state_mean=[prior_mu_g[0], prior_mu_h[0]]
kalman.initial_state_covariance=diag([prior_cov_g[0,0], prior_cov_h[0,0]])
kalman.transition_matrices=[[1, 0], [0, phi]]
kalman.transition_offsets =ones((n, 2))*[0, mu_h*(1-phi)]
kalman.transition_covariance=[[sigma_g**2, 0], [0, sigma_h**2]]
kalman.observation_matrices=[[1, 0], [1, 1]]
kalman.observation_covariance=[[sigma_z_g**2, 0], [0, sigma_z_h**2]]
sampled_surfaces = forward_filter_backward_sample(kalman, y)
return sampled_surfaces
开发者ID:bwallin,项目名称:thesis-code,代码行数:26,代码来源:model_simulation_epsilon.py
示例9: run_on_cluster
def run_on_cluster(dir='../data', true_cf = [pl.ones(3)/3.0, pl.ones(3)/3.0], true_std = 0.01*pl.ones(3), std_bias=[1.,1.,1.], reps=5, tag=''):
"""
Runs validate_once multiple times (as specified by reps) for the given true_cf and
true_std. Combines the output and cleans up the temp files. This accomplished in
parallel on the cluster. This function requires that the files cluster_shell.sh
(which allows for submission of a job for each iteration), cluster_validate.py (which
runs validate_once for each iteration), and cluster_validate_combine.py (which
runs combine_output all exist. The tag argument allows for adding a string to the job
names so that this function can be run multiple times simultaneously and not have
conflicts between jobs with the same name.
"""
T, J = pl.array(true_cf).shape
if os.path.exists(dir) == False: os.mkdir(dir)
# write true_cf and true_std to file
data.rec2csv_2d(pl.array(true_cf), '%s/truth_cf.csv' % (dir))
data.rec2csv_2d(pl.array(true_std), '%s/truth_std.csv' % (dir))
data.rec2csv_2d(pl.array([std_bias]), '%s/truth_bias.csv' % (dir))
# submit all individual jobs to retrieve true_cf and true_std and run validate_once
all_names = []
for i in range(reps):
name = 'cc%s_%i' % (tag, i)
call = 'qsub -cwd -N %s cluster_shell.sh cluster_validate.py %i "%s"' % (name, i, dir)
subprocess.call(call, shell=True)
all_names.append(name)
# submit job to run combine_output and clean_up
hold_string = '-hold_jid %s ' % ','.join(all_names)
call = 'qsub -cwd %s -N cc%s_comb cluster_shell.sh cluster_validate_combine.py %i "%s"' % (hold_string, tag, reps, dir)
subprocess.call(call, shell=True)
开发者ID:aflaxman,项目名称:pymc-cod-correct,代码行数:32,代码来源:validate_models.py
示例10: __init__
def __init__(self):
self.ai = ones(NN.ni)
self.ah = ones(NN.nh)
self.ao = ones(NN.no)
self.wi = zeros((NN.ni, NN.nh))
self.wo = zeros((NN.nh, NN.no))
randomizeMatrix(self.wi, -0.2, 0.2)
randomizeMatrix(self.wo, -2.0, 2.0)
开发者ID:mfbx9da4,项目名称:neuron-astrocyte-networks,代码行数:8,代码来源:neuralnetwork.py
示例11: allocate
def allocate(self,n):
"""Allocate space for the internal state variables.
`n` is the maximum sequence length that can be processed."""
ni,ns,na = self.dims
vars = "cix ci gix gi gox go gfx gf"
vars += " state output"
for v in vars.split():
setattr(self,v,nan*ones((n,ns)))
self.source = nan*ones((n,na))
开发者ID:dwohlfahrt,项目名称:ocropy,代码行数:9,代码来源:minilstm.py
示例12: getParamCovMat
def getParamCovMat(prefix,dlogpower = 2, theoconstmult = 1.,dlogfilenames = ['dlogpnldloga.dat'],volume=256.**3,startki = 0, endki = 0, veff = [0.]):
"""
Calculates parameter covariance matrix from the power spectrum covariance matrix and derivative term
in the prefix directory
"""
nparams = len(dlogfilenames)
kpnl = M.load(prefix+'pnl.dat')
k = kpnl[startki:,0]
nk = len(k)
if (endki == 0):
endki = nk
pnl = M.array(kpnl[startki:,1],M.Float64)
covarwhole = M.load(prefix+'covar.dat')
covar = covarwhole[startki:,startki:]
if len(veff) > 1:
sqrt_veff = M.sqrt(veff[startki:])
else:
sqrt_veff = M.sqrt(volume*M.ones(nk))
dlogs = M.reshape(M.ones(nparams*nk,M.Float64),(nparams,nk))
paramFishMat = M.reshape(M.zeros(nparams*nparams*(endki-startki),M.Float64),(nparams,nparams,endki-startki))
paramCovMat = paramFishMat * 0.
# Covariance matrices of dlog's
for param in range(nparams):
if len(dlogfilenames[param]) > 0:
dlogs[param,:] = M.load(prefix+dlogfilenames[param])[startki:,1]
normcovar = M.zeros(M.shape(covar),M.Float64)
for i in range(nk):
normcovar[i,:] = covar[i,:]/(pnl*pnl[i])
M.save(prefix+'normcovar.dat',normcovar)
f = k[1]/k[0]
if (volume == -1.):
volume = (M.pi/k[0])**3
#theoconst = volume * k[1]**3 * f**(-1.5)/(12.*M.pi**2) #1 not 0 since we're starting at 1
for ki in range(1,endki-startki):
for p1 in range(nparams):
for p2 in range(nparams):
paramFishMat[p1,p2,ki] = M.sum(M.sum(\
M.inverse(normcovar[:ki+1,:ki+1]) *
M.outerproduct(dlogs[p1,:ki+1]*sqrt_veff[:ki+1],\
dlogs[p2,:ki+1]*sqrt_veff[:ki+1])))
paramCovMat[:,:,ki] = M.inverse(paramFishMat[:,:,ki])
return k[1:],paramCovMat[:,:,1:]
开发者ID:JohanComparat,项目名称:pyLPT,代码行数:55,代码来源:info.py
示例13: getDR
def getDR(self):
#this function should return the dynamic range
#this should be the noiselevel of the fft
noiselevel=py.sqrt(py.mean(abs(py.fft(self._tdData.getAllPrecNoise()[0]))**2))
#apply a moving average filter on log
window_size=5
window=py.ones(int(window_size))/float(window_size)
hlog=py.convolve(20*py.log10(self.getFAbs()), window, 'valid')
one=py.ones((2,))
hlog=py.concatenate((hlog[0]*one,hlog,hlog[-1]*one))
return hlog-20*py.log10(noiselevel)
开发者ID:DavidJahn86,项目名称:terapy,代码行数:11,代码来源:TeraData.py
示例14: X_obs
def X_obs(pi=pi, sigma=sigma, value=X):
logp = mc.normal_like(pl.array(value).ravel(),
(pl.ones([N,J*T])*pl.array(pi).ravel()).ravel(),
(pl.ones([N,J*T])*pl.array(sigma).ravel()).ravel()**-2)
return logp
logp = pl.zeros(N)
for n in range(N):
logp[n] = mc.normal_like(pl.array(value[n]).ravel(),
pl.array(pi+beta).ravel(),
pl.array(sigma).ravel()**-2)
return mc.flib.logsum(logp - pl.log(N))
开发者ID:ldwyerlindgren,项目名称:pymc-cod-correct,代码行数:12,代码来源:models.py
示例15: __init__
def __init__(self, r_floop=0.5, z_floop=0.0,
i_p_coil_filename='hitpops.05.txt',
tris_filename='hitpops.05.t3d'):
self.r_floop = r_floop
self.z_floop = z_floop
# read equilibrium file
i_p_coils = P.loadtxt(i_p_coil_filename, delimiter=',', dtype=fdtype)
self.i_p_coils = i_p_coils
r_p_coils_full = i_p_coils[:, 0]
z_p_coils_full = i_p_coils[:, 1]
# ??? what is this scale factor, something to do with mu_0 ???
beta = i_p_coils[:, 3] * 6.28e7
i_p_coils_full = i_p_coils[:, 2]
self.r_p_coils_full = r_p_coils_full
self.z_p_coils_full = z_p_coils_full
self.beta = beta
self.i_p_coils_full = i_p_coils_full
# choose subset where current is not zero
sub = P.where(i_p_coils_full != 0.0)
r_p_coils = r_p_coils_full[sub]
z_p_coils = z_p_coils_full[sub]
i_p_coils = i_p_coils_full[sub]
n_p_coils = len(r_p_coils)
self.r_p_coils = r_p_coils
self.z_p_coils = z_p_coils
self.i_p_coils = i_p_coils
self.n_p_coils = n_p_coils
r_p_widths = P.ones(n_p_coils, dtype=fdtype) * 0.05
z_p_widths = 1.0 * r_p_widths
n_r_p_filaments = P.ones(n_p_coils, dtype=idtype)
n_z_p_filaments = 1 * n_r_p_filaments
self.r_p_widths = r_p_widths
self.z_p_widths = z_p_widths
self.n_r_p_filaments = n_r_p_filaments
self.n_z_p_filaments = n_z_p_filaments
# read in triangle unstructured mesh information
rzt, tris, pt = t3dinp(tris_filename)
self.rzt = rzt
self.tris = tris
self.pt = pt
开发者ID:zchmlk,项目名称:Coil-GUI,代码行数:52,代码来源:plasma_coil_object.py
示例16: __convertToFloats__
def __convertToFloats__(self, signal, annotation, time):
"""
method converts all string values in signal, annotation arrays
into float values;
here is one assumption: time array is in float format already
"""
floats = pl.ones(len(signal))
if annotation == None:
entities = zip(signal)
else:
entities = zip(signal, annotation)
for idx, values in enumerate(entities):
for value in values:
try:
pl.float64(value) # check if it can be converted to float
except ValueError:
floats[idx] = 0 # the value is NOT like float type
break
true_floats = pl.nonzero(floats) # get indexes of non-zero positions
signal = signal[true_floats].astype(float)
if not annotation == None:
annotation = annotation[true_floats].astype(float)
if not time == None:
time = time[true_floats]
return signal, annotation, time
开发者ID:TEAM-HRA,项目名称:hra_suite,代码行数:27,代码来源:data_vector_file_data_source.py
示例17: parse_task_object_data
def parse_task_object_data(bhv):
"""Convert all the objects into image data and parse their initial positions."""
obj_data = bhv['Stimuli']['Pic'] #Only handling pics now
obj_r = re.compile("(\w+)\(") #Regexp to find task object description
args_r = re.compile("([-.\w]+)[,\)]")#Regexp to extract arguments
to = bhv['TaskObject']
objects = []
initial_pos = []
for n in xrange(len(to)):
oname = obj_r.findall(to[n][0])[0]
if oname == 'fix':
odata = pylab.ones((5,5,3),dtype=float)#Arbitrary square for FP
args = args_r.findall(to[n][0])
p = [float(p) for p in args]
elif oname =='pic':
args = args_r.findall(to[n][0])
picname = args[0] #First one is object name
p = [float(p) for p in args[1:]]
for oidx in xrange(len(obj_data)):
if obj_data[oidx]['Name'] == picname:
odata = obj_data[oidx]['Data']/255.0 #matplotlib needs [0,1]
break
else:
odata = pylab.zeros((4,4,3))
logger.error('Could not find object')
objects.append(odata)
initial_pos.append(p)
return objects, pylab.array(initial_pos)
开发者ID:kghose,项目名称:neurapy,代码行数:31,代码来源:moviemaker.py
示例18: datagen
def datagen(N):
"""
Produces N pairs of training data and desired output;
each sample of training data contains -1 in its first position,
this corresponds to the interpretation of the threshold as first
element of the weight vector
"""
fun1 = lambda x1,x2: -2*x1**3-x2+.5*x1**2
fun2 = lambda x1,x2: x1**2*x2+2*x1*x2+1
fun3 = lambda x1,x2: .5*x1*x2**2+x2**2-2*x1**2
rarr1 = rand(1,N)
rarr2 = rand(1,N)
teacher = sign(rand(1,N)-.5)
idplus = (teacher<0)
idminus = -idplus
rarr1[idplus] = rarr1[idplus]-1
y1=fun1(rarr1,rarr2)
y2=fun2(rarr1,rarr2)
y3=fun3(rarr1,rarr2)
x=transpose(concatenate((-ones((1,N)),y1,y2)))
return x, teacher[0]
开发者ID:albert4git,项目名称:aTest,代码行数:29,代码来源:datagen.py
示例19: _istftm
def _istftm(self, X_hat=None, Phi_hat=None, pvoc=False, usewin=True, resamp=None):
"""
::
Inverse short-time Fourier transform magnitude. Make a signal from a |STFT| transform.
Uses phases from self.STFT if Phi_hat is None.
Inputs:
X_hat - N/2+1 magnitude STFT [None=abs(self.STFT)]
Phi_hat - N/2+1 phase STFT [None=exp(1j*angle(self.STFT))]
pvoc - whether to use phase vocoder [False]
usewin - whether to use overlap-add [False]
Returns:
x_hat - estimated signal
"""
if not self._have_stft:
return None
X_hat = P.np.abs(self.STFT) if X_hat is None else P.np.abs(X_hat)
if pvoc:
self._pvoc(X_hat, Phi_hat, pvoc)
else:
Phi_hat = P.angle(self.STFT) if Phi_hat is None else Phi_hat
self.X_hat = X_hat * P.exp( 1j * Phi_hat )
if usewin:
self.win = P.hanning(self.nfft)
self.win *= 1.0 / ((float(self.nfft)*(self.win**2).sum())/self.nhop)
else:
self.win = P.ones(self.nfft)
if resamp:
self.win = sig.resample(self.win, int(P.np.round(self.nfft * resamp)))
fp = self._check_feature_params()
self.x_hat = self._overlap_add(P.real(self.nfft * P.irfft(self.X_hat.T)), usewin=usewin, resamp=resamp)
if self.verbosity:
print "Extracted iSTFTM->self.x_hat"
return self.x_hat
开发者ID:StevenLOL,项目名称:BregmanToolkit,代码行数:35,代码来源:features_base.py
示例20: rank_by_distance_bhatt
def rank_by_distance_bhatt(self, qkeys, ikeys, rkeys, dists):
"""
::
Reduce timbre-channel distances to ranks list by ground-truth key indices
Bhattacharyya distance on timbre-channel probabilities and Kullback distances
"""
# timbre-channel search using pre-computed distances
ranks_list = []
t_keys, t_lens = self.get_adb_lists(0)
rdists=pylab.ones(len(t_keys))*float('inf')
qk = self._get_probs_tc(qkeys)
for i in range(len(ikeys[0])): # number of include keys
ikey=[]
dk = pylab.zeros(self.timbre_channels)
for t_chan in range(self.timbre_channels): # timbre channels
ikey.append(ikeys[t_chan][i])
try:
# find dist of key i for query
i_idx = rkeys[t_chan].index( ikey[t_chan] ) # dataset include-key match
# the reduced distance function in include_keys order
# distance is Bhattacharyya distance on probs and dists
dk[t_chan] = dists[t_chan][i_idx]
except:
print "Key not found in result list: ", ikey, "for query:", qkeys[t_chan]
raise error.BregmanError()
rk = self._get_probs_tc(ikey)
a_idx = t_keys.index( ikey[0] ) # audiodb include-key index
rdists[a_idx] = distance.bhatt(pylab.sqrt(pylab.absolute(dk)), pylab.sqrt(pylab.absolute(qk*rk)))
#search for the index of the relevant keys
rdists = pylab.absolute(rdists)
sort_idx = pylab.argsort(rdists) # Sort fields into database order
for r in self.ground_truth: # relevant keys
ranks_list.append(pylab.where(sort_idx==r)[0][0]) # Rank of the relevant key
return ranks_list, rdists
开发者ID:BinRoot,项目名称:BregmanToolkit,代码行数:35,代码来源:evaluate.py
注:本文中的pylab.ones函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论