• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Golang blas64.Gemm函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Golang中github.com/gonum/blas/blas64.Gemm函数的典型用法代码示例。如果您正苦于以下问题:Golang Gemm函数的具体用法?Golang Gemm怎么用?Golang Gemm使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了Gemm函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Golang代码示例。

示例1: constructH

func constructH(tau []float64, v blas64.General, store lapack.StoreV, direct lapack.Direct) blas64.General {
	m := v.Rows
	k := v.Cols
	if store == lapack.RowWise {
		m, k = k, m
	}
	h := blas64.General{
		Rows:   m,
		Cols:   m,
		Stride: m,
		Data:   make([]float64, m*m),
	}
	for i := 0; i < m; i++ {
		h.Data[i*m+i] = 1
	}
	for i := 0; i < k; i++ {
		vecData := make([]float64, m)
		if store == lapack.ColumnWise {
			for j := 0; j < m; j++ {
				vecData[j] = v.Data[j*v.Cols+i]
			}
		} else {
			for j := 0; j < m; j++ {
				vecData[j] = v.Data[i*v.Cols+j]
			}
		}
		vec := blas64.Vector{
			Inc:  1,
			Data: vecData,
		}

		hi := blas64.General{
			Rows:   m,
			Cols:   m,
			Stride: m,
			Data:   make([]float64, m*m),
		}
		for i := 0; i < m; i++ {
			hi.Data[i*m+i] = 1
		}
		// hi = I - tau * v * v^T
		blas64.Ger(-tau[i], vec, vec, hi)

		hcopy := blas64.General{
			Rows:   m,
			Cols:   m,
			Stride: m,
			Data:   make([]float64, m*m),
		}
		copy(hcopy.Data, h.Data)
		if direct == lapack.Forward {
			// H = H * H_I in forward mode
			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hcopy, hi, 0, h)
		} else {
			// H = H_I * H in backward mode
			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hi, hcopy, 0, h)
		}
	}
	return h
}
开发者ID:jacobxk,项目名称:lapack,代码行数:60,代码来源:general.go


示例2: convolveR

func (c *ConvLayer) convolveR(v autofunc.RVector, in, inR linalg.Vector, out *Tensor3) {
	inMat := c.inputToMatrix(in)
	inMatR := c.inputToMatrix(inR)
	filterMat := blas64.General{
		Rows:   c.FilterCount,
		Cols:   inMat.Cols,
		Stride: inMat.Stride,
		Data:   c.FilterVar.Vector,
	}
	outMat := blas64.General{
		Rows:   out.Width * out.Height,
		Cols:   out.Depth,
		Stride: out.Depth,
		Data:   out.Data,
	}
	blas64.Gemm(blas.NoTrans, blas.Trans, 1, inMatR, filterMat, 0, outMat)
	if filterRV, ok := v[c.FilterVar]; ok {
		filterMatR := blas64.General{
			Rows:   c.FilterCount,
			Cols:   inMat.Cols,
			Stride: inMat.Stride,
			Data:   filterRV,
		}
		blas64.Gemm(blas.NoTrans, blas.Trans, 1, inMat, filterMatR, 1, outMat)
	}

	if biasRV, ok := v[c.Biases]; ok {
		biasVec := blas64.Vector{Inc: 1, Data: biasRV}
		for i := 0; i < len(out.Data); i += outMat.Cols {
			outRow := out.Data[i : i+outMat.Cols]
			outVec := blas64.Vector{Inc: 1, Data: outRow}
			blas64.Axpy(len(outRow), 1, biasVec, outVec)
		}
	}
}
开发者ID:unixpickle,项目名称:weakai,代码行数:35,代码来源:conv_layer.go


示例3: propagateSingle

func (c *convLayerResult) propagateSingle(input, upstream, downstream linalg.Vector,
	grad autofunc.Gradient) {
	upstreamMat := blas64.General{
		Rows:   c.Layer.OutputWidth() * c.Layer.OutputHeight(),
		Cols:   c.Layer.OutputDepth(),
		Stride: c.Layer.OutputDepth(),
		Data:   upstream,
	}

	if downstream != nil {
		inDeriv := c.Layer.inputToMatrix(input)
		filterMat := blas64.General{
			Rows:   len(c.Layer.Filters),
			Cols:   c.Layer.FilterWidth * c.Layer.FilterHeight * c.Layer.InputDepth,
			Stride: c.Layer.FilterWidth * c.Layer.FilterHeight * c.Layer.InputDepth,
			Data:   c.Layer.FilterVar.Vector,
		}
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, upstreamMat, filterMat, 0, inDeriv)
		flattened := NewTensor3Col(c.Layer.InputWidth, c.Layer.InputHeight,
			c.Layer.InputDepth, inDeriv.Data, c.Layer.FilterWidth,
			c.Layer.FilterHeight, c.Layer.Stride)
		copy(downstream, flattened.Data)
	}

	if filterGrad, ok := grad[c.Layer.FilterVar]; ok {
		inMatrix := c.Layer.inputToMatrix(input)
		destMat := blas64.General{
			Rows:   len(c.Layer.Filters),
			Cols:   c.Layer.FilterWidth * c.Layer.FilterHeight * c.Layer.InputDepth,
			Stride: c.Layer.FilterWidth * c.Layer.FilterHeight * c.Layer.InputDepth,
			Data:   filterGrad,
		}
		blas64.Gemm(blas.Trans, blas.NoTrans, 1, upstreamMat, inMatrix, 1, destMat)
	}
}
开发者ID:unixpickle,项目名称:weakai,代码行数:35,代码来源:conv_layer.go


