本文整理汇总了Golang中github.com/gonum/blas/blas64.Gemm函数的典型用法代码示例。如果您正苦于以下问题:Golang Gemm函数的具体用法?Golang Gemm怎么用?Golang Gemm使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了Gemm函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Golang代码示例。
示例1: constructH
func constructH(tau []float64, v blas64.General, store lapack.StoreV, direct lapack.Direct) blas64.General {
m := v.Rows
k := v.Cols
if store == lapack.RowWise {
m, k = k, m
}
h := blas64.General{
Rows: m,
Cols: m,
Stride: m,
Data: make([]float64, m*m),
}
for i := 0; i < m; i++ {
h.Data[i*m+i] = 1
}
for i := 0; i < k; i++ {
vecData := make([]float64, m)
if store == lapack.ColumnWise {
for j := 0; j < m; j++ {
vecData[j] = v.Data[j*v.Cols+i]
}
} else {
for j := 0; j < m; j++ {
vecData[j] = v.Data[i*v.Cols+j]
}
}
vec := blas64.Vector{
Inc: 1,
Data: vecData,
}
hi := blas64.General{
Rows: m,
Cols: m,
Stride: m,
Data: make([]float64, m*m),
}
for i := 0; i < m; i++ {
hi.Data[i*m+i] = 1
}
// hi = I - tau * v * v^T
blas64.Ger(-tau[i], vec, vec, hi)
hcopy := blas64.General{
Rows: m,
Cols: m,
Stride: m,
Data: make([]float64, m*m),
}
copy(hcopy.Data, h.Data)
if direct == lapack.Forward {
// H = H * H_I in forward mode
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hcopy, hi, 0, h)
} else {
// H = H_I * H in backward mode
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hi, hcopy, 0, h)
}
}
return h
}
开发者ID:jacobxk,项目名称:lapack,代码行数:60,代码来源:general.go
示例2: convolveR
func (c *ConvLayer) convolveR(v autofunc.RVector, in, inR linalg.Vector, out *Tensor3) {
inMat := c.inputToMatrix(in)
inMatR := c.inputToMatrix(inR)
filterMat := blas64.General{
Rows: c.FilterCount,
Cols: inMat.Cols,
Stride: inMat.Stride,
Data: c.FilterVar.Vector,
}
outMat := blas64.General{
Rows: out.Width * out.Height,
Cols: out.Depth,
Stride: out.Depth,
Data: out.Data,
}
blas64.Gemm(blas.NoTrans, blas.Trans, 1, inMatR, filterMat, 0, outMat)
if filterRV, ok := v[c.FilterVar]; ok {
filterMatR := blas64.General{
Rows: c.FilterCount,
Cols: inMat.Cols,
Stride: inMat.Stride,
Data: filterRV,
}
blas64.Gemm(blas.NoTrans, blas.Trans, 1, inMat, filterMatR, 1, outMat)
}
if biasRV, ok := v[c.Biases]; ok {
biasVec := blas64.Vector{Inc: 1, Data: biasRV}
for i := 0; i < len(out.Data); i += outMat.Cols {
outRow := out.Data[i : i+outMat.Cols]
outVec := blas64.Vector{Inc: 1, Data: outRow}
blas64.Axpy(len(outRow), 1, biasVec, outVec)
}
}
}
开发者ID:unixpickle,项目名称:weakai,代码行数:35,代码来源:conv_layer.go
示例3: propagateSingle
func (c *convLayerResult) propagateSingle(input, upstream, downstream linalg.Vector,
grad autofunc.Gradient) {
upstreamMat := blas64.General{
Rows: c.Layer.OutputWidth() * c.Layer.OutputHeight(),
Cols: c.Layer.OutputDepth(),
Stride: c.Layer.OutputDepth(),
Data: upstream,
}
if downstream != nil {
inDeriv := c.Layer.inputToMatrix(input)
filterMat := blas64.General{
Rows: len(c.Layer.Filters),
Cols: c.Layer.FilterWidth * c.Layer.FilterHeight * c.Layer.InputDepth,
Stride: c.Layer.FilterWidth * c.Layer.FilterHeight * c.Layer.InputDepth,
Data: c.Layer.FilterVar.Vector,
}
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, upstreamMat, filterMat, 0, inDeriv)
flattened := NewTensor3Col(c.Layer.InputWidth, c.Layer.InputHeight,
c.Layer.InputDepth, inDeriv.Data, c.Layer.FilterWidth,
c.Layer.FilterHeight, c.Layer.Stride)
copy(downstream, flattened.Data)
}
if filterGrad, ok := grad[c.Layer.FilterVar]; ok {
inMatrix := c.Layer.inputToMatrix(input)
destMat := blas64.General{
Rows: len(c.Layer.Filters),
Cols: c.Layer.FilterWidth * c.Layer.FilterHeight * c.Layer.InputDepth,
Stride: c.Layer.FilterWidth * c.Layer.FilterHeight * c.Layer.InputDepth,
Data: filterGrad,
}
blas64.Gemm(blas.Trans, blas.NoTrans, 1, upstreamMat, inMatrix, 1, destMat)
}
}
开发者ID:unixpickle,项目名称:weakai,代码行数:35,代码来源:conv_layer.go
示例4: testDlarfx
func testDlarfx(t *testing.T, impl Dlarfxer, side blas.Side, m, n, extra int, rnd *rand.Rand) {
const tol = 1e-13
c := randomGeneral(m, n, n+extra, rnd)
cWant := randomGeneral(m, n, n+extra, rnd)
tau := rnd.NormFloat64()
var (
v []float64
h blas64.General
)
if side == blas.Left {
v = randomSlice(m, rnd)
h = eye(m, m+extra)
} else {
v = randomSlice(n, rnd)
h = eye(n, n+extra)
}
blas64.Ger(-tau, blas64.Vector{Inc: 1, Data: v}, blas64.Vector{Inc: 1, Data: v}, h)
if side == blas.Left {
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, c, 0, cWant)
} else {
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, c, h, 0, cWant)
}
var work []float64
if h.Rows > 10 {
// Allocate work only if H has order > 10.
if side == blas.Left {
work = make([]float64, n)
} else {
work = make([]float64, m)
}
}
impl.Dlarfx(side, m, n, v, tau, c.Data, c.Stride, work)
prefix := fmt.Sprintf("Case side=%v, m=%v, n=%v, extra=%v", side, m, n, extra)
// Check any invalid modifications of c.
if !generalOutsideAllNaN(c) {
t.Errorf("%v: out-of-range write to C\n%v", prefix, c.Data)
}
if !equalApproxGeneral(c, cWant, tol) {
t.Errorf("%v: unexpected C\n%v", prefix, c.Data)
}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:48,代码来源:dlarfx.go
示例5: Mul
// Mul takes the matrix product of a and b, placing the result in the receiver.
//
// See the Muler interface for more information.
func (m *Dense) Mul(a, b Matrix) {
ar, ac := a.Dims()
br, bc := b.Dims()
if ac != br {
panic(ErrShape)
}
m.reuseAs(ar, bc)
var w *Dense
if m != a && m != b {
w = m
} else {
w = getWorkspace(ar, bc, false)
defer func() {
m.Copy(w)
putWorkspace(w)
}()
}
if a, ok := a.(RawMatrixer); ok {
if b, ok := b.(RawMatrixer); ok {
amat, bmat := a.RawMatrix(), b.RawMatrix()
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, amat, bmat, 0, w.Mat)
return
}
}
if a, ok := a.(Vectorer); ok {
if b, ok := b.(Vectorer); ok {
row := make([]float64, ac)
col := make([]float64, br)
for r := 0; r < ar; r++ {
dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc]
for c := 0; c < bc; c++ {
dataTmp[c] = blas64.Dot(ac,
blas64.Vector{Inc: 1, Data: a.Row(row, r)},
blas64.Vector{Inc: 1, Data: b.Col(col, c)},
)
}
}
return
}
}
row := make([]float64, ac)
for r := 0; r < ar; r++ {
for i := range row {
row[i] = a.At(r, i)
}
for c := 0; c < bc; c++ {
var v float64
for i, e := range row {
v += e * b.At(i, c)
}
w.Mat.Data[r*w.Mat.Stride+c] = v
}
}
}
开发者ID:drewlanenga,项目名称:matrix,代码行数:62,代码来源:dense_arithmetic.go
示例6: testDorghr
func testDorghr(t *testing.T, impl Dorghrer, n, ilo, ihi, extra int, optwork bool, rnd *rand.Rand) {
const tol = 1e-14
// Construct the matrix A with elementary reflectors and scalar factors tau.
a := randomGeneral(n, n, n+extra, rnd)
var tau []float64
if n > 1 {
tau = nanSlice(n - 1)
}
work := nanSlice(max(1, n)) // Minimum work for Dgehrd.
impl.Dgehrd(n, ilo, ihi, a.Data, a.Stride, tau, work, len(work))
// Extract Q for later comparison.
q := eye(n, n)
qCopy := cloneGeneral(q)
for j := ilo; j < ihi; j++ {
h := eye(n, n)
v := blas64.Vector{
Inc: 1,
Data: make([]float64, n),
}
v.Data[j+1] = 1
for i := j + 2; i < ihi+1; i++ {
v.Data[i] = a.Data[i*a.Stride+j]
}
blas64.Ger(-tau[j], v, v, h)
copy(qCopy.Data, q.Data)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, h, 0, q)
}
if optwork {
work = nanSlice(1)
impl.Dorghr(n, ilo, ihi, a.Data, a.Stride, tau, work, -1)
work = nanSlice(int(work[0]))
} else {
work = nanSlice(max(1, ihi-ilo))
}
impl.Dorghr(n, ilo, ihi, a.Data, a.Stride, tau, work, len(work))
prefix := fmt.Sprintf("Case n=%v, ilo=%v, ihi=%v, extra=%v, optwork=%v", n, ilo, ihi, extra, optwork)
if !generalOutsideAllNaN(a) {
t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data)
}
if !isOrthonormal(a) {
t.Errorf("%v: A is not orthogonal\n%v", prefix, a.Data)
}
for i := 0; i < n; i++ {
for j := 0; j < n; j++ {
aij := a.Data[i*a.Stride+j]
qij := q.Data[i*q.Stride+j]
if math.Abs(aij-qij) > tol {
t.Errorf("%v: unexpected value of A[%v,%v]. want %v, got %v", prefix, i, j, qij, aij)
}
}
}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:56,代码来源:dorghr.go
示例7: QFromQR
// QFromQR extracts the m×m orthonormal matrix Q from a QR decomposition.
func (m *Dense) QFromQR(qr *QR) {
r, c := qr.qr.Dims()
m.reuseAs(r, r)
// Set Q = I.
for i := 0; i < r; i++ {
for j := 0; j < i; j++ {
m.mat.Data[i*m.mat.Stride+j] = 0
}
m.mat.Data[i*m.mat.Stride+i] = 1
for j := i + 1; j < r; j++ {
m.mat.Data[i*m.mat.Stride+j] = 0
}
}
// Construct Q from the elementary reflectors.
h := blas64.General{
Rows: r,
Cols: r,
Stride: r,
Data: make([]float64, r*r),
}
qCopy := getWorkspace(r, r, false)
v := blas64.Vector{
Inc: 1,
Data: make([]float64, r),
}
for i := 0; i < c; i++ {
// Set h = I.
for i := range h.Data {
h.Data[i] = 0
}
for j := 0; j < r; j++ {
h.Data[j*r+j] = 1
}
// Set the vector data as the elementary reflector.
for j := 0; j < i; j++ {
v.Data[j] = 0
}
v.Data[i] = 1
for j := i + 1; j < r; j++ {
v.Data[j] = qr.qr.mat.Data[j*qr.qr.mat.Stride+i]
}
// Compute the multiplication matrix.
blas64.Ger(-qr.tau[i], v, v, h)
qCopy.Copy(m)
blas64.Gemm(blas.NoTrans, blas.NoTrans,
1, qCopy.mat, h,
0, m.mat)
}
}
开发者ID:yonglehou,项目名称:matrix,代码行数:54,代码来源:qr.go
示例8: QFromLQ
// QFromLQ extracts the n×n orthonormal matrix Q from an LQ decomposition.
func (m *Dense) QFromLQ(lq *LQ) {
r, c := lq.lq.Dims()
m.reuseAs(c, c)
// Set Q = I.
for i := 0; i < c; i++ {
for j := 0; j < i; j++ {
m.mat.Data[i*m.mat.Stride+j] = 0
}
m.mat.Data[i*m.mat.Stride+i] = 1
for j := i + 1; j < c; j++ {
m.mat.Data[i*m.mat.Stride+j] = 0
}
}
// Construct Q from the elementary reflectors.
h := blas64.General{
Rows: c,
Cols: c,
Stride: c,
Data: make([]float64, c*c),
}
qCopy := getWorkspace(c, c, false)
v := blas64.Vector{
Inc: 1,
Data: make([]float64, c),
}
for i := 0; i < r; i++ {
// Set h = I.
for i := range h.Data {
h.Data[i] = 0
}
for j := 0; j < c; j++ {
h.Data[j*c+j] = 1
}
// Set the vector data as the elementary reflector.
for j := 0; j < i; j++ {
v.Data[j] = 0
}
v.Data[i] = 1
for j := i + 1; j < c; j++ {
v.Data[j] = lq.lq.mat.Data[i*lq.lq.mat.Stride+j]
}
// Compute the multiplication matrix.
blas64.Ger(-lq.tau[i], v, v, h)
qCopy.Copy(m)
blas64.Gemm(blas.NoTrans, blas.NoTrans,
1, h, qCopy.mat,
0, m.mat)
}
}
开发者ID:yonglehou,项目名称:matrix,代码行数:54,代码来源:lq.go
示例9: Eval
// Eval returns a matrix literal.
func (m1 *Mul) Eval() MatrixLiteral {
// This should be replaced with a call to Eval on each side, and then a type
// switch to handle the various matrix literals.
lm := m1.Left.Eval()
rm := m1.Right.Eval()
left := lm.AsGeneral()
right := rm.AsGeneral()
r, c := m1.Dims()
m := blas64.General{
Rows: r,
Cols: c,
Stride: c,
Data: make([]float64, r*c),
}
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, left, right, 0, m)
return &General{m}
}
开发者ID:jonlawlor,项目名称:matrixexp,代码行数:21,代码来源:mul.go
示例10: convolve
func (c *ConvLayer) convolve(in linalg.Vector, out *Tensor3) {
inMat := c.inputToMatrix(in)
filterMat := blas64.General{
Rows: c.FilterCount,
Cols: inMat.Cols,
Stride: inMat.Stride,
Data: c.FilterVar.Vector,
}
outMat := blas64.General{
Rows: out.Width * out.Height,
Cols: out.Depth,
Stride: out.Depth,
Data: out.Data,
}
blas64.Gemm(blas.NoTrans, blas.Trans, 1, inMat, filterMat, 0, outMat)
biasVec := blas64.Vector{Inc: 1, Data: c.Biases.Vector}
for i := 0; i < len(out.Data); i += outMat.Cols {
outRow := out.Data[i : i+outMat.Cols]
outVec := blas64.Vector{Inc: 1, Data: outRow}
blas64.Axpy(len(outRow), 1, biasVec, outVec)
}
}
开发者ID:unixpickle,项目名称:weakai,代码行数:23,代码来源:conv_layer.go
示例11: dlatrdCheckDecomposition
// dlatrdCheckDecomposition checks that the first nb rows have been successfully
// reduced.
func dlatrdCheckDecomposition(t *testing.T, uplo blas.Uplo, n, nb int, e, tau, a []float64, lda int, aGen, q blas64.General) bool {
// Compute Q^T * A * Q.
tmp := blas64.General{
Rows: n,
Cols: n,
Stride: n,
Data: make([]float64, n*n),
}
ans := blas64.General{
Rows: n,
Cols: n,
Stride: n,
Data: make([]float64, n*n),
}
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aGen, 0, tmp)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, ans)
// Compare with T.
if uplo == blas.Upper {
for i := n - 1; i >= n-nb; i-- {
for j := 0; j < n; j++ {
v := ans.Data[i*ans.Stride+j]
switch {
case i == j:
if math.Abs(v-a[i*lda+j]) > 1e-10 {
return false
}
case i == j-1:
if math.Abs(a[i*lda+j]-1) > 1e-10 {
return false
}
if math.Abs(v-e[i]) > 1e-10 {
return false
}
case i == j+1:
default:
if math.Abs(v) > 1e-10 {
return false
}
}
}
}
} else {
for i := 0; i < nb; i++ {
for j := 0; j < n; j++ {
v := ans.Data[i*ans.Stride+j]
switch {
case i == j:
if math.Abs(v-a[i*lda+j]) > 1e-10 {
return false
}
case i == j-1:
case i == j+1:
if math.Abs(a[i*lda+j]-1) > 1e-10 {
return false
}
if math.Abs(v-e[i-1]) > 1e-10 {
return false
}
default:
if math.Abs(v) > 1e-10 {
return false
}
}
}
}
}
return true
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:73,代码来源:dlatrd.go
示例12: DlarfbTest
//.........这里部分代码省略.........
a := make([]float64, ma*lda)
for i := 0; i < ma; i++ {
for j := 0; j < lda; j++ {
a[i*lda+j] = rnd.Float64()
}
}
k := min(ma, na)
// H is always ma x ma
var m, n, rowsWork int
switch {
default:
panic("not implemented")
case side == blas.Left:
m = test.ma
n = test.cdim
rowsWork = n
case side == blas.Right:
m = test.cdim
n = test.ma
rowsWork = m
}
// Use dgeqr2 to find the v vectors
tau := make([]float64, na)
work := make([]float64, na)
impl.Dgeqr2(ma, k, a, lda, tau, work)
// Correct the v vectors based on the direct and store
vMatTmp := extractVMat(ma, na, a, lda, lapack.Forward, lapack.ColumnWise)
vMat := constructVMat(vMatTmp, store, direct)
v := vMat.Data
ldv := vMat.Stride
// Use dlarft to find the t vector
ldt := test.ldt
if ldt == 0 {
ldt = k
}
tm := make([]float64, k*ldt)
impl.Dlarft(direct, store, ma, k, v, ldv, tau, tm, ldt)
// Generate c matrix
ldc := test.ldc
if ldc == 0 {
ldc = n
}
c := make([]float64, m*ldc)
for i := 0; i < m; i++ {
for j := 0; j < ldc; j++ {
c[i*ldc+j] = rnd.Float64()
}
}
cCopy := make([]float64, len(c))
copy(cCopy, c)
ldwork := k
work = make([]float64, rowsWork*k)
// Call Dlarfb with this information
impl.Dlarfb(side, trans, direct, store, m, n, k, v, ldv, tm, ldt, c, ldc, work, ldwork)
h := constructH(tau, vMat, store, direct)
cMat := blas64.General{
Rows: m,
Cols: n,
Stride: ldc,
Data: make([]float64, m*ldc),
}
copy(cMat.Data, cCopy)
ans := blas64.General{
Rows: m,
Cols: n,
Stride: ldc,
Data: make([]float64, m*ldc),
}
copy(ans.Data, cMat.Data)
switch {
default:
panic("not implemented")
case side == blas.Left && trans == blas.NoTrans:
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, cMat, 0, ans)
case side == blas.Left && trans == blas.Trans:
blas64.Gemm(blas.Trans, blas.NoTrans, 1, h, cMat, 0, ans)
case side == blas.Right && trans == blas.NoTrans:
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat, h, 0, ans)
case side == blas.Right && trans == blas.Trans:
blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMat, h, 0, ans)
}
if !floats.EqualApprox(ans.Data, c, 1e-14) {
t.Errorf("Cas %v mismatch. Want %v, got %v.", cas, ans.Data, c)
}
}
}
}
}
}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:101,代码来源:dlarfb.go
示例13: DlarfTest
//.........这里部分代码省略.........
lastr: 0,
lastc: 1,
tau: 2,
},
{
m: 10,
n: 10,
ldc: 10,
incv: 4,
lastv: 6,
lastr: 9,
lastc: 8,
tau: 2,
},
} {
// Construct a random matrix.
c := make([]float64, test.ldc*test.m)
for i := 0; i <= test.lastr; i++ {
for j := 0; j <= test.lastc; j++ {
c[i*test.ldc+j] = rand.Float64()
}
}
cCopy := make([]float64, len(c))
copy(cCopy, c)
cCopy2 := make([]float64, len(c))
copy(cCopy2, c)
// Test with side right.
sz := max(test.m, test.n) // so v works for both right and left side.
v := make([]float64, test.incv*sz+1)
// Fill with nonzero entries up until lastv.
for i := 0; i <= test.lastv; i++ {
v[i*test.incv] = rand.Float64()
}
// Construct h explicitly to compare.
h := make([]float64, test.n*test.n)
for i := 0; i < test.n; i++ {
h[i*test.n+i] = 1
}
hMat := blas64.General{
Rows: test.n,
Cols: test.n,
Stride: test.n,
Data: h,
}
vVec := blas64.Vector{
Inc: test.incv,
Data: v,
}
blas64.Ger(-test.tau, vVec, vVec, hMat)
// Apply multiplication (2nd copy is to avoid aliasing).
cMat := blas64.General{
Rows: test.m,
Cols: test.n,
Stride: test.ldc,
Data: cCopy,
}
cMat2 := blas64.General{
Rows: test.m,
Cols: test.n,
Stride: test.ldc,
Data: cCopy2,
}
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat2, hMat, 0, cMat)
// cMat now stores the true answer. Compare with the function call.
work := make([]float64, sz)
impl.Dlarf(blas.Right, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
if !floats.EqualApprox(c, cMat.Data, 1e-14) {
t.Errorf("Dlarf mismatch right, case %v. Want %v, got %v", i, cMat.Data, c)
}
// Test on the left side.
copy(c, cCopy2)
copy(cCopy, c)
// Construct h.
h = make([]float64, test.m*test.m)
for i := 0; i < test.m; i++ {
h[i*test.m+i] = 1
}
hMat = blas64.General{
Rows: test.m,
Cols: test.m,
Stride: test.m,
Data: h,
}
blas64.Ger(-test.tau, vVec, vVec, hMat)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, cMat2, 0, cMat)
impl.Dlarf(blas.Left, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
if !floats.EqualApprox(c, cMat.Data, 1e-14) {
t.Errorf("Dlarf mismatch left, case %v. Want %v, got %v", i, cMat.Data, c)
}
}
}
开发者ID:RomainVabre,项目名称:origin,代码行数:101,代码来源:dlarf.go
示例14: checkPLU
// checkPLU checks that the PLU factorization contained in factorize matches
// the original matrix contained in original.
func checkPLU(t *testing.T, ok bool, m, n, lda int, ipiv []int, factorized, original []float64, tol float64, print bool) {
var hasZeroDiagonal bool
for i := 0; i < min(m, n); i++ {
if factorized[i*lda+i] == 0 {
hasZeroDiagonal = true
break
}
}
if hasZeroDiagonal && ok {
t.Error("Has a zero diagonal but returned ok")
}
if !hasZeroDiagonal && !ok {
t.Error("Non-zero diagonal but returned !ok")
}
// Check that the LU decomposition is correct.
mn := min(m, n)
l := make([]float64, m*mn)
ldl := mn
u := make([]float64, mn*n)
ldu := n
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
v := factorized[i*lda+j]
switch {
case i == j:
l[i*ldl+i] = 1
u[i*ldu+i] = v
case i > j:
l[i*ldl+j] = v
case i < j:
u[i*ldu+j] = v
}
}
}
LU := blas64.General{
Rows: m,
Cols: n,
Stride: n,
Data: make([]float64, m*n),
}
U := blas64.General{
Rows: mn,
Cols: n,
Stride: ldu,
Data: u,
}
L := blas64.General{
Rows: m,
Cols: mn,
Stride: ldl,
Data: l,
}
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU)
p := make([]float64, m*m)
ldp := m
for i := 0; i < m; i++ {
p[i*ldp+i] = 1
}
for i := len(ipiv) - 1; i >= 0; i-- {
v := ipiv[i]
blas64.Swap(m, blas64.Vector{1, p[i*ldp:]}, blas64.Vector{1, p[v*ldp:]})
}
P := blas64.General{
Rows: m,
Cols: m,
Stride: m,
Data: p,
}
aComp := blas64.General{
Rows: m,
Cols: n,
Stride: lda,
Data: make([]float64, m*lda),
}
copy(aComp.Data, factorized)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp)
if !floats.EqualApprox(aComp.Data, original, tol) {
if print {
t.Errorf("PLU multiplication does not match original matrix.\nWant: %v\nGot: %v", original, aComp.Data)
return
}
t.Error("PLU multiplication does not match original matrix.")
}
}
开发者ID:jacobxk,项目名称:lapack,代码行数:89,代码来源:dgetf2.go
示例15: Dgelq2Test
func Dgelq2Test(t *testing.T, impl Dgelq2er) {
for c, test := range []struct {
m, n, lda int
}{
{1, 1, 0},
{2, 2, 0},
{3, 2, 0},
{2, 3, 0},
{1, 12, 0},
{2, 6, 0},
{3, 4, 0},
{4, 3, 0},
{6, 2, 0},
{1, 12, 0},
{1, 1, 20},
{2, 2, 20},
{3, 2, 20},
{2, 3, 20},
{1, 12, 20},
{2, 6, 20},
{3, 4, 20},
{4, 3, 20},
{6, 2, 20},
{1, 12, 20},
} {
n := test.n
m := test.m
lda := test.lda
if lda == 0 {
lda = test.n
}
k := min(m, n)
tau := make([]float64, k)
for i := range tau {
tau[i] = rand.Float64()
}
work := make([]float64, m)
for i := range work {
work[i] = rand.Float64()
}
a := make([]float64, m*lda)
for i := 0; i < m*lda; i++ {
a[i] = rand.Float64()
}
aCopy := make([]float64, len(a))
copy(aCopy, a)
impl.Dgelq2(m, n, a, lda, tau, work)
Q := constructQ("LQ", m, n, a, lda, tau)
// Check that Q is orthonormal
for i := 0; i < Q.Rows; i++ {
nrm := blas64.Nrm2(Q.Cols, blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]})
if math.Abs(nrm-1) > 1e-14 {
t.Errorf("Q not normal. Norm is %v", nrm)
}
for j := 0; j < i; j++ {
dot := blas64.Dot(Q.Rows,
blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]},
blas64.Vector{Inc: 1, Data: Q.Data[j*Q.Stride:]},
)
if math.Abs(dot) > 1e-14 {
t.Errorf("Q not orthogonal. Dot is %v", dot)
}
}
}
L := blas64.General{
Rows: m,
Cols: n,
Stride: n,
Data: make([]float64, m*n),
}
for i := 0; i < m; i++ {
for j := 0; j <= min(i, n-1); j++ {
L.Data[i*L.Stride+j] = a[i*lda+j]
}
}
ans := blas64.General{
Rows: m,
Cols: n,
Stride: lda,
Data: make([]float64, m*lda),
}
copy(ans.Data, aCopy)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, Q, 0, ans)
if !floats.EqualApprox(aCopy, ans.Data, 1e-14) {
t.Errorf("Case %v, LQ mismatch. Want %v, got %v.", c, aCopy, ans.Data)
}
}
}
开发者ID:RomainVabre,项目名称:origin,代码行数:92,代码来源:dgelq2.go
示例16: testDgebak
func testDgebak(t *testing.T, impl Dgebaker, job lapack.Job, side blas.Side, ilo, ihi int, v blas64.General, rnd *rand.Rand) {
const tol = 1e-15
n := v.Rows
m := v.Cols
extra := v.Stride - v.Cols
// Create D and D^{-1} by generating random scales between ilo and ihi.
d := eye(n, n)
dinv := eye(n, n)
scale := nanSlice(n)
if job == lapack.Scale || job == lapack.PermuteScale {
if ilo == ihi {
scale[ilo] = 1
} else {
for i := ilo; i <= ihi; i++ {
scale[i] = 2 * rnd.Float64()
d.Data[i*d.Stride+i] = scale[i]
dinv.Data[i*dinv.Stride+i] = 1 / scale[i]
}
}
}
// Create P by generating random column swaps.
p := eye(n, n)
if job == lapack.Permute || job == lapack.PermuteScale {
// Make up some random permutations.
for i := n - 1; i > ihi; i-- {
scale[i] = float64(rnd.Intn(i + 1))
blas64.Swap(n,
blas64.Vector{p.Stride, p.Data[i:]},
blas64.Vector{p.Stride, p.Data[int(scale[i]):]})
}
for i := 0; i < ilo; i++ {
scale[i] = float64(i + rnd.Intn(ihi-i+1))
blas64.Swap(n,
blas64.Vector{p.Stride, p.Data[i:]},
blas64.Vector{p.Stride, p.Data[int(scale[i]):]})
}
}
got := cloneGeneral(v)
impl.Dgebak(job, side, n, ilo, ihi, scale, m, got.Data, got.Stride)
prefix := fmt.Sprintf("Case job=%v, side=%v, n=%v, ilo=%v, ihi=%v, m=%v, extra=%v",
job, side, n, ilo, ihi, m, extra)
if !generalOutsideAllNaN(got) {
t.Errorf("%v: out-of-range write to V\n%v", prefix, got.Data)
}
// Compute D*V or D^{-1}*V and store into dv.
dv := zeros(n, m, m)
if side == blas.Right {
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d, v, 0, dv)
} else {
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, dinv, v, 0, dv)
}
// Compute P*D*V or P*D^{-1}*V and store into want.
want := zeros(n, m, m)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, p, dv, 0, want)
if !equalApproxGeneral(want, got, tol) {
t.Errorf("%v: unexpected value of V", prefix)
}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:65,代码来源:dgebak.go
示例17: Dgeqr2Test
func Dgeqr2Test(t *testing.T, impl Dgeqr2er) {
for c, test := range []struct {
m, n, lda int
}{
{1, 1, 0},
{2, 2, 0},
{3, 2, 0},
{2, 3, 0},
{1, 12, 0},
{2, 6, 0},
{3, 4, 0},
{4, 3, 0},
{6, 2, 0},
{12, 1, 0},
{1, 1, 20},
{2, 2, 20},
{3, 2, 20},
{2, 3, 20},
{1, 12, 20},
{2, 6, 20},
{3, 4, 20},
{4, 3, 20},
{6, 2, 20},
{12, 1, 20},
} {
n := test.n
m := test.m
lda := test.lda
if lda == 0 {
lda = test.n
}
a := make([]float64, m*lda)
for i := range a {
a[i] = rand.Float64()
}
aCopy := make([]float64, len(a))
k := min(m, n)
tau := make([]float64, k)
for i := range tau {
tau[i] = rand.Float64()
}
work := make([]float64, n)
for i := range work {
work[i] = rand.Float64()
}
copy(aCopy, a)
impl.Dgeqr2(m, n, a, lda, tau, work)
// Test that the QR factorization has completed successfully. Compute
// Q based on the vectors.
q := constructQ("QR", m, n, a, lda, tau)
// Check that q is orthonormal
for i := 0; i < m; i++ {
nrm := blas64.Nrm2(m, blas64.Vector{1, q.Data[i*m:]})
if math.Abs(nrm-1) > 1e-14 {
t.Errorf("Case %v, q not normal", c)
}
for j := 0; j < i; j++ {
dot := blas64.Dot(m, blas64.Vector{1, q.Data[i*m:]}, blas64.Vector{1, q.Data[j*m:]})
if math.Abs(dot) > 1e-14 {
t.Errorf("Case %v, q not orthogonal", i)
}
}
}
// Check that A = Q * R
r := blas64.General{
Rows: m,
Cols: n,
Stride: n,
Data: make([]float64, m*n),
}
for i := 0; i < m; i++ {
for j := i; j < n; j++ {
r.Data[i*n+j] = a[i*lda+j]
}
}
atmp := blas64.General{
Rows: m,
Cols: n,
Stride: lda,
Data: make([]float64, m*lda),
}
copy(atmp.Data, a)
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, r, 0, atmp)
if !floats.EqualApprox(atmp.Data, aCopy, 1e-14) {
t.Errorf("Q*R != a")
}
}
}
开发者ID:RomainVabre,项目名称:origin,代码行数:90,代码来源:dgeqr2.go
示例18: constructQPBidiagonal
//.........这里部分代码省略.........
}
}
if vect == lapack.ApplyQ {
if m >= n {
for i := 0; i < m; i++ {
for j := 0; j <= min(nb-1, i); j++ {
if i == j {
v.Data[i*ldv+j] = 1
continue
}
v.Data[i*ldv+j] = a[i*lda+j]
}
}
} else {
for i := 1; i < m; i++ {
for j := 0; j <= min(nb-1, i-1); j++ {
if i-1 == j {
v.Data[i*ldv+j] = 1
continue
}
v.Data[i*ldv+j] = a[i*lda+j]
}
}
}
} else {
if m < n {
for i := 0; i < nb; i++ {
for j := i; j < n; j++ {
if i == j {
v.Data[i*ldv+j] = 1
continue
}
v.Data[i*ldv+j] = a[i*lda+j]
}
}
} else {
for i := 0; i < nb; i++ {
for j := i + 1; j < n; j++ {
if j-1 == i {
v.Data[i*ldv+j] = 1
continue
}
v.Data[i*ldv+j] = a[i*lda+j]
}
}
}
}
// The variable name is a computation of Q, but the algorithm is mostly the
// same for computing P (just with different data).
qMat := blas64.General{
Rows: sz,
Cols: sz,
Stride: sz,
Data: make([]float64, sz*sz),
}
hMat := blas64.General{
Rows: sz,
Cols: sz,
Stride: sz,
Data: make([]float64, sz*sz),
}
// set Q to I
for i := 0; i < sz; i++ {
qMat.Data[i*qMat.Stride+i] = 1
}
for i := 0; i < nb; i++ {
qCopy := blas64.General{Rows: qMat.Rows, Cols: qMat.Cols, Stride: qMat.Stride, Data: make([]float64, len(qMat.Data))}
copy(qCopy.Data, qMat.Data)
// Set g and h to I
for i := 0; i < sz; i++ {
for j := 0; j < sz; j++ {
if i == j {
hMat.Data[i*sz+j] = 1
} else {
hMat.Data[i*sz+j] = 0
}
}
}
var vi blas64.Vector
// H -= tauQ[i] * v[i] * v[i]^t
if vect == lapack.ApplyQ {
vi = blas64.Vector{
Inc: v.Stride,
Data: v.Data[i:],
}
} else {
vi = blas64.Vector{
Inc: 1,
Data: v.Data[i*v.Stride:],
}
}
blas64.Ger(-tau[i], vi, vi, hMat)
// Q = Q * G[1]
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, hMat, 0, qMat)
}
return qMat
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:101,代码来源:general.go
示例19: Mul
// Mul takes the matrix product of a and b, placing the result in the receiver.
//
// See the Muler interface for more information.
func (m *Dense) Mul(a, b Matrix) {
ar, ac := a.Dims()
br, bc := b.Dims()
if ac != br {
panic(ErrShape)
}
aU, aTrans := untranspose(a)
bU, bTrans := untranspose(b)
m.reuseAs(ar, bc)
var restore func()
if m == aU {
m, restore = m.isolatedWorkspace(aU)
defer restore()
} else if m == bU {
m, restore = m.isolatedWorkspace(bU)
defer restore()
}
aT := blas.NoTrans
if aTrans {
aT = blas.Trans
}
bT := blas.NoTrans
if bTrans {
bT = blas.Trans
}
// Some of the cases do not have a transpose option, so create
// temporary memory.
// C = A^T * B = (B^T * A)^T
// C^T = B^T * A.
if aU, ok := aU.(RawMatrixer); ok {
amat := aU.RawMatrix()
if bU, ok := bU.(RawMatrixer); ok {
bmat := bU.RawMatrix()
blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
return
}
if bU, ok := bU.(RawSymmetricer); ok {
bmat := bU.RawSymmetric()
if aTrans {
c := getWorkspace(ac, ar, false)
blas64.Symm(blas.Left, 1, bmat, amat, 0, c.mat)
strictCopy(m, c.T())
putWorkspace(c)
return
}
blas64.Symm(blas.Right, 1, bmat, amat, 0, m.mat)
return
}
if bU, ok := bU.(RawTriangular); ok {
// Trmm updates in place, so copy aU first.
bmat := bU.RawTriangular()
if aTrans {
c := getWorkspace(ac, ar, false)
var tmp Dense
tmp.SetRawMatrix(aU.RawMatrix())
c.Copy(&tmp)
bT := blas.Trans
if bTrans {
bT = blas.NoTrans
}
blas64.Trmm(blas.Left, bT, 1, bmat, c.mat)
strictCopy(m, c.T())
putWorkspace(c)
return
}
m.Copy(a)
blas64.Trmm(blas.Right, bT, 1, bmat, m.mat)
return
}
if bU, ok := bU.(*Vector); ok {
bvec := bU.RawVector()
if bTrans {
// {ar,1} x {1,bc}, which is not a vector.
// Instead, construct B as a General.
bmat := blas64.General{
Rows: bc,
Cols: 1,
Stride: bvec.Inc,
Data: bvec.Data,
}
blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
return
}
cvec := blas64.Vector{
Inc: m.mat.Stride,
Data: m.mat.Data,
}
blas64.Gemv(aT, 1, amat, bvec, 0, cvec)
return
}
}
if bU, ok := bU.(RawMatrixer); ok {
bmat := bU.RawMatrix()
if aU, ok := aU.(RawSymmetricer); ok {
//.........这里部分代码省略.........
开发者ID:rwcarlsen,项目名称:cloudlus,代码行数:101,代码来源:dense_arithmetic.go
示例20: constructQ
// constructQ constructs the Q matrix from the result of dgeqrf and dgeqr2
func constructQ(kind string, m, n int, a []float64, lda int, tau []float64) blas64.General {
k := min(m, n)
var sz int
switch kind {
case "QR":
sz = m
case "LQ":
sz = n
}
q := blas64.General{
Rows: sz,
Cols: sz,
Stride: sz,
Data: make([]float64, sz*sz),
}
for i := 0; i < sz; i++ {
q.Data[i*sz+i] = 1
}
qCopy := blas64.General{
Rows: q.Rows,
Cols: q.Cols,
Stride: q.Stride,
Data: make([]float64, len(q.Data)),
}
for i := 0; i < k; i++ {
h := blas64.General{
Rows: sz,
Cols: sz,
Stride: sz,
Data: make([]float64, sz*sz),
}
for j := 0; j < sz; j++ {
h.Data[j*sz+j] = 1
}
vVec := blas64.Vector{
Inc: 1,
Data: make([]float64, sz),
}
for j := 0; j < i; j++ {
vVec.Data[j] = 0
}
vVec.Data[i] = 1
switch kind {
case "QR":
for j := i + 1; j < sz; j++ {
vVec.Data[j] = a[lda*j+i]
}
case "LQ":
for j := i + 1; j < sz; j++ {
vVec.Data[j] = a[i*lda+j]
}
}
blas64.Ger(-tau[i], vVec, vVec, h)
copy(qCopy.Data, q.Data)
// Mulitply q by the new h
switch kind {
case "QR":
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, h, 0, q)
case "LQ":
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qCopy, 0, q)
}
}
return q
}
开发者ID:jacobxk,项目名称:lapack,代码行数:66,代码来源:general.go
注:本文中的github.com/gonum/blas/blas64.Gemm函数示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论