This article from Intel describes how to do exactly the 3x8 case that you want.
That article covers the float
case. If you want int32
, you'll need to cast the outputs since there's no integer version of _mm256_shuffle_ps()
.
Copying their solution verbatim:
float *p; // address of first vector
__m128 *m = (__m128*) p;
__m256 m03;
__m256 m14;
__m256 m25;
m03 = _mm256_castps128_ps256(m[0]); // load lower halves
m14 = _mm256_castps128_ps256(m[1]);
m25 = _mm256_castps128_ps256(m[2]);
m03 = _mm256_insertf128_ps(m03 ,m[3],1); // load upper halves
m14 = _mm256_insertf128_ps(m14 ,m[4],1);
m25 = _mm256_insertf128_ps(m25 ,m[5],1);
__m256 xy = _mm256_shuffle_ps(m14, m25, _MM_SHUFFLE( 2,1,3,2)); // upper x's and y's
__m256 yz = _mm256_shuffle_ps(m03, m14, _MM_SHUFFLE( 1,0,2,1)); // lower y's and z's
__m256 x = _mm256_shuffle_ps(m03, xy , _MM_SHUFFLE( 2,0,3,0));
__m256 y = _mm256_shuffle_ps(yz , xy , _MM_SHUFFLE( 3,1,2,0));
__m256 z = _mm256_shuffle_ps(yz , m25, _MM_SHUFFLE( 3,0,3,1));
So this is 11 instructions. (6 loads, 5 shuffles)
In the general case, it's possible to do an S x W
transpose in O(S*log(W))
instructions. Where:
S
is the stride
W
is the SIMD width
Assuming the existence of 2-vector permutes and half-vector insert-loads, then the formula becomes:
(S x W load-permute) <= S * (lg(W) + 1) instructions
Ignoring reg-reg moves. For degenerate cases like the 3 x 4
, it may be possible to do better.
Here's the 3 x 16
load-transpose with AVX512: (6 loads, 3 shuffles, 6 blends)
FORCE_INLINE void transpose_f32_16x3_forward_AVX512(
const float T[48],
__m512& r0, __m512& r1, __m512& r2
){
__m512 a0, a1, a2;
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
// 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
a0 = _mm512_castps256_ps512(_mm256_loadu_ps(T + 0));
a1 = _mm512_castps256_ps512(_mm256_loadu_ps(T + 8));
a2 = _mm512_castps256_ps512(_mm256_loadu_ps(T + 16));
a0 = _mm512_insertf32x8(a0, ((const __m256*)T)[3], 1);
a1 = _mm512_insertf32x8(a1, ((const __m256*)T)[4], 1);
a2 = _mm512_insertf32x8(a2, ((const __m256*)T)[5], 1);
// 0 1 2 3 4 5 6 7 24 25 26 27 28 29 30 31
// 8 9 10 11 12 13 14 15 32 33 34 35 36 37 38 39
// 16 17 18 19 20 21 22 23 40 41 42 43 44 45 46 47
r0 = _mm512_mask_blend_ps(0xf0f0, a0, a1);
r1 = _mm512_permutex2var_ps(a0, _mm512_setr_epi32( 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15, 24, 25, 26, 27), a2);
r2 = _mm512_mask_blend_ps(0xf0f0, a1, a2);
// 0 1 2 3 12 13 14 15 24 25 26 27 36 37 38 39
// 4 5 6 7 16 17 18 19 28 29 30 31 40 41 42 43
// 8 9 10 11 20 21 22 23 32 33 34 35 44 45 46 47
a0 = _mm512_mask_blend_ps(0xcccc, r0, r1);
a1 = _mm512_shuffle_ps(r0, r2, 78);
a2 = _mm512_mask_blend_ps(0xcccc, r1, r2);
// 0 1 6 7 12 13 18 19 24 25 30 31 36 37 42 43
// 2 3 8 9 14 15 20 21 26 27 32 33 38 39 44 45
// 4 5 10 11 16 17 22 23 28 29 34 35 40 41 46 47
r0 = _mm512_mask_blend_ps(0xaaaa, a0, a1);
r1 = _mm512_permutex2var_ps(a0, _mm512_setr_epi32( 1, 16, 3, 18, 5, 20, 7, 22, 9, 24, 11, 26, 13, 28, 15, 30), a2);
r2 = _mm512_mask_blend_ps(0xaaaa, a1, a2);
// 0 3 6 9 12 15 18 21 24 27 30 33 36 39 42 45
// 1 4 7 10 13 16 19 22 25 28 31 34 37 40 43 46
// 2 5 8 11 14 17 20 23 26 29 32 35 38 41 44 47
}
The inverse 3 x 16
transpose-store will be left as an exercise to the reader.
The pattern is not at all trivial to see since the S = 3
is somewhat degenerate. But if you can see the pattern, you'll be able to generalize this to any odd integer S
as well as any power-of-two W
.