示例4: testDlarfx

func testDlarfx(t *testing.T, impl Dlarfxer, side blas.Side, m, n, extra int, rnd *rand.Rand) {
	const tol = 1e-13

	c := randomGeneral(m, n, n+extra, rnd)
	cWant := randomGeneral(m, n, n+extra, rnd)
	tau := rnd.NormFloat64()

	var (
		v []float64
		h blas64.General
	)
	if side == blas.Left {
		v = randomSlice(m, rnd)
		h = eye(m, m+extra)
	} else {
		v = randomSlice(n, rnd)
		h = eye(n, n+extra)
	}
	blas64.Ger(-tau, blas64.Vector{Inc: 1, Data: v}, blas64.Vector{Inc: 1, Data: v}, h)
	if side == blas.Left {
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, c, 0, cWant)
	} else {
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, c, h, 0, cWant)
	}

	var work []float64
	if h.Rows > 10 {
		// Allocate work only if H has order > 10.
		if side == blas.Left {
			work = make([]float64, n)
		} else {
			work = make([]float64, m)
		}
	}

	impl.Dlarfx(side, m, n, v, tau, c.Data, c.Stride, work)

	prefix := fmt.Sprintf("Case side=%v, m=%v, n=%v, extra=%v", side, m, n, extra)

	// Check any invalid modifications of c.
	if !generalOutsideAllNaN(c) {
		t.Errorf("%v: out-of-range write to C\n%v", prefix, c.Data)
	}

	if !equalApproxGeneral(c, cWant, tol) {
		t.Errorf("%v: unexpected C\n%v", prefix, c.Data)
	}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:48,代码来源:dlarfx.go


示例5: Mul

// Mul takes the matrix product of a and b, placing the result in the receiver.
//
// See the Muler interface for more information.
func (m *Dense) Mul(a, b Matrix) {
	ar, ac := a.Dims()
	br, bc := b.Dims()

	if ac != br {
		panic(ErrShape)
	}

	m.reuseAs(ar, bc)
	var w *Dense
	if m != a && m != b {
		w = m
	} else {
		w = getWorkspace(ar, bc, false)
		defer func() {
			m.Copy(w)
			putWorkspace(w)
		}()
	}

	if a, ok := a.(RawMatrixer); ok {
		if b, ok := b.(RawMatrixer); ok {
			amat, bmat := a.RawMatrix(), b.RawMatrix()
			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, amat, bmat, 0, w.Mat)
			return
		}
	}

	if a, ok := a.(Vectorer); ok {
		if b, ok := b.(Vectorer); ok {
			row := make([]float64, ac)
			col := make([]float64, br)
			for r := 0; r < ar; r++ {
				dataTmp := w.Mat.Data[r*w.Mat.Stride : r*w.Mat.Stride+bc]
				for c := 0; c < bc; c++ {
					dataTmp[c] = blas64.Dot(ac,
						blas64.Vector{Inc: 1, Data: a.Row(row, r)},
						blas64.Vector{Inc: 1, Data: b.Col(col, c)},
					)
				}
			}
			return
		}
	}

	row := make([]float64, ac)
	for r := 0; r < ar; r++ {
		for i := range row {
			row[i] = a.At(r, i)
		}
		for c := 0; c < bc; c++ {
			var v float64
			for i, e := range row {
				v += e * b.At(i, c)
			}
			w.Mat.Data[r*w.Mat.Stride+c] = v
		}
	}
}
开发者ID:drewlanenga,项目名称:matrix,代码行数:62,代码来源:dense_arithmetic.go


示例6: testDorghr

func testDorghr(t *testing.T, impl Dorghrer, n, ilo, ihi, extra int, optwork bool, rnd *rand.Rand) {
	const tol = 1e-14

	// Construct the matrix A with elementary reflectors and scalar factors tau.
	a := randomGeneral(n, n, n+extra, rnd)
	var tau []float64
	if n > 1 {
		tau = nanSlice(n - 1)
	}
	work := nanSlice(max(1, n)) // Minimum work for Dgehrd.
	impl.Dgehrd(n, ilo, ihi, a.Data, a.Stride, tau, work, len(work))

	// Extract Q for later comparison.
	q := eye(n, n)
	qCopy := cloneGeneral(q)
	for j := ilo; j < ihi; j++ {
		h := eye(n, n)
		v := blas64.Vector{
			Inc:  1,
			Data: make([]float64, n),
		}
		v.Data[j+1] = 1
		for i := j + 2; i < ihi+1; i++ {
			v.Data[i] = a.Data[i*a.Stride+j]
		}
		blas64.Ger(-tau[j], v, v, h)
		copy(qCopy.Data, q.Data)
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, h, 0, q)
	}

	if optwork {
		work = nanSlice(1)
		impl.Dorghr(n, ilo, ihi, a.Data, a.Stride, tau, work, -1)
		work = nanSlice(int(work[0]))
	} else {
		work = nanSlice(max(1, ihi-ilo))
	}
	impl.Dorghr(n, ilo, ihi, a.Data, a.Stride, tau, work, len(work))

	prefix := fmt.Sprintf("Case n=%v, ilo=%v, ihi=%v, extra=%v, optwork=%v", n, ilo, ihi, extra, optwork)
	if !generalOutsideAllNaN(a) {
		t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data)
	}
	if !isOrthonormal(a) {
		t.Errorf("%v: A is not orthogonal\n%v", prefix, a.Data)
	}
	for i := 0; i < n; i++ {
		for j := 0; j < n; j++ {
			aij := a.Data[i*a.Stride+j]
			qij := q.Data[i*q.Stride+j]
			if math.Abs(aij-qij) > tol {
				t.Errorf("%v: unexpected value of A[%v,%v]. want %v, got %v", prefix, i, j, qij, aij)
			}
		}
	}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:56,代码来源:dorghr.go


