本文整理汇总了Python中tvm.var函数的典型用法代码示例。如果您正苦于以下问题:Python var函数的具体用法?Python var怎么用?Python var使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了var函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test_simplify_if_then_else
def test_simplify_if_then_else():
ck = CanonicalChecker()
x = tvm.var("x")
y = tvm.var("y")
# simplification that takes condition into account.
res = tvm.if_then_else((x * 4 + y) >= 466036,
tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528),
(((((x*4) + y) - 466036) % 24528) -24512) % 16,
x), y)
expected = tvm.if_then_else(
tvm.expr.LE(466036, (x * 4 + y)),
tvm.if_then_else(tvm.expr.LE(24512, ((((x*4) + y) - 4) % 24528)),
(((x*4) + y) - 4) % 16,
x), y)
ck.verify(res, expected)
# can only simplify if condition
res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 100) % 3, (x + 100) % 3)
expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 1) % 3, (x + 100) % 3)
ck.verify(res, ck.analyzer.canonical_simplify(expected))
res = tvm.expr.Select(x >= 10,
tvm.if_then_else(x / 3 > 2, x, 0), 0)
expected = tvm.expr.Select(x >= 10, x, 0)
ck.verify(res, ck.analyzer.canonical_simplify(expected))
res = tvm.expr.Select(x >= 10,
tvm.if_then_else(x / 3 < 2, x, 0), 0)
ck.verify(res, 0)
开发者ID:bddppq,项目名称:tvm,代码行数:28,代码来源:test_arith_canonical_simplify.py
示例2: test_infer_type_leaky_relu
def test_infer_type_leaky_relu():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.leaky_relu(x, alpha=0.1)
"alpha=0.1" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")
shape = (1, 5, 10, 10)
dtype = "float32"
x = relay.var("x", relay.TensorType(shape, dtype))
z = relay.nn.leaky_relu(x, alpha=0.1)
assert "alpha=0.1" in z.astext()
yy = relay.ir_pass.infer_type(z)
assert yy.checked_type == relay.TensorType(shape, dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
ref_res = np.where(x_data > 0, x_data, x_data * 0.1)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:26,代码来源:test_op_level3.py
示例3: test_equality
def test_equality():
a = tvm.var('a')
b = tvm.var('b')
c = (a == b)
assert not c
d = (c != c)
assert not d
开发者ID:LANHUIYING,项目名称:tvm,代码行数:7,代码来源:test_lang_basic.py
示例4: test_parallel_alloc
def test_parallel_alloc():
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="i", for_type="parallel") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", n, name="A", scope="global")
A[j] = A[j] + 2
body = ib.get()
body = tvm.ir_pass.StorageRewrite(body)
assert (isinstance(body.body.body, tvm.stmt.Allocate))
ib = tvm.ir_builder.create()
n = tvm.var("n")
with ib.for_range(0, n, name="t") as i:
ib.scope_attr(
tvm.const(1, "int32") , "pragma_scope",
tvm.make.StringImm("parallel_launch_point"))
with ib.for_range(0, n, name="i", for_type="parallel") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", n, name="A", scope="global")
A[j] = A[j] + 2
body = ib.get()
body = tvm.ir_pass.StorageRewrite(body)
assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate))
开发者ID:bddppq,项目名称:tvm,代码行数:26,代码来源:test_pass_storage_rewrite.py
示例5: test_scan_group
def test_scan_group():
m = tvm.var("m")
n = tvm.var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: x[0, i])
s_update1 = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i])
s_update2 = tvm.compute((m, n), lambda t, i: s_update1[t, i] + 1)
s_update3 = tvm.compute((m, n), lambda t, i: s_update2[t, i] + 1)
res = tvm.scan(s_init, s_update3, s_state, inputs=x)
s = tvm.create_schedule(res.op)
assert s[s_update1].group is not None
assert s[s_update2].group == s[s_update1].group
# Assign within group, is valid
s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1])
# create a new group, for [s_update2 and s_update1]
g2 = s.create_group(outputs=s_update2, inputs=[s_state, x])
assert g2.group is not None
assert g2.group == s[s_update3].group
assert s[s_update2].group == g2
assert s[s_update1].group == g2
g2.compute_at(s[s_update3], s_update3.op.axis[1])
assert g2.attach_stage == s[s_update3]
try:
# compute outside group error.
s[s_update2].compute_at(s[s_init], s_init.op.axis[0])
assert False
except tvm.TVMError:
pass
开发者ID:LANHUIYING,项目名称:tvm,代码行数:31,代码来源:test_lang_group.py
示例6: test_mix_index
def test_mix_index():
a = tvm.var("a")
b = tvm.var("b")
analyzer = tvm.arith.Analyzer()
m = analyzer.modular_set(a * 4 + b * 6 + 7)
assert m.coeff == 2
assert m.base == 1
m = analyzer.modular_set((a * 4 + 1) * (b * 8 + 3))
assert m.coeff == 4
assert m.base == 3
m = analyzer.modular_set((a * 4 + 1) / (b * 8 + 3))
assert m.coeff == 1
assert m.base == 0
m = analyzer.modular_set((a * 4 + 1) * (b * 8 / 4))
assert m.coeff == 2
assert m.base == 0
m = analyzer.modular_set((a * 12 + 1) - (b * 3 * 7 + 2))
assert m.coeff == 3
assert m.base == 2
m = analyzer.modular_set(a * 12 + tvm.min(b * 3 * 7, 2))
assert m.coeff == 1
assert m.base == 0
开发者ID:bddppq,项目名称:tvm,代码行数:27,代码来源:test_arith_modular_set.py
示例7: test_multivariate
def test_multivariate():
v = [tvm.var("v%d" % i) for i in range(4)]
b = tvm.var("b")
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8, v)
assert(tvm.ir_pass.Equal(tvm.ir_pass.Simplify(m[0]), b + 5))
assert(m[1].value == 8)
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v)
assert(len(m) == 0)
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[1] + v[3], v)
assert(len(m) == 0)
m = tvm.arith.DetectLinearEquation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 2, v)
assert(m[1].value == 16)
assert(m[2].value == 2)
assert(m[len(m)-1].value == 2)
m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [v[2]])
assert(m[0].value == 0)
assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0)
m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [])
assert(len(m) == 1)
assert(tvm.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0)
开发者ID:bddppq,项目名称:tvm,代码行数:25,代码来源:test_arith_detect_linear_equation.py
示例8: test_reduce_functions
def test_reduce_functions():
def _with_keepdims(func):
def _wrapper(data, axis=None, keepdims=False):
if not keepdims:
return func(data, axis=axis)
else:
if axis is not None:
axis = axis[0]
out_shape = list(data.shape)
out_shape[axis] = 1
else:
out_shape = [1 for _ in range(len(data.shape))]
return func(data, axis=axis).reshape(out_shape)
return _wrapper
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
for func in [[relay.sum, np.sum],
[relay.max, np.max],
[relay.min, np.min],
[relay.mean, np.mean],
[relay.prod, np.prod],
[relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3), (1,), True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, False, True, ())
verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,))
verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,))
verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128))
verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1))
开发者ID:LANHUIYING,项目名称:tvm,代码行数:35,代码来源:test_op_level4.py
示例9: test_strided_slice
def test_strided_slice():
def verify(dshape, begin, end, strides, output, test_ref=True):
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.strided_slice(x, begin=begin, end=end, strides=strides)
func = relay.Function([x], z)
func = relay.ir_pass.infer_type(func)
text = func.astext()
assert "begin=" in text
assert "end=" in text
if output:
assert func.body.checked_type == relay.ty.TensorType(output, "float32")
if not test_ref:
return
x_data = np.random.uniform(size=dshape).astype("float32")
ref_res = topi.testing.strided_slice_python(
x_data, begin, end, strides)
for target, ctx in ctx_list():
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
verify((d1, d2, 3), [None, None, 1], [None, None, 2], None, (d1, d2, 1), False)
verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3))
verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
verify((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2))
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3))
verify((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3))
verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))
开发者ID:bddppq,项目名称:tvm,代码行数:32,代码来源:test_op_level4.py
示例10: test_min_max_bound
def test_min_max_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x"), tvm.var("y")
analyzer.update(x, tvm.arith.ConstIntBound(-9, 11))
analyzer.update(y, tvm.arith.ConstIntBound(4, 10))
bd = analyzer.const_int_bound(tvm.min(x, y))
assert bd.min_value == -9
assert bd.max_value == 10
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(tvm.min(x, y))
assert bd.min_value == bd.NEG_INF
assert bd.max_value == 10
bd = analyzer.const_int_bound(tvm.max(x, y))
assert bd.min_value == 4
assert bd.max_value == bd.POS_INF
analyzer.update(x, tvm.arith.ConstIntBound(1, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(4, 10), override=True)
bd = analyzer.const_int_bound(tvm.max(x, y))
assert bd.min_value == 4
assert bd.max_value == bd.POS_INF
开发者ID:bddppq,项目名称:tvm,代码行数:25,代码来源:test_arith_const_int_bound.py
示例11: test_add_sub_bound
def test_add_sub_bound():
analyzer = tvm.arith.Analyzer()
x, y = tvm.var("x", "int64"), tvm.var("y", "int64")
bd = analyzer.const_int_bound(x + y)
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF
analyzer.update(x, tvm.arith.ConstIntBound(0, 4))
analyzer.update(y, tvm.arith.ConstIntBound(1, 10))
bd = analyzer.const_int_bound(x + y)
assert bd.min_value == 1
assert bd.max_value == 14
bd = analyzer.const_int_bound(x - y)
assert bd.min_value == -10
assert bd.max_value == 3
analyzer.update(x, tvm.arith.ConstIntBound(0, bd.POS_INF), override=True)
bd = analyzer.const_int_bound(x - y)
assert bd.min_value == -10
assert bd.max_value == bd.POS_INF
bd = analyzer.const_int_bound(1 - x)
assert bd.min_value == bd.NEG_INF
assert bd.max_value == 1
开发者ID:bddppq,项目名称:tvm,代码行数:25,代码来源:test_arith_const_int_bound.py
示例12: test_vectorize_if_then_else
def test_vectorize_if_then_else():
n = tvm.var('n')
x = tvm.var('x')
ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A")
with ib.for_range(0, 4, for_type="vectorize") as i:
A[i] = tvm.call_intrin("float32", "tvm_if_then_else",
i > 0,
A[i] + 1, A[i])
stmt = ib.get()
stmt = tvm.ir_pass.VectorizeLoop(stmt)
assert isinstance(stmt, tvm.stmt.For)
ib = tvm.ir_builder.create()
A = ib.pointer("float32", name="A")
with ib.for_range(0, n) as k:
with ib.for_range(0, 4, for_type="vectorize") as i:
A[k * 4 + i] = tvm.call_intrin("float32", "tvm_if_then_else",
k > 0,
A[k * 4 + i], 0)
stmt = ib.get()
assert isinstance(stmt.body, tvm.stmt.For)
stmt = tvm.ir_pass.VectorizeLoop(stmt)
assert not isinstance(stmt.body, tvm.stmt.For)
assert isinstance(stmt.body.value.args[2], tvm.expr.Broadcast)
开发者ID:bddppq,项目名称:tvm,代码行数:26,代码来源:test_pass_vectorize.py
示例13: test_buffer_vload
def test_buffer_vload():
m = tvm.var('m')
n = tvm.var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32, elem_offset=100)
load = Ab.vload([2, 3])
offset = tvm.ir_pass.Simplify(load.index)
assert tvm.ir_pass.Equal(offset, n * 2 + 103)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:7,代码来源:test_lang_buffer.py
示例14: test_scan
def test_scan():
m = tvm.var("m")
n = tvm.var("n")
x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: x[0, i], name="s_init")
x_trans = tvm.compute((m, n), lambda i, j: x[i, j] + 1, name="x_trans")
s_up1 = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + 1, name="up1")
s_update = tvm.compute((m, n), lambda t, i: s_up1[t, i] + x_trans[t, i], name="update")
s_scan = tvm.scan(s_init, s_update, s_state)
def test_getbody():
body = tvm.schedule.ScanGetBody(s_scan.op)
assert set(body) == set([s_scan.op, s_update.op, s_up1.op])
def test_attach_path():
s = tvm.create_schedule(s_scan.op)
s[x_trans].compute_at(s[s_update], s_update.op.axis[0])
apath = tvm.schedule.CreateAttachPath(s)
assert(tuple(apath[s_update.op]) == tuple([s_scan.op.scan_axis]))
assert(tuple(apath[x_trans.op]) == tuple([s_update.op.axis[0], s_scan.op.scan_axis]))
def test_fix_pt():
body = tvm.schedule.ScanGetBody(s_scan.op)
fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
assert(fxpt[s_scan.spatial_axis_[0]].value != 0)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:26,代码来源:test_schedule_graph.py
示例15: test_tensor_comm_reducer
def test_tensor_comm_reducer():
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m, n), name='A')
k = tvm.reduce_axis((0, n), "k")
mysum = tvm.comm_reducer(lambda x, y: x+y, lambda t: tvm.const(0, dtype=t))
C = tvm.compute((m,), lambda i: mysum(A[i, k], axis=k))
开发者ID:gwli,项目名称:tvm,代码行数:7,代码来源:test_lang_tensor.py
示例16: test_conv2d_transpose_infer_type
def test_conv2d_transpose_infer_type():
# symbolic in batch dimension
n, c, h, w = tvm.var("n"), 10, 10, 12
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
w = relay.var("w", relay.IncompleteType())
y = relay.nn.conv2d_transpose(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=15)
assert "channels=15" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, 15, 10, 12), "float32")
assert yy.args[1].checked_type == relay.TensorType(
(10, 15, 3, 3), "float32")
# infer by shape of w, mixed precision
n, c, h, w = tvm.var("n"), 10, 10, 12
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
w = relay.var("w", relay.TensorType((12, 11, 5, 5), "float32"))
y = relay.nn.conv2d_transpose(x, w,
output_padding=(1, 1),
channels=11,
data_layout="NHWC")
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, 15, 15, 11), "float32")
开发者ID:bddppq,项目名称:tvm,代码行数:27,代码来源:test_op_level2.py
示例17: _post_order
def _post_order(op):
if isinstance(op, tvm.stmt.Allocate):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
return None
new_var = rw_info[buffer_var]
let_stmt = tvm.make.LetStmt(
new_var, tvm.call_extern(
"handle", "VTABufferCPUPtr",
env.dev.command_handle,
buffer_var), op.body)
alloc = tvm.make.Allocate(
buffer_var, op.dtype, op.extents,
op.condition, let_stmt)
del rw_info[buffer_var]
return alloc
if isinstance(op, tvm.expr.Load):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = tvm.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.make.Load(op.dtype, new_var, op.index)
if isinstance(op, tvm.stmt.Store):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = tvm.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.make.Store(new_var, op.value, op.index)
raise RuntimeError("not reached")
开发者ID:bddppq,项目名称:tvm,代码行数:31,代码来源:ir_pass.py
示例18: test_flatten_infer_type
def test_flatten_infer_type():
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
x = relay.var("x", relay.TensorType((d1, d2, d3, d4), "float32"))
y = relay.nn.batch_flatten(x)
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((d1, ((d2*d3)*d4)), "float32")
x = relay.var("x", relay.TensorType((3, 2, 4, 3), "float32"))
y = relay.nn.batch_flatten(x)
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((3, 24), "float32")
x = relay.var("x", relay.TensorType((d1, 2, d3, 3), "float32"))
y = relay.nn.batch_flatten(x)
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((d1, ((2*d3)*3)), "float32")
shape = (1, 5, 10, 10)
o_shape = (1, 500)
dtype = "float32"
x = relay.var("x", relay.TensorType(shape, dtype))
z = relay.nn.batch_flatten(x)
yy = relay.ir_pass.infer_type(z)
assert yy.checked_type == relay.TensorType(o_shape, dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
ref_res = x_data.flatten().reshape(o_shape)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
开发者ID:bddppq,项目名称:tvm,代码行数:35,代码来源:test_op_level2.py
示例19: test_storage_sync
def test_storage_sync():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.create_schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], factor=8)
s[A2].bind(xo, tvm.thread_axis("blockIdx.x"))
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
flist = tvm.ir_pass.SplitHostDevice(f)
f = flist[1]
f = tvm.ir_pass.ThreadSync(f, "shared")
body_list = tvm.make.stmt_list(f.body.body.body.body)
assert(body_list[1].value.name == "tvm_storage_sync")
开发者ID:gwli,项目名称:tvm,代码行数:26,代码来源:test_pass_storage_sync.py
示例20: test_lrn
def test_lrn():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", shape=(n, c , h, w))
y = relay.nn.lrn(x, size=10, axis=2, bias=0.5, alpha=.00001, beta=0.75)
"alpha=" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c , h, w))
shape = (1, 5, 10, 10)
dtype = "float32"
x = relay.var("x", relay.TensorType(shape, dtype))
size=5
axis=1
bias=0.5
alpha=.00001
beta=0.75
z = relay.nn.lrn(x, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)
yy = relay.ir_pass.infer_type(z)
assert yy.checked_type == relay.TensorType(shape, dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
ref_res = topi.testing.lrn_python(x_data, size, axis, bias, alpha, beta)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
开发者ID:bddppq,项目名称:tvm,代码行数:30,代码来源:test_op_level2.py
注:本文中的tvm.var函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论