本文整理汇总了Python中tvm.get_global_func函数的典型用法代码示例。如果您正苦于以下问题:Python get_global_func函数的具体用法?Python get_global_func怎么用?Python get_global_func使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_global_func函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: verify
def verify(target="llvm",
algorithm=nnpack.ConvolutionAlgorithm.AUTO,
with_bias=True):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
print("skip because extern function is not available")
return
if not nnpack.is_available():
return
ctx = tvm.cpu(0)
transformed_kernel = nnpack.convolution_inference_weight_transform(
kernel, algorithm=algorithm)
output = nnpack.convolution_inference_without_weight_transform(
data, transformed_kernel, bias if with_bias else None,
[PAD, PAD, PAD, PAD], [STRIDE, STRIDE],
algorithm=algorithm)
s = tvm.create_schedule(output.op)
f = tvm.build(s, [data, kernel, bias, output], target)
na = np.random.uniform(size=dshape).astype(data.dtype)
nb = np.random.uniform(size=kshape).astype(kernel.dtype)
nc = np.random.uniform(size=bshape).astype(bias.dtype) if with_bias else np.zeros(bshape, dtype=bias.dtype)
ta = tvm.nd.array(na, ctx)
tb = tvm.nd.array(nb, ctx)
tc = tvm.nd.array(nc, ctx)
td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
f(ta, tb, tc, td)
nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(1, bshape[0], 1, 1)
tvm.testing.assert_allclose(
td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:35,代码来源:test_nnpack.py
示例2: check
def check(factor):
s = tvm.create_schedule(z.op)
xo, xi = s[z].split(z.op.axis[0], factor=factor)
vadd = intrin_vadd(factor)
s[z].tensorize(xi, vadd)
s = s.normalize()
dom_map = tvm.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[z], dom_map)
assert tvm.ir_pass.Equal(out_dom[z.op.axis[0]].extent, factor)
assert tvm.ir_pass.Equal(out_dom[z.op.axis[0]].min, xo * factor)
assert tvm.ir_pass.Equal(in_dom.items()[0][1][0].extent, factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[z], out_dom, in_dom, vadd)
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
tvm.ir_pass.CanonicalSimplify(vadd.op.body[0]))
stmt = tvm.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [x, y, z])
开发者ID:bddppq,项目名称:tvm,代码行数:18,代码来源:test_schedule_tensorize.py
示例3: test_conv2d
def test_conv2d():
in_channel = 3
out_channel = 64
filter_h = 3
filter_w = 3
pad_h = 1
pad_w = 1
stride_h = 1
stride_w = 1
dilation_h = 1
dilation_w = 1
xshape = [1, in_channel, 128, 128]
if not tvm.module.enabled("rocm"):
print("skip because rocm is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True):
print("skip because miopen is not enabled...")
return
wshape = (out_channel, in_channel, filter_h, filter_w)
X = tvm.placeholder(xshape, name='X')
W = tvm.placeholder(wshape, name='W')
Y = miopen.conv2d_forward(X,
W,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
conv_mode=0)
yshape = [x.value for x in Y.shape]
import topi
with tvm.target.create("rocm -libs=miopen"):
s = topi.generic.schedule_extern(Y)
def verify():
ctx = tvm.rocm(0)
f = tvm.build(s, [X, W, Y], "rocm", target_host="llvm", name="conv2d")
x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(np.float32), ctx)
w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(np.float32), ctx)
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
f(x, w, y)
Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w))
with tvm.target.rocm():
s_ref = topi.generic.schedule_conv2d_nchw([Y_ref])
f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm")
y_ref = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
f_ref(x, w, y_ref)
print("Max abs diff:", np.max(np.abs(y.asnumpy() - y_ref.asnumpy())))
tvm.testing.assert_allclose(y.asnumpy(), y_ref.asnumpy(), atol=1e-3)
verify()
开发者ID:LANHUIYING,项目名称:tvm,代码行数:56,代码来源:test_miopen.py
示例4: stats
def stats():
"""Clear profiler statistics
Returns
-------
stats : dict
Current profiler statistics
"""
x = tvm.get_global_func("vta.simulator.profiler_status")()
return json.loads(x)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:10,代码来源:simulator.py
示例5: verify
def verify(A, B, C, target="llvm"):
if not tvm.get_global_func("tvm.contrib.mps.conv2d", True):
print("skip because extern function is not available")
return
ctx = tvm.metal(0)
f = tvm.build(s1, [A, B, C], "metal")
a = tvm.nd.array(np.random.uniform(size=(n, h, w, ci)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(co, kh, kw, ci)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((n, h // stride, w // stride, co), dtype=C.dtype), ctx)
f(a, b, c)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:10,代码来源:test_mps.py
示例6: test_get_global
def test_get_global():
targs = (10, 10.0, "hello")
# register into global function table
@tvm.register_func
def my_packed_func(*args):
assert(tuple(args) == targs)
return 10
# get it out from global function table
f = tvm.get_global_func("my_packed_func")
assert isinstance(f, tvm.Function)
y = f(*targs)
assert y == 10
开发者ID:bddppq,项目名称:tvm,代码行数:12,代码来源:test_runtime_packed_func.py
示例7: test_conv2d
def test_conv2d():
in_channel = 3
out_channel = 32
filter_h = 3
filter_w = 3
pad_h = 1
pad_w = 1
stride_h = 1
stride_w = 1
dilation_h = 1
dilation_w = 1
xshape = [4, 3, 32, 32]
if not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.cudnn.conv2d.output_shape", True):
print("skip because cudnn is not enabled...")
return
wshape = cudnn.conv2d_w_shape(in_channel,
out_channel,
filter_h,
filter_w)
X = tvm.placeholder(xshape, name='X')
W = tvm.placeholder(wshape, name='W')
Y = cudnn.conv2d_forward(X,
W,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
conv_mode=1,
tensor_format=0,
algo=1)
yshape = [x.value for x in Y.shape]
s = tvm.create_schedule(Y.op)
def verify():
ctx = tvm.gpu(0)
f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d")
x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(np.float32),
ctx)
w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(np.float32),
ctx)
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32),
ctx)
f(x, w, y)
verify()
开发者ID:bddppq,项目名称:tvm,代码行数:52,代码来源:test_cudnn.py
示例8: check_rfactor_no_reset_multi_reduction
def check_rfactor_no_reset_multi_reduction(factor, rfactor):
s = tvm.create_schedule(C.op)
x, y = C.op.axis
rk = C.op.reduce_axis[0]
yo, yi = s[C].split(y, factor=factor)
ro, ri = s[C].split(rk, factor=rfactor)
roo, roi = s[C].split(ro, factor=2)
s[C].reorder(yo, roo, roi, yi, ri)
gemv = intrin_gemv_no_reset(factor, rfactor)
s[C].tensorize(yi, gemv)
s = s.normalize()
dom_map = tvm.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[C], dom_map)
assert tvm.ir_pass.Equal(out_dom[x].extent, 1)
assert tvm.ir_pass.Equal(out_dom[y].extent, factor)
assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor)
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[C], out_dom, in_dom, gemv)
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
tvm.ir_pass.CanonicalSimplify(gemv.op.body[0]))
stmt = tvm.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [A, B, C])
开发者ID:bddppq,项目名称:tvm,代码行数:23,代码来源:test_schedule_tensorize.py
示例9: verify
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.random.normal", True):
print("skip because extern function is not available")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A], target)
a = tvm.nd.array(np.zeros((m, n), dtype=A.dtype), ctx)
f(a)
na = a.asnumpy()
assert abs(np.mean(na) - 3) < 1e-2
assert abs(np.std(na) - 4) < 1e-2
开发者ID:LANHUIYING,项目名称:tvm,代码行数:14,代码来源:test_random.py
示例10: verify
def verify(target="rocm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.rocblas.matmul", True):
print("skip because extern function is not available")
return
ctx = tvm.rocm(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c)
tvm.testing.assert_allclose(
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:15,代码来源:test_rocblas.py
示例11: verify
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
print("skip because extern function is not avalable")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10.0
f(a, b, d, bb)
np.testing.assert_allclose(
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5)
开发者ID:gwli,项目名称:tvm,代码行数:16,代码来源:test_cblas.py
示例12: test_get_callback_with_node
def test_get_callback_with_node():
x = tvm.convert(10)
def test(y):
assert y.handle != x.handle
return y
f2 = tvm.convert(test)
# register into global function table
@tvm.register_func
def my_callback_with_node(y, f):
assert y == x
return f(y)
# get it out from global function table
f = tvm.get_global_func("my_callback_with_node")
assert isinstance(f, tvm.Function)
y = f(x, f2)
assert(y.value == 10)
开发者ID:bddppq,项目名称:tvm,代码行数:18,代码来源:test_runtime_packed_func.py
示例13: verify
def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
print("skip because extern function is not avalable")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [data, kernel, bias, output], target)
na = np.random.uniform(size=dshape).astype(data.dtype)
nb = np.random.uniform(size=kshape).astype(kernel.dtype)
nc = np.zeros(bshape, dtype=bias.dtype)
ta = tvm.nd.array(na, ctx)
tb = tvm.nd.array(nb, ctx)
tc = tvm.nd.array(nc, ctx)
td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
f(ta, tb, tc, td)
nd = np_conv(na, nb, PAD)
np.testing.assert_allclose(
td.asnumpy(), nd, rtol=1e-5)
开发者ID:gwli,项目名称:tvm,代码行数:21,代码来源:test_nnpack.py
示例14: test_env_func
def test_env_func():
@tvm.register_func("test.env_func")
def test(x):
return x + 1
f = tvm.get_global_func("test.env_func")
x = tvm.get_env_func("test.env_func")
assert x.name == "test.env_func"
json_str = tvm.save_json([x])
y = tvm.load_json(json_str)[0]
assert y.name == x.name
assert y(1) == 2
assert y.func(1) == 2
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4), func=y)
assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10
x = tvm.load_json(tvm.save_json(x))
assert isinstance(x.func, tvm.container.EnvFunc)
assert x.func(10) == 11
开发者ID:bddppq,项目名称:tvm,代码行数:22,代码来源:test_lang_reflection.py
示例15: save_tensors
def save_tensors(params):
"""Save parameter dictionary to binary bytes.
The result binary bytes can be loaded by the
GraphModule with API "load_params".
Parameters
----------
params : dict of str to NDArray
The parameter dictionary.
Returns
-------
param_bytes: bytearray
Serialized parameters.
"""
_save_tensors = tvm.get_global_func("_save_param_dict")
args = []
for k, v in params.items():
args.append(k)
args.append(tvm.nd.array(v))
return _save_tensors(*args)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:23,代码来源:debug_result.py
示例16: save_param_dict
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Helper utility to save parameter dict"""
import tvm
_save_param_dict = tvm.get_global_func("nnvm.compiler._save_param_dict")
_load_param_dict = tvm.get_global_func("nnvm.compiler._load_param_dict")
def save_param_dict(params):
"""Save parameter dictionary to binary bytes.
The result binary bytes can be loaded by the
GraphModule with API "load_params".
Parameters
----------
params : dict of str to NDArray
The parameter dictionary.
Returns
-------
开发者ID:bddppq,项目名称:tvm,代码行数:31,代码来源:param_dict.py
示例17: GraphKey
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Compiler engine interface to internal engine
You can get the engine singleton at ``nnvm.compiler.engine``
"""
import tvm
_list_cache_items = tvm.get_global_func("nnvm.compiler.ListCacheItems")
_clear_cache = tvm.get_global_func("nnvm.compiler.ClearCache")
_get_cache_item = tvm.get_global_func("nnvm.compiler.GetCacheItem")
_set_cache_item = tvm.get_global_func("nnvm.compiler.SetCacheItem")
_graph_key_get_graph = tvm.get_global_func("nnvm.compiler.GraphKeyGetGraph")
_make_graph_key = tvm.get_global_func("nnvm.compiler.MakeGraphKey")
@tvm.register_node
class GraphKey(tvm.node.NodeBase):
"""Key of a graph compilation context"""
@property
def graph(self):
return _graph_key_get_graph(self)
@tvm.register_node
开发者ID:bddppq,项目名称:tvm,代码行数:31,代码来源:compile_engine.py
示例18: check_graph_equal
out_dtype: list of tuple
Dtype of outputs
"""
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph = graph.apply("InferType")
dtype = graph.json_attr("dtype")
index = graph.index
input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.input_names]
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.output_entries]
return input_dtype, output_dtype
_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")
def check_graph_equal(grapha, graphb, compare_variable_attrs=False):
"""Check if two graphs have equal structure.
Parameters
----------
grapha : Graph
The first graph
graphb : Graph
The second graph
compare_variable_attrs : bool, optional
Whether we want to compare attributes(names) on variables.
Usually it is safe to skip it unless we want input name
开发者ID:bddppq,项目名称:tvm,代码行数:30,代码来源:graph_util.py
示例19: AttrDict
# pylint: disable=invalid-name
"""Attr dictionary object used by schedule functions"""
import tvm
_dict_get = tvm.get_global_func("nnvm.compiler._dict_get")
_dict_size = tvm.get_global_func("nnvm.compiler._dict_size")
_dict_keys = tvm.get_global_func("nnvm.compiler._dict_keys")
class AttrDict(object):
"""Attribute dictionary in nnvm.
Used by python registration of compute and schedule function.
AttrDict is passed as the first argument to schedule and compute function.
"""
_tvm_tcode = 18
def __init__(self, handle):
self.handle = handle
def __del__(self):
tvm.nd.free_extension_handle(self.handle, 18)
@property
def _tvm_handle(self):
return self.handle.value
def __getitem__(self, key):
return _dict_get(self, key)
def keys(self):
"""Get list of keys in the dict.
开发者ID:masa-ito-fj,项目名称:nnvm,代码行数:31,代码来源:attr_dict.py
示例20: program_fpga
def program_fpga(file_name):
path = tvm.get_global_func("tvm.rpc.server.workpath")(file_name)
bitstream = Bitstream(path)
bitstream.download()
logging.info("Program FPGA with %s", file_name)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:5,代码来源:rpc_server.py
注:本文中的tvm.get_global_func函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论