示例7: QFromQR

// QFromQR extracts the m×m orthonormal matrix Q from a QR decomposition.
func (m *Dense) QFromQR(qr *QR) {
	r, c := qr.qr.Dims()
	m.reuseAs(r, r)

	// Set Q = I.
	for i := 0; i < r; i++ {
		for j := 0; j < i; j++ {
			m.mat.Data[i*m.mat.Stride+j] = 0
		}
		m.mat.Data[i*m.mat.Stride+i] = 1
		for j := i + 1; j < r; j++ {
			m.mat.Data[i*m.mat.Stride+j] = 0
		}
	}

	// Construct Q from the elementary reflectors.
	h := blas64.General{
		Rows:   r,
		Cols:   r,
		Stride: r,
		Data:   make([]float64, r*r),
	}
	qCopy := getWorkspace(r, r, false)
	v := blas64.Vector{
		Inc:  1,
		Data: make([]float64, r),
	}
	for i := 0; i < c; i++ {
		// Set h = I.
		for i := range h.Data {
			h.Data[i] = 0
		}
		for j := 0; j < r; j++ {
			h.Data[j*r+j] = 1
		}

		// Set the vector data as the elementary reflector.
		for j := 0; j < i; j++ {
			v.Data[j] = 0
		}
		v.Data[i] = 1
		for j := i + 1; j < r; j++ {
			v.Data[j] = qr.qr.mat.Data[j*qr.qr.mat.Stride+i]
		}

		// Compute the multiplication matrix.
		blas64.Ger(-qr.tau[i], v, v, h)
		qCopy.Copy(m)
		blas64.Gemm(blas.NoTrans, blas.NoTrans,
			1, qCopy.mat, h,
			0, m.mat)
	}
}
开发者ID:yonglehou,项目名称:matrix,代码行数:54,代码来源:qr.go


示例8: QFromLQ

// QFromLQ extracts the n×n orthonormal matrix Q from an LQ decomposition.
func (m *Dense) QFromLQ(lq *LQ) {
	r, c := lq.lq.Dims()
	m.reuseAs(c, c)

	// Set Q = I.
	for i := 0; i < c; i++ {
		for j := 0; j < i; j++ {
			m.mat.Data[i*m.mat.Stride+j] = 0
		}
		m.mat.Data[i*m.mat.Stride+i] = 1
		for j := i + 1; j < c; j++ {
			m.mat.Data[i*m.mat.Stride+j] = 0
		}
	}

	// Construct Q from the elementary reflectors.
	h := blas64.General{
		Rows:   c,
		Cols:   c,
		Stride: c,
		Data:   make([]float64, c*c),
	}
	qCopy := getWorkspace(c, c, false)
	v := blas64.Vector{
		Inc:  1,
		Data: make([]float64, c),
	}
	for i := 0; i < r; i++ {
		// Set h = I.
		for i := range h.Data {
			h.Data[i] = 0
		}
		for j := 0; j < c; j++ {
			h.Data[j*c+j] = 1
		}

		// Set the vector data as the elementary reflector.
		for j := 0; j < i; j++ {
			v.Data[j] = 0
		}
		v.Data[i] = 1
		for j := i + 1; j < c; j++ {
			v.Data[j] = lq.lq.mat.Data[i*lq.lq.mat.Stride+j]
		}

		// Compute the multiplication matrix.
		blas64.Ger(-lq.tau[i], v, v, h)
		qCopy.Copy(m)
		blas64.Gemm(blas.NoTrans, blas.NoTrans,
			1, h, qCopy.mat,
			0, m.mat)
	}
}
开发者ID:yonglehou,项目名称:matrix,代码行数:54,代码来源:lq.go


示例9: Eval

// Eval returns a matrix literal.
func (m1 *Mul) Eval() MatrixLiteral {

	// This should be replaced with a call to Eval on each side, and then a type
	// switch to handle the various matrix literals.

	lm := m1.Left.Eval()
	rm := m1.Right.Eval()

	left := lm.AsGeneral()
	right := rm.AsGeneral()
	r, c := m1.Dims()
	m := blas64.General{
		Rows:   r,
		Cols:   c,
		Stride: c,
		Data:   make([]float64, r*c),
	}
	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, left, right, 0, m)
	return &General{m}
}
开发者ID:jonlawlor,项目名称:matrixexp,代码行数:21,代码来源:mul.go


示例10: convolve

func (c *ConvLayer) convolve(in linalg.Vector, out *Tensor3) {
	inMat := c.inputToMatrix(in)
	filterMat := blas64.General{
		Rows:   c.FilterCount,
		Cols:   inMat.Cols,
		Stride: inMat.Stride,
		Data:   c.FilterVar.Vector,
	}
	outMat := blas64.General{
		Rows:   out.Width * out.Height,
		Cols:   out.Depth,
		Stride: out.Depth,
		Data:   out.Data,
	}
	blas64.Gemm(blas.NoTrans, blas.Trans, 1, inMat, filterMat, 0, outMat)

	biasVec := blas64.Vector{Inc: 1, Data: c.Biases.Vector}
	for i := 0; i < len(out.Data); i += outMat.Cols {
		outRow := out.Data[i : i+outMat.Cols]
		outVec := blas64.Vector{Inc: 1, Data: outRow}
		blas64.Axpy(len(outRow), 1, biasVec, outVec)
	}
}
开发者ID:unixpickle,项目名称:weakai,代码行数:23,代码来源:conv_layer.go


