def _eval_expand_trig(self, **hints):
arg = self.args[0]
x = None
if arg.is_Add:
from sympy import symmetric_poly
n = len(arg.args)
TX = []
for x in arg.args:
tx = tan(x, evaluate=False)._eval_expand_trig()
TX.append(tx)
Yg = numbered_symbols('Y')
Y = [ Yg.next() for i in xrange(n) ]
p = [0,0]
for i in xrange(n+1):
p[1-i%2] += symmetric_poly(i,Y)*(-1)**((i%4)//2)
return (p[0]/p[1]).subs(zip(Y,TX))
else:
coeff, terms = arg.as_coeff_Mul(rational=True)
if coeff.is_Integer and coeff > 1:
I = S.ImaginaryUnit
z = C.Symbol('dummy',real=True)
P = ((1+I*z)**coeff).expand()
return (C.im(P)/C.re(P)).subs([(z,tan(terms))])
return tan(arg)
def cse(exprs, symbols=None, optimizations=None):
""" Perform common subexpression elimination on an expression.
Parameters:
exprs : list of sympy expressions, or a single sympy expression
The expressions to reduce.
symbols : infinite iterator yielding unique Symbols
The symbols used to label the common subexpressions which are pulled
out. The `numbered_symbols` generator is useful. The default is a stream
of symbols of the form "x0", "x1", etc. This must be an infinite
iterator.
optimizations : list of (callable, callable) pairs, optional
The (preprocessor, postprocessor) pairs. If not provided,
`sympy.simplify.cse.cse_optimizations` is used.
Returns:
replacements : list of (Symbol, expression) pairs
All of the common subexpressions that were replaced. Subexpressions
earlier in this list might show up in subexpressions later in this list.
reduced_exprs : list of sympy expressions
The reduced expressions with all of the replacements above.
"""
if symbols is None:
symbols = numbered_symbols()
else:
# In case we get passed an iterable with an __iter__ method instead of
# an actual iterator.
symbols = iter(symbols)
seen_subexp = set()
muls = set()
adds = set()
to_eliminate = []
to_eliminate_ops_count = []
if optimizations is None:
# Pull out the default here just in case there are some weird
# manipulations of the module-level list in some other thread.
optimizations = list(cse_optimizations)
# Handle the case if just one expression was passed.
if isinstance(exprs, Basic):
exprs = [exprs]
# Preprocess the expressions to give us better optimization opportunities.
exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
# Find all of the repeated subexpressions.
def insert(subtree):
'''This helper will insert the subtree into to_eliminate while
maintaining the ordering by op count and will skip the insertion
if subtree is already present.'''
ops_count = subtree.count_ops()
index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count)
# all i up to this index have op count <= the current op count
# so check that subtree is not yet present from this index down
# (if necessary) to zero.
for i in xrange(index_to_insert - 1, -1, -1):
if to_eliminate_ops_count[i] == ops_count and \
subtree == to_eliminate[i]:
return # already have it
to_eliminate_ops_count.insert(index_to_insert, ops_count)
to_eliminate.insert(index_to_insert, subtree)
for expr in exprs:
pt = preorder_traversal(expr)
for subtree in pt:
if subtree.is_Atom:
# Exclude atoms, since there is no point in renaming them.
continue
if subtree in seen_subexp:
insert(subtree)
pt.skip()
continue
if subtree.is_Mul:
muls.add(subtree)
elif subtree.is_Add:
adds.add(subtree)
seen_subexp.add(subtree)
# process adds - any adds that weren't repeated might contain
# subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common
adds = [set(a.args) for a in adds]
for i in xrange(len(adds)):
for j in xrange(i + 1, len(adds)):
com = adds[i].intersection(adds[j])
if len(com) > 1:
insert(Add(*com))
# remove this set of symbols so it doesn't appear again
adds[i] = adds[i].difference(com)
adds[j] = adds[j].difference(com)
for k in xrange(j + 1, len(adds)):
if not com.difference(adds[k]):
adds[k] = adds[k].difference(com)
# process muls - any muls that weren't repeated might contain
#.........这里部分代码省略.........
开发者ID:Jerryy,项目名称:sympy,代码行数:101,代码来源:cse_main.py
示例6: extract_sub_expressions
def extract_sub_expressions(self, cache_prefix='cache', sub_prefix='sub', prefix='XoXoXoX'):
# Do the common sub expression elimination.
common_sub_expressions, expression_substituted_list = sym.cse(self.expression_list, numbered_symbols(prefix=prefix))
self.variables[cache_prefix] = []
self.variables[sub_prefix] = []
# Create dictionary of new sub expressions
sub_expression_dict = {}
for var, void in common_sub_expressions:
sub_expression_dict[var.name] = var
# Sort out any expression that's dependent on something that scales with data size (these are listed in cacheable).
cacheable_list = []
params_change_list = []
# common_sube_expressions contains a list of paired tuples with the new variable and what it equals
for var, expr in common_sub_expressions:
arg_list = [e for e in expr.atoms() if e.is_Symbol]
# List any cacheable dependencies of the sub-expression
cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in self.cacheable_vars]
if cacheable_symbols:
# list which ensures dependencies are cacheable.
cacheable_list.append(var)
else:
params_change_list.append(var)
replace_dict = {}
for i, expr in enumerate(cacheable_list):
sym_var = sym.var(cache_prefix + str(i))
self.variables[cache_prefix].append(sym_var)
replace_dict[expr.name] = sym_var
for i, expr in enumerate(params_change_list):
sym_var = sym.var(sub_prefix + str(i))
self.variables[sub_prefix].append(sym_var)
replace_dict[expr.name] = sym_var
for replace, void in common_sub_expressions:
for expr, keys in zip(expression_substituted_list, self.expression_keys):
setInDict(self.expressions, keys, expr.subs(replace, replace_dict[replace.name]))
for void, expr in common_sub_expressions:
expr = expr.subs(replace, replace_dict[replace.name])
# Replace original code with code including subexpressions.
for keys in self.expression_keys:
for replace, void in common_sub_expressions:
setInDict(self.expressions, keys, getFromDict(self.expressions, keys).subs(replace, replace_dict[replace.name]))
self.expressions['parameters_changed'] = {}
self.expressions['update_cache'] = {}
for var, expr in common_sub_expressions:
for replace, void in common_sub_expressions:
expr = expr.subs(replace, replace_dict[replace.name])
if var in cacheable_list:
self.expressions['update_cache'][replace_dict[var.name].name] = expr
else:
self.expressions['parameters_changed'][replace_dict[var.name].name] = expr
开发者ID:Imdrail,项目名称:GPy,代码行数:57,代码来源:symbolic.py
示例7: cgen_ncomp
#.........这里部分代码省略.........
tpf = (xFj - xPj)/(xTj - xPj)
xP = [(((xF[i]/ppf)*(beta[i]**(NT+1) - 1))/(beta[i]**(NT+1) - beta[i]**(-NP))) \
for i in r]
xT = [(((xF[i]/tpf)*(1 - beta[i]**(-NP)))/(beta[i]**(NT+1) - beta[i]**(-NP))) \
for i in r]
rfeed = xFj / xF[k]
rprod = xPj / xP[k]
rtail = xTj / xT[k]
# setup constraint equations
numer = [ppf*xP[i]*log(rprod) + tpf*xT[i]*log(rtail) - xF[i]*log(rfeed) for i in r]
denom = [log(beta[j]) * ((beta[i] - 1.0)/(beta[i] + 1.0)) for i in r]
LoverF = sum([n/d for n, d in zip(numer, denom)])
SWUoverF = -1.0 * sum(numer)
SWUoverP = SWUoverF / ppf
prod_constraint = (xPj/xFj)*ppf - (beta[j]**(NT+1) - 1)/\
(beta[j]**(NT+1) - beta[j]**(-NP))
tail_constraint = (xTj/xFj)*(sum(xT)) - (1 - beta[j]**(-NP))/\
(beta[j]**(NT+1) - beta[j]**(-NP))
#xp_constraint = 1.0 - sum(xP)
#xf_constraint = 1.0 - sum(xF)
#xt_constraint = 1.0 - sum(xT)
# This is NT(NP,...) and is correct!
#nt_closed = solve(prod_constraint, NT)[0]
# However, this is NT(NP,...) rewritten (by hand) to minimize the number of NP
# and M* instances in the expression. Luckily this is only depends on the key
# component and remains general no matter the number of components.
nt_closed = (-MW[0]*log(alpha) + Mstar*log(alpha) + log(xTj) + log((-1.0 + xPj/\
xF[0])/(xPj - xTj)) - log(alpha**(NP*(MW[0] - Mstar))*(xF[0]*xPj - xPj*xTj)/\
(-xF[0]*xPj + xF[0]*xTj) + 1))/((MW[0] - Mstar)*log(alpha))
# new expression for normalized flow rate
# NOTE: not needed, solved below
#loverf = LoverF.xreplace({NT: nt_closed})
# Define the constraint equation with which to solve NP. This is chosen such to
# minimize the number of ops in the derivatives (and thus np_closed). Other,
# more verbose possibilities are commented out.
#np_constraint = (xP[j]/sum(xP) - xPj).xreplace({NT: nt_closed})
#np_constraint = (xP[j]- sum(xP)*xPj).xreplace({NT: nt_closed})
#np_constraint = (xT[j]/sum(xT) - xTj).xreplace({NT: nt_closed})
np_constraint = (xT[j] - sum(xT)*xTj).xreplace({NT: nt_closed})
# get closed form approximation of NP via symbolic derivatives
stat = _aggstatus(stat, " order-{0} NP approximation".format(nporder), aggstat)
d0NP = np_constraint.xreplace({NP: NP0})
d1NP = diff(np_constraint, NP, 1).xreplace({NP: NP0})
if 1 == nporder:
np_closed = NP0 - d1NP / d0NP
elif 2 == nporder:
d2NP = diff(np_constraint, NP, 2).xreplace({NP: NP0})/2.0
# taylor series polynomial coefficients, grouped by order
# f(x) = ax**2 + bx + c
a = d2NP
b = d1NP - 2*NP0*d2NP
c = d0NP - NP0*d1NP + NP0*NP0*d2NP
# quadratic eq. (minus only)
#np_closed = (-b - sqrt(b**2 - 4*a*c)) / (2*a)
# However, we need to break up this expr as follows to prevent
# a floating point arithmetic bug if b**2 - 4*a*c is very close
# to zero but happens to be negative. LAME!!!
np_2a = 2*a
np_sqrt_base = b**2 - 4*a*c
np_closed = (-NP_b - sqrt(NP_sqrt_base)) / (NP_2a)
else:
raise ValueError("nporder must be 1 or 2")
# generate cse for writing out
msg = " minimizing ops by eliminating common sub-expressions"
stat = _aggstatus(stat, msg, aggstat)
exprstages = [Eq(NP_b, b), Eq(NP_2a, np_2a),
# fix for floating point sqrt() error
Eq(NP_sqrt_base, np_sqrt_base), Eq(NP_sqrt_base, Abs(NP_sqrt_base)),
Eq(NP1, np_closed), Eq(NT1, nt_closed).xreplace({NP: NP1})]
cse_stages = cse(exprstages, numbered_symbols('n'))
exprothers = [Eq(LpF, LoverF), Eq(PpF, ppf), Eq(TpF, tpf),
Eq(SWUpF, SWUoverF), Eq(SWUpP, SWUoverP)] + \
[Eq(*z) for z in zip(xPi, xP)] + [Eq(*z) for z in zip(xTi, xT)]
exprothers = [e.xreplace({NP: NP1, NT: NT1}) for e in exprothers]
cse_others = cse(exprothers, numbered_symbols('g'))
exprops = count_ops(exprstages + exprothers)
cse_ops = count_ops(cse_stages + cse_others)
msg = " reduced {0} ops to {1}".format(exprops, cse_ops)
stat = _aggstatus(stat, msg, aggstat)
# create function body
ccode, repnames = cse_to_c(*cse_stages, indent=6, debug=debug)
ccode_others, repnames_others = cse_to_c(*cse_others, indent=6, debug=debug)
ccode += ccode_others
repnames |= repnames_others
msg = " completed in {0:.3G} s".format(time.time() - start_time)
stat = _aggstatus(stat, msg, aggstat)
if aggstat:
print(stat)
return ccode, repnames, stat
def _eval_rewrite_as_sqrt(self, arg):
_EXPAND_INTS = False
def migcdex(x):
# recursive calcuation of gcd and linear combination
# for a sequence of integers.
# Given (x1, x2, x3)
# Returns (y1, y1, y3, g)
# such that g is the gcd and x1*y1+x2*y2+x3*y3 - g = 0
# Note, that this is only one such linear combination.
if len(x) == 1:
return (1, x[0])
if len(x) == 2:
return igcdex(x[0], x[-1])
g = migcdex(x[1:])
u, v, h = igcdex(x[0], g[-1])
return tuple([u] + [v*i for i in g[0:-1] ] + [h])
def ipartfrac(r, factors=None):
if isinstance(r, int):
return r
assert isinstance(r, C.Rational)
n = r.q
if 2 > r.q*r.q:
return r.q
if None == factors:
a = [n/x**y for x, y in factorint(r.q).iteritems()]
else:
a = [n/x for x in factors]
if len(a) == 1:
return [ r ]
h = migcdex(a)
ans = [ r.p*C.Rational(i*j, r.q) for i, j in zip(h[:-1], a) ]
assert r == sum(ans)
return ans
pi_coeff = _pi_coeff(arg)
if pi_coeff is None:
return None
assert not pi_coeff.is_integer, "should have been simplified already"
if not pi_coeff.is_Rational:
return None
cst_table_some = {
3: S.Half,
5: (sqrt(5) + 1)/4,
17: sqrt((15 + sqrt(17))/32 + sqrt(2)*(sqrt(17 - sqrt(17)) +
sqrt(sqrt(2)*(-8*sqrt(17 + sqrt(17)) - (1 - sqrt(17))
*sqrt(17 - sqrt(17))) + 6*sqrt(17) + 34))/32)
# 65537 and 257 are the only other known Fermat primes
# Please add if you would like them
}
def fermatCoords(n):
assert isinstance(n, int)
assert n > 0
if n == 1 or 0 == n % 2:
return False
primes = dict( [(p, 0) for p in cst_table_some ] )
assert 1 not in primes
for p_i in primes:
while 0 == n % p_i:
n = n/p_i
primes[p_i] += 1
if 1 != n:
return False
if max(primes.values()) > 1:
return False
return tuple([ p for p in primes if primes[p] == 1])
if pi_coeff.q in cst_table_some:
return C.chebyshevt(pi_coeff.p, cst_table_some[pi_coeff.q]).expand()
if 0 == pi_coeff.q % 2: # recursively remove powers of 2
narg = (pi_coeff*2)*S.Pi
nval = cos(narg)
if None == nval:
return None
nval = nval.rewrite(sqrt)
if not _EXPAND_INTS:
if (isinstance(nval, cos) or isinstance(-nval, cos)):
return None
x = (2*pi_coeff + 1)/2
sign_cos = (-1)**((-1 if x < 0 else 1)*int(abs(x)))
return sign_cos*sqrt( (1 + nval)/2 )
FC = fermatCoords(pi_coeff.q)
if FC:
decomp = ipartfrac(pi_coeff, FC)
X = [(x[1], x[0]*S.Pi) for x in zip(decomp, numbered_symbols('z'))]
pcls = cos(sum([x[0] for x in X]))._eval_expand_trig().subs(X)
return pcls.rewrite(sqrt)
if _EXPAND_INTS:
decomp = ipartfrac(pi_coeff)
X = [(x[1], x[0]*S.Pi) for x in zip(decomp, numbered_symbols('z'))]
pcls = cos(sum([x[0] for x in X]))._eval_expand_trig().subs(X)
return pcls
return None
def cse(exprs, symbols=None, optimizations=None, postprocess=None):
""" Perform common subexpression elimination on an expression.
Parameters
==========
exprs : list of sympy expressions, or a single sympy expression
The expressions to reduce.
symbols : infinite iterator yielding unique Symbols
The symbols used to label the common subexpressions which are pulled
out. The ``numbered_symbols`` generator is useful. The default is a
stream of symbols of the form "x0", "x1", etc. This must be an infinite
iterator.
optimizations : list of (callable, callable) pairs, optional
The (preprocessor, postprocessor) pairs. If not provided,
``sympy.simplify.cse.cse_optimizations`` is used.
postprocess : a function which accepts the two return values of cse and
returns the desired form of output from cse, e.g. if you want the
replacements reversed the function might be the following lambda:
lambda r, e: return reversed(r), e
Returns
=======
replacements : list of (Symbol, expression) pairs
All of the common subexpressions that were replaced. Subexpressions
earlier in this list might show up in subexpressions later in this list.
reduced_exprs : list of sympy expressions
The reduced expressions with all of the replacements above.
"""
from sympy.matrices import Matrix
if symbols is None:
symbols = numbered_symbols()
else:
# In case we get passed an iterable with an __iter__ method instead of
# an actual iterator.
symbols = iter(symbols)
seen_subexp = set()
muls = set()
adds = set()
to_eliminate = set()
if optimizations is None:
# Pull out the default here just in case there are some weird
# manipulations of the module-level list in some other thread.
optimizations = list(cse_optimizations)
# Handle the case if just one expression was passed.
if isinstance(exprs, Basic):
exprs = [exprs]
# Preprocess the expressions to give us better optimization opportunities.
reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
# Find all of the repeated subexpressions.
for expr in reduced_exprs:
if not isinstance(expr, Basic):
continue
pt = preorder_traversal(expr)
for subtree in pt:
inv = 1/subtree if subtree.is_Pow else None
if subtree.is_Atom or iterable(subtree) or inv and inv.is_Atom:
# Exclude atoms, since there is no point in renaming them.
continue
if subtree in seen_subexp:
if inv and _coeff_isneg(subtree.exp):
# save the form with positive exponent
subtree = inv
to_eliminate.add(subtree)
pt.skip()
continue
if inv and inv in seen_subexp:
if _coeff_isneg(subtree.exp):
# save the form with positive exponent
subtree = inv
to_eliminate.add(subtree)
pt.skip()
continue
elif subtree.is_Mul:
muls.add(subtree)
elif subtree.is_Add:
adds.add(subtree)
seen_subexp.add(subtree)
# process adds - any adds that weren't repeated might contain
# subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common
adds = [set(a.args) for a in ordered(adds)]
for i in xrange(len(adds)):
for j in xrange(i + 1, len(adds)):
com = adds[i].intersection(adds[j])
if len(com) > 1:
to_eliminate.add(Add(*com))
# remove this set of symbols so it doesn't appear again
#.........这里部分代码省略.........
开发者ID:QuaBoo,项目名称:sympy,代码行数:101,代码来源:cse_main.py
示例10: cse
def cse(exprs, symbols=None, optimizations=None):
""" Perform common subexpression elimination on an expression.
Parameters:
exprs : list of sympy expressions, or a single sympy expression
The expressions to reduce.
symbols : infinite iterator yielding unique Symbols
The symbols used to label the common subexpressions which are pulled
out. The `numbered_symbols` generator is useful. The default is a stream
of symbols of the form "x0", "x1", etc. This must be an infinite
iterator.
optimizations : list of (callable, callable) pairs, optional
The (preprocessor, postprocessor) pairs. If not provided,
`sympy.simplify.cse.cse_optimizations` is used.
Returns:
replacements : list of (Symbol, expression) pairs
All of the common subexpressions that were replaced. Subexpressions
earlier in this list might show up in subexpressions later in this list.
reduced_exprs : list of sympy expressions
The reduced expressions with all of the replacements above.
"""
if symbols is None:
symbols = numbered_symbols()
else:
# In case we get passed an iterable with an __iter__ method instead of
# an actual iterator.
symbols = iter(symbols)
seen_subexp = set()
to_eliminate = []
if optimizations is None:
# Pull out the default here just in case there are some weird
# manipulations of the module-level list in some other thread.
optimizations = list(cse_optimizations)
# Handle the case if just one expression was passed.
if isinstance(exprs, Basic):
exprs = [exprs]
# Preprocess the expressions to give us better optimization opportunities.
exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
# Find all of the repeated subexpressions.
for expr in exprs:
for subtree in postorder_traversal(expr):
if subtree.args == ():
# Exclude atoms, since there is no point in renaming them.
continue
if (subtree.args != () and
subtree in seen_subexp and
subtree not in to_eliminate):
to_eliminate.append(subtree)
seen_subexp.add(subtree)
# Substitute symbols for all of the repeated subexpressions.
replacements = []
reduced_exprs = list(exprs)
for i, subtree in enumerate(to_eliminate):
sym = symbols.next()
replacements.append((sym, subtree))
# Make the substitution in all of the target expressions.
for j, expr in enumerate(reduced_exprs):
reduced_exprs[j] = expr.subs(subtree, sym)
# Make the substitution in all of the subsequent substitutions.
# WARNING: modifying iterated list in-place! I think it's fine,
# but there might be clearer alternatives.
for j in range(i+1, len(to_eliminate)):
to_eliminate[j] = to_eliminate[j].subs(subtree, sym)
# Postprocess the expressions to return the expressions to canonical form.
for i, (sym, subtree) in enumerate(replacements):
subtree = postprocess_for_cse(subtree, optimizations)
replacements[i] = (sym, subtree)
reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs]
return replacements, reduced_exprs
def cse(exprs, symbols=None, optimizations=None, postprocess=None,
order='canonical', ignore=()):
""" Perform common subexpression elimination on an expression.
Parameters
==========
exprs : list of sympy expressions, or a single sympy expression
The expressions to reduce.
symbols : infinite iterator yielding unique Symbols
The symbols used to label the common subexpressions which are pulled
out. The ``numbered_symbols`` generator is useful. The default is a
stream of symbols of the form "x0", "x1", etc. This must be an
infinite iterator.
optimizations : list of (callable, callable) pairs
The (preprocessor, postprocessor) pairs of external optimization
functions. Optionally 'basic' can be passed for a set of predefined
basic optimizations. Such 'basic' optimizations were used by default
in old implementation, however they can be really slow on larger
expressions. Now, no pre or post optimizations are made by default.
postprocess : a function which accepts the two return values of cse and
returns the desired form of output from cse, e.g. if you want the
replacements reversed the function might be the following lambda:
lambda r, e: return reversed(r), e
order : string, 'none' or 'canonical'
The order by which Mul and Add arguments are processed. If set to
'canonical', arguments will be canonically ordered. If set to 'none',
ordering will be faster but dependent on expressions hashes, thus
machine dependent and variable. For large expressions where speed is a
concern, use the setting order='none'.
ignore : iterable of Symbols
Substitutions containing any Symbol from ``ignore`` will be ignored.
Returns
=======
replacements : list of (Symbol, expression) pairs
All of the common subexpressions that were replaced. Subexpressions
earlier in this list might show up in subexpressions later in this
list.
reduced_exprs : list of sympy expressions
The reduced expressions with all of the replacements above.
Examples
========
>>> from sympy import cse, SparseMatrix
>>> from sympy.abc import x, y, z, w
>>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3)
([(x0, w + y + z)], [x0*(x + x0)/(w + x)**3])
Note that currently, y + z will not get substituted if -y - z is used.
>>> cse(((w + x + y + z)*(w - y - z))/(w + x)**3)
([(x0, w + x)], [(w - y - z)*(x0 + y + z)/x0**3])
List of expressions with recursive substitutions:
>>> m = SparseMatrix([x + y, x + y + z])
>>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m])
([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([
[x0],
[x1]])])
Note: the type and mutability of input matrices is retained.
>>> isinstance(_[1][-1], SparseMatrix)
True
The user may disallow substitutions containing certain symbols:
>>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,))
([(x0, x + 1)], [x0*y**2, 3*x0*y**2])
"""
from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix,
SparseMatrix, ImmutableSparseMatrix)
# Handle the case if just one expression was passed.
if isinstance(exprs, (Basic, MatrixBase)):
exprs = [exprs]
copy = exprs
temp = []
for e in exprs:
if isinstance(e, (Matrix, ImmutableMatrix)):
temp.append(Tuple(*e._mat))
elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
temp.append(Tuple(*e._smat.items()))
else:
temp.append(e)
exprs = temp
del temp
if optimizations is None:
optimizations = list()
elif optimizations == 'basic':
optimizations = basic_optimizations
# Preprocess the expressions to give us better optimization opportunities.
reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
#.........这里部分代码省略.........
def cse(exprs, symbols=None, optimizations=None, postprocess=None):
""" Perform common subexpression elimination on an expression.
Parameters
==========
exprs : list of sympy expressions, or a single sympy expression
The expressions to reduce.
symbols : infinite iterator yielding unique Symbols
The symbols used to label the common subexpressions which are pulled
out. The ``numbered_symbols`` generator is useful. The default is a
stream of symbols of the form "x0", "x1", etc. This must be an infinite
iterator.
optimizations : list of (callable, callable) pairs, optional
The (preprocessor, postprocessor) pairs. If not provided,
``sympy.simplify.cse.cse_optimizations`` is used.
postprocess : a function which accepts the two return values of cse and
returns the desired form of output from cse, e.g. if you want the
replacements reversed the function might be the following lambda:
lambda r, e: return reversed(r), e
Returns
=======
replacements : list of (Symbol, expression) pairs
All of the common subexpressions that were replaced. Subexpressions
earlier in this list might show up in subexpressions later in this list.
reduced_exprs : list of sympy expressions
The reduced expressions with all of the replacements above.
"""
from sympy.matrices import Matrix
if symbols is None:
symbols = numbered_symbols()
else:
# In case we get passed an iterable with an __iter__ method instead of
# an actual iterator.
symbols = iter(symbols)
tmp_symbols = numbered_symbols('_csetmp')
subexp_iv = dict()
muls = set()
adds = set()
if optimizations is None:
# Pull out the default here just in case there are some weird
# manipulations of the module-level list in some other thread.
optimizations = list(cse_optimizations)
# Handle the case if just one expression was passed.
if isinstance(exprs, Basic):
exprs = [exprs]
# Preprocess the expressions to give us better optimization opportunities.
prep_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
# Find all subexpressions.
def _parse(expr):
if expr.is_Atom:
# Exclude atoms, since there is no point in renaming them.
return expr
if iterable(expr):
return expr
subexpr = type(expr)(*map(_parse, expr.args))
if subexpr in subexp_iv:
return subexp_iv[subexpr]
if subexpr.is_Mul:
muls.add(subexpr)
elif subexpr.is_Add:
adds.add(subexpr)
ivar = next(tmp_symbols)
subexp_iv[subexpr] = ivar
return ivar
tmp_exprs = list()
for expr in prep_exprs:
if isinstance(expr, Basic):
tmp_exprs.append(_parse(expr))
else:
tmp_exprs.append(expr)
# process adds - any adds that weren't repeated might contain
# subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common
adds = list(ordered(adds))
addargs = [set(a.args) for a in adds]
for i in xrange(len(addargs)):
for j in xrange(i + 1, len(addargs)):
com = addargs[i].intersection(addargs[j])
if len(com) > 1:
add_subexp = Add(*com)
diff_add_i = addargs[i].difference(com)
diff_add_j = addargs[j].difference(com)
#.........这里部分代码省略.........
def cse(exprs, symbols=None, optimizations=None, postprocess=None,
order='canonical'):
""" Perform common subexpression elimination on an expression.
Parameters
==========
exprs : list of sympy expressions, or a single sympy expression
The expressions to reduce.
symbols : infinite iterator yielding unique Symbols
The symbols used to label the common subexpressions which are pulled
out. The ``numbered_symbols`` generator is useful. The default is a
stream of symbols of the form "x0", "x1", etc. This must be an
infinite iterator.
optimizations : list of (callable, callable) pairs
The (preprocessor, postprocessor) pairs of external optimization
functions. Optionally 'basic' can be passed for a set of predefined
basic optimizations. Such 'basic' optimizations were used by default
in old implementation, however they can be really slow on larger
expressions. Now, no pre or post optimizations are made by default.
postprocess : a function which accepts the two return values of cse and
returns the desired form of output from cse, e.g. if you want the
replacements reversed the function might be the following lambda:
lambda r, e: return reversed(r), e
order : string, 'none' or 'canonical'
The order by which Mul and Add arguments are processed. If set to
'canonical', arguments will be canonically ordered. If set to 'none',
ordering will be faster but dependent on expressions hashes, thus
machine dependent and variable. For large expressions where speed is a
concern, use the setting order='none'.
Returns
=======
replacements : list of (Symbol, expression) pairs
All of the common subexpressions that were replaced. Subexpressions
earlier in this list might show up in subexpressions later in this
list.
reduced_exprs : list of sympy expressions
The reduced expressions with all of the replacements above.
"""
from sympy.matrices import Matrix
# Handle the case if just one expression was passed.
if isinstance(exprs, Basic):
exprs = [exprs]
if optimizations is None:
optimizations = list()
elif optimizations == 'basic':
optimizations = basic_optimizations
# Preprocess the expressions to give us better optimization opportunities.
reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
excluded_symbols = set.union(*[expr.atoms(Symbol)
for expr in reduced_exprs])
if symbols is None:
symbols = numbered_symbols()
else:
# In case we get passed an iterable with an __iter__ method instead of
# an actual iterator.
symbols = iter(symbols)
symbols = filter_symbols(symbols, excluded_symbols)
# Find other optimization opportunities.
opt_subs = opt_cse(reduced_exprs, order)
# Main CSE algorithm.
replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,
order)
# Postprocess the expressions to return the expressions to canonical form.
for i, (sym, subtree) in enumerate(replacements):
subtree = postprocess_for_cse(subtree, optimizations)
replacements[i] = (sym, subtree)
reduced_exprs = [postprocess_for_cse(e, optimizations)
for e in reduced_exprs]
if isinstance(exprs, Matrix):
reduced_exprs = [Matrix(exprs.rows, exprs.cols, reduced_exprs)]
if postprocess is None:
return replacements, reduced_exprs
return postprocess(replacements, reduced_exprs)
开发者ID:B-Rich,项目名称:sympy,代码行数:86,代码来源:cse_main.py
示例16: test_numbered_symbols
def test_numbered_symbols():
s = numbered_symbols(cls=Dummy)
assert isinstance(s.next(), Dummy)
请发表评论