示例11: dlatrdCheckDecomposition

// dlatrdCheckDecomposition checks that the first nb rows have been successfully
// reduced.
func dlatrdCheckDecomposition(t *testing.T, uplo blas.Uplo, n, nb int, e, tau, a []float64, lda int, aGen, q blas64.General) bool {
	// Compute Q^T * A * Q.
	tmp := blas64.General{
		Rows:   n,
		Cols:   n,
		Stride: n,
		Data:   make([]float64, n*n),
	}

	ans := blas64.General{
		Rows:   n,
		Cols:   n,
		Stride: n,
		Data:   make([]float64, n*n),
	}

	blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aGen, 0, tmp)
	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, ans)

	// Compare with T.
	if uplo == blas.Upper {
		for i := n - 1; i >= n-nb; i-- {
			for j := 0; j < n; j++ {
				v := ans.Data[i*ans.Stride+j]
				switch {
				case i == j:
					if math.Abs(v-a[i*lda+j]) > 1e-10 {
						return false
					}
				case i == j-1:
					if math.Abs(a[i*lda+j]-1) > 1e-10 {
						return false
					}
					if math.Abs(v-e[i]) > 1e-10 {
						return false
					}
				case i == j+1:
				default:
					if math.Abs(v) > 1e-10 {
						return false
					}
				}
			}
		}
	} else {
		for i := 0; i < nb; i++ {
			for j := 0; j < n; j++ {
				v := ans.Data[i*ans.Stride+j]
				switch {
				case i == j:
					if math.Abs(v-a[i*lda+j]) > 1e-10 {
						return false
					}
				case i == j-1:
				case i == j+1:
					if math.Abs(a[i*lda+j]-1) > 1e-10 {
						return false
					}
					if math.Abs(v-e[i-1]) > 1e-10 {
						return false
					}
				default:
					if math.Abs(v) > 1e-10 {
						return false
					}
				}
			}
		}
	}
	return true
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:73,代码来源:dlatrd.go


示例12: DlarfbTest


//.........这里部分代码省略.........
						a := make([]float64, ma*lda)
						for i := 0; i < ma; i++ {
							for j := 0; j < lda; j++ {
								a[i*lda+j] = rnd.Float64()
							}
						}
						k := min(ma, na)

						// H is always ma x ma
						var m, n, rowsWork int
						switch {
						default:
							panic("not implemented")
						case side == blas.Left:
							m = test.ma
							n = test.cdim
							rowsWork = n
						case side == blas.Right:
							m = test.cdim
							n = test.ma
							rowsWork = m
						}

						// Use dgeqr2 to find the v vectors
						tau := make([]float64, na)
						work := make([]float64, na)
						impl.Dgeqr2(ma, k, a, lda, tau, work)

						// Correct the v vectors based on the direct and store
						vMatTmp := extractVMat(ma, na, a, lda, lapack.Forward, lapack.ColumnWise)
						vMat := constructVMat(vMatTmp, store, direct)
						v := vMat.Data
						ldv := vMat.Stride

						// Use dlarft to find the t vector
						ldt := test.ldt
						if ldt == 0 {
							ldt = k
						}
						tm := make([]float64, k*ldt)

						impl.Dlarft(direct, store, ma, k, v, ldv, tau, tm, ldt)

						// Generate c matrix
						ldc := test.ldc
						if ldc == 0 {
							ldc = n
						}
						c := make([]float64, m*ldc)
						for i := 0; i < m; i++ {
							for j := 0; j < ldc; j++ {
								c[i*ldc+j] = rnd.Float64()
							}
						}
						cCopy := make([]float64, len(c))
						copy(cCopy, c)

						ldwork := k
						work = make([]float64, rowsWork*k)

						// Call Dlarfb with this information
						impl.Dlarfb(side, trans, direct, store, m, n, k, v, ldv, tm, ldt, c, ldc, work, ldwork)

						h := constructH(tau, vMat, store, direct)

						cMat := blas64.General{
							Rows:   m,
							Cols:   n,
							Stride: ldc,
							Data:   make([]float64, m*ldc),
						}
						copy(cMat.Data, cCopy)
						ans := blas64.General{
							Rows:   m,
							Cols:   n,
							Stride: ldc,
							Data:   make([]float64, m*ldc),
						}
						copy(ans.Data, cMat.Data)
						switch {
						default:
							panic("not implemented")
						case side == blas.Left && trans == blas.NoTrans:
							blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, cMat, 0, ans)
						case side == blas.Left && trans == blas.Trans:
							blas64.Gemm(blas.Trans, blas.NoTrans, 1, h, cMat, 0, ans)
						case side == blas.Right && trans == blas.NoTrans:
							blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat, h, 0, ans)
						case side == blas.Right && trans == blas.Trans:
							blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMat, h, 0, ans)
						}
						if !floats.EqualApprox(ans.Data, c, 1e-14) {
							t.Errorf("Cas %v mismatch. Want %v, got %v.", cas, ans.Data, c)
						}
					}
				}
			}
		}
	}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:101,代码来源:dlarfb.go


示例13: DlarfTest


//.........这里部分代码省略.........

			lastr: 0,
			lastc: 1,

			tau: 2,
		},
		{
			m:   10,
			n:   10,
			ldc: 10,

			incv:  4,
			lastv: 6,

			lastr: 9,
			lastc: 8,

			tau: 2,
		},
	} {
		// Construct a random matrix.
		c := make([]float64, test.ldc*test.m)
		for i := 0; i <= test.lastr; i++ {
			for j := 0; j <= test.lastc; j++ {
				c[i*test.ldc+j] = rand.Float64()
			}
		}
		cCopy := make([]float64, len(c))
		copy(cCopy, c)
		cCopy2 := make([]float64, len(c))
		copy(cCopy2, c)

		// Test with side right.
		sz := max(test.m, test.n) // so v works for both right and left side.
		v := make([]float64, test.incv*sz+1)
		// Fill with nonzero entries up until lastv.
		for i := 0; i <= test.lastv; i++ {
			v[i*test.incv] = rand.Float64()
		}
		// Construct h explicitly to compare.
		h := make([]float64, test.n*test.n)
		for i := 0; i < test.n; i++ {
			h[i*test.n+i] = 1
		}
		hMat := blas64.General{
			Rows:   test.n,
			Cols:   test.n,
			Stride: test.n,
			Data:   h,
		}
		vVec := blas64.Vector{
			Inc:  test.incv,
			Data: v,
		}
		blas64.Ger(-test.tau, vVec, vVec, hMat)

		// Apply multiplication (2nd copy is to avoid aliasing).
		cMat := blas64.General{
			Rows:   test.m,
			Cols:   test.n,
			Stride: test.ldc,
			Data:   cCopy,
		}
		cMat2 := blas64.General{
			Rows:   test.m,
			Cols:   test.n,
			Stride: test.ldc,
			Data:   cCopy2,
		}
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat2, hMat, 0, cMat)

		// cMat now stores the true answer. Compare with the function call.
		work := make([]float64, sz)
		impl.Dlarf(blas.Right, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
		if !floats.EqualApprox(c, cMat.Data, 1e-14) {
			t.Errorf("Dlarf mismatch right, case %v. Want %v, got %v", i, cMat.Data, c)
		}

		// Test on the left side.
		copy(c, cCopy2)
		copy(cCopy, c)
		// Construct h.
		h = make([]float64, test.m*test.m)
		for i := 0; i < test.m; i++ {
			h[i*test.m+i] = 1
		}
		hMat = blas64.General{
			Rows:   test.m,
			Cols:   test.m,
			Stride: test.m,
			Data:   h,
		}
		blas64.Ger(-test.tau, vVec, vVec, hMat)
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, cMat2, 0, cMat)
		impl.Dlarf(blas.Left, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
		if !floats.EqualApprox(c, cMat.Data, 1e-14) {
			t.Errorf("Dlarf mismatch left, case %v. Want %v, got %v", i, cMat.Data, c)
		}
	}
}
开发者ID:RomainVabre,项目名称:origin,代码行数:101,代码来源:dlarf.go


示例14: checkPLU

// checkPLU checks that the PLU factorization contained in factorize matches
// the original matrix contained in original.
func checkPLU(t *testing.T, ok bool, m, n, lda int, ipiv []int, factorized, original []float64, tol float64, print bool) {
	var hasZeroDiagonal bool
	for i := 0; i < min(m, n); i++ {
		if factorized[i*lda+i] == 0 {
			hasZeroDiagonal = true
			break
		}
	}
	if hasZeroDiagonal && ok {
		t.Error("Has a zero diagonal but returned ok")
	}
	if !hasZeroDiagonal && !ok {
		t.Error("Non-zero diagonal but returned !ok")
	}

	// Check that the LU decomposition is correct.
	mn := min(m, n)
	l := make([]float64, m*mn)
	ldl := mn
	u := make([]float64, mn*n)
	ldu := n
	for i := 0; i < m; i++ {
		for j := 0; j < n; j++ {
			v := factorized[i*lda+j]
			switch {
			case i == j:
				l[i*ldl+i] = 1
				u[i*ldu+i] = v
			case i > j:
				l[i*ldl+j] = v
			case i < j:
				u[i*ldu+j] = v
			}
		}
	}

	LU := blas64.General{
		Rows:   m,
		Cols:   n,
		Stride: n,
		Data:   make([]float64, m*n),
	}
	U := blas64.General{
		Rows:   mn,
		Cols:   n,
		Stride: ldu,
		Data:   u,
	}
	L := blas64.General{
		Rows:   m,
		Cols:   mn,
		Stride: ldl,
		Data:   l,
	}
	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU)

	p := make([]float64, m*m)
	ldp := m
	for i := 0; i < m; i++ {
		p[i*ldp+i] = 1
	}
	for i := len(ipiv) - 1; i >= 0; i-- {
		v := ipiv[i]
		blas64.Swap(m, blas64.Vector{1, p[i*ldp:]}, blas64.Vector{1, p[v*ldp:]})
	}
	P := blas64.General{
		Rows:   m,
		Cols:   m,
		Stride: m,
		Data:   p,
	}
	aComp := blas64.General{
		Rows:   m,
		Cols:   n,
		Stride: lda,
		Data:   make([]float64, m*lda),
	}
	copy(aComp.Data, factorized)
	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp)
	if !floats.EqualApprox(aComp.Data, original, tol) {
		if print {
			t.Errorf("PLU multiplication does not match original matrix.\nWant: %v\nGot: %v", original, aComp.Data)
			return
		}
		t.Error("PLU multiplication does not match original matrix.")
	}
}
开发者ID:jacobxk,项目名称:lapack,代码行数:89,代码来源:dgetf2.go


示例15: Dgelq2Test

func Dgelq2Test(t *testing.T, impl Dgelq2er) {
	for c, test := range []struct {
		m, n, lda int
	}{
		{1, 1, 0},
		{2, 2, 0},
		{3, 2, 0},
		{2, 3, 0},
		{1, 12, 0},
		{2, 6, 0},
		{3, 4, 0},
		{4, 3, 0},
		{6, 2, 0},
		{1, 12, 0},
		{1, 1, 20},
		{2, 2, 20},
		{3, 2, 20},
		{2, 3, 20},
		{1, 12, 20},
		{2, 6, 20},
		{3, 4, 20},
		{4, 3, 20},
		{6, 2, 20},
		{1, 12, 20},
	} {
		n := test.n
		m := test.m
		lda := test.lda
		if lda == 0 {
			lda = test.n
		}
		k := min(m, n)
		tau := make([]float64, k)
		for i := range tau {
			tau[i] = rand.Float64()
		}
		work := make([]float64, m)
		for i := range work {
			work[i] = rand.Float64()
		}
		a := make([]float64, m*lda)
		for i := 0; i < m*lda; i++ {
			a[i] = rand.Float64()
		}
		aCopy := make([]float64, len(a))
		copy(aCopy, a)
		impl.Dgelq2(m, n, a, lda, tau, work)

		Q := constructQ("LQ", m, n, a, lda, tau)

		// Check that Q is orthonormal
		for i := 0; i < Q.Rows; i++ {
			nrm := blas64.Nrm2(Q.Cols, blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]})
			if math.Abs(nrm-1) > 1e-14 {
				t.Errorf("Q not normal. Norm is %v", nrm)
			}
			for j := 0; j < i; j++ {
				dot := blas64.Dot(Q.Rows,
					blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]},
					blas64.Vector{Inc: 1, Data: Q.Data[j*Q.Stride:]},
				)
				if math.Abs(dot) > 1e-14 {
					t.Errorf("Q not orthogonal. Dot is %v", dot)
				}
			}
		}

		L := blas64.General{
			Rows:   m,
			Cols:   n,
			Stride: n,
			Data:   make([]float64, m*n),
		}
		for i := 0; i < m; i++ {
			for j := 0; j <= min(i, n-1); j++ {
				L.Data[i*L.Stride+j] = a[i*lda+j]
			}
		}

		ans := blas64.General{
			Rows:   m,
			Cols:   n,
			Stride: lda,
			Data:   make([]float64, m*lda),
		}
		copy(ans.Data, aCopy)
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, Q, 0, ans)
		if !floats.EqualApprox(aCopy, ans.Data, 1e-14) {
			t.Errorf("Case %v, LQ mismatch. Want %v, got %v.", c, aCopy, ans.Data)
		}
	}
}
开发者ID:RomainVabre,项目名称:origin,代码行数:92,代码来源:dgelq2.go


示例16: testDgebak

func testDgebak(t *testing.T, impl Dgebaker, job lapack.Job, side blas.Side, ilo, ihi int, v blas64.General, rnd *rand.Rand) {
	const tol = 1e-15
	n := v.Rows
	m := v.Cols
	extra := v.Stride - v.Cols

	// Create D and D^{-1} by generating random scales between ilo and ihi.
	d := eye(n, n)
	dinv := eye(n, n)
	scale := nanSlice(n)
	if job == lapack.Scale || job == lapack.PermuteScale {
		if ilo == ihi {
			scale[ilo] = 1
		} else {
			for i := ilo; i <= ihi; i++ {
				scale[i] = 2 * rnd.Float64()
				d.Data[i*d.Stride+i] = scale[i]
				dinv.Data[i*dinv.Stride+i] = 1 / scale[i]
			}
		}
	}

	// Create P by generating random column swaps.
	p := eye(n, n)
	if job == lapack.Permute || job == lapack.PermuteScale {
		// Make up some random permutations.
		for i := n - 1; i > ihi; i-- {
			scale[i] = float64(rnd.Intn(i + 1))
			blas64.Swap(n,
				blas64.Vector{p.Stride, p.Data[i:]},
				blas64.Vector{p.Stride, p.Data[int(scale[i]):]})
		}
		for i := 0; i < ilo; i++ {
			scale[i] = float64(i + rnd.Intn(ihi-i+1))
			blas64.Swap(n,
				blas64.Vector{p.Stride, p.Data[i:]},
				blas64.Vector{p.Stride, p.Data[int(scale[i]):]})
		}
	}

	got := cloneGeneral(v)
	impl.Dgebak(job, side, n, ilo, ihi, scale, m, got.Data, got.Stride)

	prefix := fmt.Sprintf("Case job=%v, side=%v, n=%v, ilo=%v, ihi=%v, m=%v, extra=%v",
		job, side, n, ilo, ihi, m, extra)

	if !generalOutsideAllNaN(got) {
		t.Errorf("%v: out-of-range write to V\n%v", prefix, got.Data)
	}

	// Compute D*V or D^{-1}*V and store into dv.
	dv := zeros(n, m, m)
	if side == blas.Right {
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d, v, 0, dv)
	} else {
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, dinv, v, 0, dv)
	}
	// Compute P*D*V or P*D^{-1}*V and store into want.
	want := zeros(n, m, m)
	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, p, dv, 0, want)

	if !equalApproxGeneral(want, got, tol) {
		t.Errorf("%v: unexpected value of V", prefix)
	}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:65,代码来源:dgebak.go


示例17: Dgeqr2Test

func Dgeqr2Test(t *testing.T, impl Dgeqr2er) {
	for c, test := range []struct {
		m, n, lda int
	}{
		{1, 1, 0},
		{2, 2, 0},
		{3, 2, 0},
		{2, 3, 0},
		{1, 12, 0},
		{2, 6, 0},
		{3, 4, 0},
		{4, 3, 0},
		{6, 2, 0},
		{12, 1, 0},
		{1, 1, 20},
		{2, 2, 20},
		{3, 2, 20},
		{2, 3, 20},
		{1, 12, 20},
		{2, 6, 20},
		{3, 4, 20},
		{4, 3, 20},
		{6, 2, 20},
		{12, 1, 20},
	} {
		n := test.n
		m := test.m
		lda := test.lda
		if lda == 0 {
			lda = test.n
		}
		a := make([]float64, m*lda)
		for i := range a {
			a[i] = rand.Float64()
		}
		aCopy := make([]float64, len(a))
		k := min(m, n)
		tau := make([]float64, k)
		for i := range tau {
			tau[i] = rand.Float64()
		}
		work := make([]float64, n)
		for i := range work {
			work[i] = rand.Float64()
		}
		copy(aCopy, a)
		impl.Dgeqr2(m, n, a, lda, tau, work)

		// Test that the QR factorization has completed successfully. Compute
		// Q based on the vectors.
		q := constructQ("QR", m, n, a, lda, tau)

		// Check that q is orthonormal
		for i := 0; i < m; i++ {
			nrm := blas64.Nrm2(m, blas64.Vector{1, q.Data[i*m:]})
			if math.Abs(nrm-1) > 1e-14 {
				t.Errorf("Case %v, q not normal", c)
			}
			for j := 0; j < i; j++ {
				dot := blas64.Dot(m, blas64.Vector{1, q.Data[i*m:]}, blas64.Vector{1, q.Data[j*m:]})
				if math.Abs(dot) > 1e-14 {
					t.Errorf("Case %v, q not orthogonal", i)
				}
			}
		}
		// Check that A = Q * R
		r := blas64.General{
			Rows:   m,
			Cols:   n,
			Stride: n,
			Data:   make([]float64, m*n),
		}
		for i := 0; i < m; i++ {
			for j := i; j < n; j++ {
				r.Data[i*n+j] = a[i*lda+j]
			}
		}
		atmp := blas64.General{
			Rows:   m,
			Cols:   n,
			Stride: lda,
			Data:   make([]float64, m*lda),
		}
		copy(atmp.Data, a)
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, r, 0, atmp)
		if !floats.EqualApprox(atmp.Data, aCopy, 1e-14) {
			t.Errorf("Q*R != a")
		}
	}
}
开发者ID:RomainVabre,项目名称:origin,代码行数:90,代码来源:dgeqr2.go


示例18: constructQPBidiagonal


//.........这里部分代码省略.........
		}
	}

	if vect == lapack.ApplyQ {
		if m >= n {
			for i := 0; i < m; i++ {
				for j := 0; j <= min(nb-1, i); j++ {
					if i == j {
						v.Data[i*ldv+j] = 1
						continue
					}
					v.Data[i*ldv+j] = a[i*lda+j]
				}
			}
		} else {
			for i := 1; i < m; i++ {
				for j := 0; j <= min(nb-1, i-1); j++ {
					if i-1 == j {
						v.Data[i*ldv+j] = 1
						continue
					}
					v.Data[i*ldv+j] = a[i*lda+j]
				}
			}
		}
	} else {
		if m < n {
			for i := 0; i < nb; i++ {
				for j := i; j < n; j++ {
					if i == j {
						v.Data[i*ldv+j] = 1
						continue
					}
					v.Data[i*ldv+j] = a[i*lda+j]
				}
			}
		} else {
			for i := 0; i < nb; i++ {
				for j := i + 1; j < n; j++ {
					if j-1 == i {
						v.Data[i*ldv+j] = 1
						continue
					}
					v.Data[i*ldv+j] = a[i*lda+j]
				}
			}
		}
	}

	// The variable name is a computation of Q, but the algorithm is mostly the
	// same for computing P (just with different data).
	qMat := blas64.General{
		Rows:   sz,
		Cols:   sz,
		Stride: sz,
		Data:   make([]float64, sz*sz),
	}
	hMat := blas64.General{
		Rows:   sz,
		Cols:   sz,
		Stride: sz,
		Data:   make([]float64, sz*sz),
	}
	// set Q to I
	for i := 0; i < sz; i++ {
		qMat.Data[i*qMat.Stride+i] = 1
	}
	for i := 0; i < nb; i++ {
		qCopy := blas64.General{Rows: qMat.Rows, Cols: qMat.Cols, Stride: qMat.Stride, Data: make([]float64, len(qMat.Data))}
		copy(qCopy.Data, qMat.Data)

		// Set g and h to I
		for i := 0; i < sz; i++ {
			for j := 0; j < sz; j++ {
				if i == j {
					hMat.Data[i*sz+j] = 1
				} else {
					hMat.Data[i*sz+j] = 0
				}
			}
		}
		var vi blas64.Vector
		// H -= tauQ[i] * v[i] * v[i]^t
		if vect == lapack.ApplyQ {
			vi = blas64.Vector{
				Inc:  v.Stride,
				Data: v.Data[i:],
			}
		} else {
			vi = blas64.Vector{
				Inc:  1,
				Data: v.Data[i*v.Stride:],
			}
		}
		blas64.Ger(-tau[i], vi, vi, hMat)
		// Q = Q * G[1]
		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, hMat, 0, qMat)
	}
	return qMat
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:101,代码来源:general.go


示例19: Mul

// Mul takes the matrix product of a and b, placing the result in the receiver.
//
// See the Muler interface for more information.
func (m *Dense) Mul(a, b Matrix) {
	ar, ac := a.Dims()
	br, bc := b.Dims()

	if ac != br {
		panic(ErrShape)
	}

	aU, aTrans := untranspose(a)
	bU, bTrans := untranspose(b)
	m.reuseAs(ar, bc)
	var restore func()
	if m == aU {
		m, restore = m.isolatedWorkspace(aU)
		defer restore()
	} else if m == bU {
		m, restore = m.isolatedWorkspace(bU)
		defer restore()
	}
	aT := blas.NoTrans
	if aTrans {
		aT = blas.Trans
	}
	bT := blas.NoTrans
	if bTrans {
		bT = blas.Trans
	}

	// Some of the cases do not have a transpose option, so create
	// temporary memory.
	// C = A^T * B = (B^T * A)^T
	// C^T = B^T * A.
	if aU, ok := aU.(RawMatrixer); ok {
		amat := aU.RawMatrix()
		if bU, ok := bU.(RawMatrixer); ok {
			bmat := bU.RawMatrix()
			blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
			return
		}
		if bU, ok := bU.(RawSymmetricer); ok {
			bmat := bU.RawSymmetric()
			if aTrans {
				c := getWorkspace(ac, ar, false)
				blas64.Symm(blas.Left, 1, bmat, amat, 0, c.mat)
				strictCopy(m, c.T())
				putWorkspace(c)
				return
			}
			blas64.Symm(blas.Right, 1, bmat, amat, 0, m.mat)
			return
		}
		if bU, ok := bU.(RawTriangular); ok {
			// Trmm updates in place, so copy aU first.
			bmat := bU.RawTriangular()
			if aTrans {
				c := getWorkspace(ac, ar, false)
				var tmp Dense
				tmp.SetRawMatrix(aU.RawMatrix())
				c.Copy(&tmp)
				bT := blas.Trans
				if bTrans {
					bT = blas.NoTrans
				}
				blas64.Trmm(blas.Left, bT, 1, bmat, c.mat)
				strictCopy(m, c.T())
				putWorkspace(c)
				return
			}
			m.Copy(a)
			blas64.Trmm(blas.Right, bT, 1, bmat, m.mat)
			return
		}
		if bU, ok := bU.(*Vector); ok {
			bvec := bU.RawVector()
			if bTrans {
				// {ar,1} x {1,bc}, which is not a vector.
				// Instead, construct B as a General.
				bmat := blas64.General{
					Rows:   bc,
					Cols:   1,
					Stride: bvec.Inc,
					Data:   bvec.Data,
				}
				blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
				return
			}
			cvec := blas64.Vector{
				Inc:  m.mat.Stride,
				Data: m.mat.Data,
			}
			blas64.Gemv(aT, 1, amat, bvec, 0, cvec)
			return
		}
	}
	if bU, ok := bU.(RawMatrixer); ok {
		bmat := bU.RawMatrix()
		if aU, ok := aU.(RawSymmetricer); ok {
//.........这里部分代码省略.........
开发者ID:rwcarlsen,项目名称:cloudlus,代码行数:101,代码来源:dense_arithmetic.go


示例20: constructQ

// constructQ constructs the Q matrix from the result of dgeqrf and dgeqr2
func constructQ(kind string, m, n int, a []float64, lda int, tau []float64) blas64.General {
	k := min(m, n)
	var sz int
	switch kind {
	case "QR":
		sz = m
	case "LQ":
		sz = n
	}

	q := blas64.General{
		Rows:   sz,
		Cols:   sz,
		Stride: sz,
		Data:   make([]float64, sz*sz),
	}
	for i := 0; i < sz; i++ {
		q.Data[i*sz+i] = 1
	}
	qCopy := blas64.General{
		Rows:   q.Rows,
		Cols:   q.Cols,
		Stride: q.Stride,
		Data:   make([]float64, len(q.Data)),
	}
	for i := 0; i < k; i++ {
		h := blas64.General{
			Rows:   sz,
			Cols:   sz,
			Stride: sz,
			Data:   make([]float64, sz*sz),
		}
		for j := 0; j < sz; j++ {
			h.Data[j*sz+j] = 1
		}
		vVec := blas64.Vector{
			Inc:  1,
			Data: make([]float64, sz),
		}
		for j := 0; j < i; j++ {
			vVec.Data[j] = 0
		}
		vVec.Data[i] = 1
		switch kind {
		case "QR":
			for j := i + 1; j < sz; j++ {
				vVec.Data[j] = a[lda*j+i]
			}
		case "LQ":
			for j := i + 1; j < sz; j++ {
				vVec.Data[j] = a[i*lda+j]
			}
		}
		blas64.Ger(-tau[i], vVec, vVec, h)
		copy(qCopy.Data, q.Data)
		// Mulitply q by the new h
		switch kind {
		case "QR":
			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, h, 0, q)
		case "LQ":
			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qCopy, 0, q)
		}
	}
	return q
}
开发者ID:jacobxk,项目名称:lapack,代码行数:66,代码来源:general.go



注:本文中的github.com/gonum/blas/blas64.Gemm函数示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Golang blas64.Gemv函数代码示例发布时间:2022-05-23
下一篇:
Golang blas64.Dot函数代码示例发布时间:2022-05-23
热门推荐
热门话题
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap