10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
24template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
53template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
59template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
65template<
typename Indices_,
typename LeftArgType_,
typename RightArgType_,
typename OutputKernelType_,
typename Device_>
78template <
typename LhsScalar,
typename RhsScalar>
82 template <
typename Device>
90 BlockSizes
sz = ComputeLhsRhsBlockSizes(bm, bk, bn);
91 char*
block_mem =
static_cast<char*
>(d.allocate(
sz.lhs_size +
sz.rhs_size));
98 template <
typename Device>
108 BlockSizes
sz = ComputeLhsRhsBlockSizes(bm, bk, bn);
130 template <
typename Device>
145 sz.lhs_size = divup<Index>(bm * bk *
sizeof(LhsScalar), align) * align;
146 sz.rhs_size = divup<Index>(bn * bk *
sizeof(RhsScalar), align) * align;
179template <
typename ResScalar,
typename LhsScalar,
typename RhsScalar,
180 typename StorageIndex,
typename OutputMapper,
typename LhsMapper,
189 StorageIndex
bm_, StorageIndex
bk_, StorageIndex
bn_)
190 : m(m_), k(k_), n(n_), bm(
bm_), bk(
bk_), bn(
bn_) {}
204 LhsScalar, StorageIndex,
typename LhsMapper::SubMapper, Traits::mr,
205 Traits::LhsProgress,
typename Traits::LhsPacket4Packing,
ColMajor>
209 typename RhsMapper::SubMapper, Traits::nr,
218 template <
typename Device>
224 template <
typename Device>
226 Device& d,
const StorageIndex
num_lhs,
const StorageIndex
num_rhs,
229 return BlockMemAllocator::allocateSlices(
233 template <
typename Device>
235 BlockMemAllocator::deallocate(d,
handle);
240 const StorageIndex
depth,
const StorageIndex
rows) {
247 const StorageIndex
depth,
const StorageIndex
cols) {
254 const StorageIndex
depth,
const StorageIndex
cols,
255 const ResScalar
alpha,
const ResScalar beta) {
269 const StorageIndex m;
270 const StorageIndex k;
271 const StorageIndex n;
272 const StorageIndex bm;
273 const StorageIndex bk;
274 const StorageIndex bn;
310 template <
typename Index,
typename Scalar>
324template<
typename Indices,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType = const NoOpOutputKernel>
336 const LhsXprType& lhs,
const RhsXprType& rhs,
const Indices& dims,
337 const OutputKernelType& output_kernel = OutputKernelType())
364template<
typename Derived>
423 op.lhsExpression(), op.rhsExpression()), device),
425 op.rhsExpression(), op.lhsExpression()), device),
431 YOU_MADE_A_PROGRAMMING_MISTAKE);
447 eval_op_indices[
i].first = op.
indices()[
i].first;
448 eval_op_indices[
i].second = op.
indices()[
i].second;
470 eigen_assert(eval_op_indices[
j].first != eval_op_indices[
i].first &&
472 "contraction axes should be unique");
473 if (eval_op_indices[
j].first < eval_op_indices[
i].first) {
482 lhs_strides[
i+1] = lhs_strides[
i] * eval_left_dims[
i];
488 rhs_strides[
i+1] = rhs_strides[
i] * eval_right_dims[
i];
505 Index nocontract_idx = 0;
509 bool contracting =
false;
511 if (eval_op_indices[
j].first ==
i) {
536 bool contracting =
false;
539 if (eval_op_indices[
j].
second ==
i) {
566 Index left = eval_op_indices[
i].first;
567 Index right = eval_op_indices[
i].second;
571 "Contraction axes must be same size");
581 if (
i > 0 && right < eval_op_indices[
i-1].
second) {
618#ifdef EIGEN_USE_THREADS
619 template <
typename EvalSubExprsCallback>
622 m_leftImpl.evalSubExprsIfNeededAsync(
nullptr, [
this, done, dest](
bool) {
623 m_rightImpl.evalSubExprsIfNeededAsync(
nullptr, [
this, done, dest](
bool) {
625 evalToAsync(dest, [done]() { done(
false); });
629 evalToAsync(
m_result, [done]() { done(
true); });
636#ifndef TENSOR_CONTRACTION_DISPATCH
637#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
638 if (this->m_lhs_inner_dim_contiguous) { \
639 if (this->m_rhs_inner_dim_contiguous) { \
640 if (this->m_rhs_inner_dim_reordered) { \
641 METHOD<true, true, true, ALIGNMENT> ARGS; \
643 METHOD<true, true, false, ALIGNMENT> ARGS; \
646 if (this->m_rhs_inner_dim_reordered) { \
647 METHOD<true, false, true, ALIGNMENT> ARGS; \
649 METHOD<true, false, false, ALIGNMENT> ARGS; \
653 if (this->m_rhs_inner_dim_contiguous) { \
654 if (this->m_rhs_inner_dim_reordered) { \
655 METHOD<false, true, true, ALIGNMENT> ARGS; \
657 METHOD<false, true, false, ALIGNMENT> ARGS; \
660 if (this->m_rhs_inner_dim_reordered) { \
661 METHOD<false, false, true, ALIGNMENT> ARGS; \
663 METHOD<false, false, false, ALIGNMENT> ARGS; \
669#ifndef TENSOR_CONTRACTION_ASYNC_DISPATCH
670#define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN) \
671 if (this->m_lhs_inner_dim_contiguous) { \
672 if (this->m_rhs_inner_dim_contiguous) { \
673 if (this->m_rhs_inner_dim_reordered) { \
674 (new METHOD<DONE, true, true, true, ALIGNMENT> ARGS)->FN; \
676 (new METHOD<DONE, true, true, false, ALIGNMENT> ARGS)->FN; \
679 if (this->m_rhs_inner_dim_reordered) { \
680 (new METHOD<DONE, true, false, true, ALIGNMENT> ARGS)->FN; \
682 (new METHOD<DONE, true, false, false, ALIGNMENT> ARGS)->FN; \
686 if (this->m_rhs_inner_dim_contiguous) { \
687 if (this->m_rhs_inner_dim_reordered) { \
688 (new METHOD<DONE, false, true, true, ALIGNMENT> ARGS)->FN; \
690 (new METHOD<DONE, false, true, false, ALIGNMENT> ARGS)->FN; \
693 if (this->m_rhs_inner_dim_reordered) { \
694 (new METHOD<DONE, false, false, true, ALIGNMENT> ARGS)->FN; \
696 (new METHOD<DONE, false, false, false, ALIGNMENT> ARGS)->FN; \
703 static_cast<const Derived*
>(
this)->
template evalProduct<Unaligned>(buffer);
706#ifdef EIGEN_USE_THREADS
707 template <
typename EvalToCallback>
708 void evalToAsync(
Scalar* buffer, EvalToCallback done)
const {
709 static_cast<const Derived*
>(
this)
710 ->
template evalProductAsync<EvalToCallback, Unaligned>(buffer,
715 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
716 bool rhs_inner_dim_reordered,
int Alignment>
719 this->
template evalGemv<lhs_inner_dim_contiguous,
720 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
723 this->
template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
724 rhs_inner_dim_reordered, Alignment>(buffer);
728 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
729 #if !defined(EIGEN_HIPCC)
747 lhs_inner_dim_contiguous,
748 false, lhs_alignment> LhsMapper;
753 rhs_inner_dim_contiguous,
754 rhs_inner_dim_reordered, rhs_alignment> RhsMapper;
762 const Index resIncr(1);
769 buffer, resIncr,
alpha);
774 static_cast<Index>(1));
777 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
778 #if !defined(EIGEN_HIPCC)
785 rhs_inner_dim_contiguous,
786 rhs_inner_dim_reordered,
787 Alignment,
true>(buffer, 0, k, 1);
790 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
791 bool rhs_inner_dim_reordered,
int Alignment>
795 rhs_inner_dim_reordered, Alignment,
796 false>(buffer, k_start, k_end,
800 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment,
bool use_output_kernel>
804 const Index k_slice = k_end - k_start;
825 lhs_inner_dim_contiguous,
831 rhs_inner_dim_contiguous,
832 rhs_inner_dim_reordered,
Unaligned> RhsMapper;
837 Scalar, LhsScalar, RhsScalar,
Index, OutputMapper, LhsMapper, RhsMapper>
838 TensorContractionKernel;
847 OutputMapper output(buffer,
m);
852 blocking(k_slice,
m,
n, num_threads);
853 const Index kc = blocking.kc();
857 typedef typename TensorContractionKernel::LhsBlock LhsBlock;
858 typedef typename TensorContractionKernel::RhsBlock RhsBlock;
863 TensorContractionKernel kernel(
m, k_slice,
n, mc, kc, nc);
865 typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
866 const BlockMemHandle packed_mem =
867 kernel.allocate(this->
m_device, &blockA, &blockB);
871 if (!TensorContractionKernel::HasBeta) {
875 for(
Index i2=0; i2<
m; i2+=mc)
878 for (
Index k2 = k_start; k2 < k_end; k2 += kc) {
881 kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
886 const Scalar beta = (TensorContractionKernel::HasBeta && k2 == k_start)
891 for (
Index j2 = 0; j2 <
n; j2 += nc) {
894 kernel.packRhs(&blockB, rhs.getSubMapper(k2, j2), actual_kc,
899 const OutputMapper output_mapper = output.getSubMapper(i2, j2);
900 kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc,
901 actual_nc,
alpha, beta);
904 if (use_output_kernel && k2 + kc >= k_end) {
906 actual_mc, actual_nc);
912 kernel.deallocate(this->
m_device, packed_mem);
933 template<
int LoadMode>
971template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType,
typename Device>
974 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> > {
997 static const int LDims =
999 static const int RDims =
1007 static const int NumDims = LDims + RDims - 2 * ContractDims;
1013 Base(op, device) { }
1015 template <
int Alignment>
Matrix3f m
Definition AngleAxis_mimic_euler.cpp:1
int n
Definition BiCGSTAB_simple.cpp:1
int i
Definition BiCGSTAB_step_by_step.cpp:9
#define EIGEN_ALWAYS_INLINE
Definition Macros.h:932
#define EIGEN_UNUSED_VARIABLE(var)
Definition Macros.h:1076
#define EIGEN_DEVICE_FUNC
Definition Macros.h:976
#define EIGEN_DONT_INLINE
Definition Macros.h:940
#define eigen_assert(x)
Definition Macros.h:1037
#define EIGEN_STRONG_INLINE
Definition Macros.h:917
#define EIGEN_STATIC_ASSERT(CONDITION, MSG)
Definition StaticAssert.h:127
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS)
Definition TensorContraction.h:637
#define EIGEN_DEVICE_REF
Definition TensorMacros.h:50
int rows
Definition Tutorial_commainit_02.cpp:1
int cols
Definition Tutorial_commainit_02.cpp:1
Scalar Scalar int size
Definition benchVecAdd.cpp:17
SCALAR Scalar
Definition bench_gemm.cpp:46
The tensor base class.
Definition TensorBase.h:973
Definition TensorContraction.h:326
EIGEN_DEVICE_FUNC const OutputKernelType & outputKernel() const
Definition TensorContraction.h:354
const OutputKernelType m_output_kernel
Definition TensorContraction.h:360
Eigen::internal::traits< TensorContractionOp >::Index Index
Definition TensorContraction.h:333
Eigen::internal::nested< TensorContractionOp >::type Nested
Definition TensorContraction.h:331
EIGEN_DEVICE_FUNC const internal::remove_all< typenameRhsXprType::Nested >::type & rhsExpression() const
Definition TensorContraction.h:351
Eigen::internal::traits< TensorContractionOp >::StorageKind StorageKind
Definition TensorContraction.h:332
EIGEN_DEVICE_FUNC const internal::remove_all< typenameLhsXprType::Nested >::type & lhsExpression() const
Definition TensorContraction.h:347
EIGEN_DEVICE_FUNC const Indices & indices() const
Definition TensorContraction.h:342
Eigen::internal::traits< TensorContractionOp >::Scalar Scalar
Definition TensorContraction.h:328
const Indices m_indices
Definition TensorContraction.h:359
LhsXprType::Nested m_lhs_xpr
Definition TensorContraction.h:357
RhsXprType::Nested m_rhs_xpr
Definition TensorContraction.h:358
internal::gebp_traits< typenameLhsXprType::CoeffReturnType, typenameRhsXprType::CoeffReturnType >::ResScalar CoeffReturnType
Definition TensorContraction.h:330
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(const LhsXprType &lhs, const RhsXprType &rhs, const Indices &dims, const OutputKernelType &output_kernel=OutputKernelType())
Definition TensorContraction.h:335
Definition TensorCostModel.h:25
Definition EmulateArray.h:21
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE std::size_t size()
Definition EmulateArray.h:44
Definition TensorBlock.h:617
Definition TensorContractionBlocking.h:25
Definition TensorRef.h:81
Definition GeneralBlockPanelKernel.h:419
Definition XprHelper.h:110
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
Definition gnuplot_common_settings.hh:12
@ Unaligned
Definition Constants.h:233
@ Aligned
Definition Constants.h:240
@ ColMajor
Definition Constants.h:319
@ RowMajor
Definition Constants.h:321
RealScalar alpha
Definition level1_cplx_impl.h:147
DenseIndex ret
Definition level1_cplx_impl.h:44
@ ShardByCol
Definition TensorContractionBlocking.h:19
@ Lhs
Definition TensorContractionMapper.h:19
@ Rhs
Definition TensorContractionMapper.h:18
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T maxi(const T &x, const T &y)
Definition MathFunctions.h:1091
EIGEN_STRONG_INLINE void swap(T &a, T &b)
Definition Meta.h:766
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Definition MathFunctions.h:1083
Namespace containing all symbols from the Eigen library.
Definition bench_norm.cpp:85
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const T1 & choose(Cond< true >, const T1 &first, const T2 &)
Definition TensorMeta.h:18
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:74
Definition BandTriangularSolver.h:13
real function second()
SECOND returns nothing
Definition second_NONE.f:39
Definition TensorMeta.h:15
Definition TensorDimensions.h:263
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex TotalSize() const
Definition TensorDimensions.h:271
Definition Constants.h:507
Definition TensorContraction.h:294
EIGEN_ALWAYS_INLINE void operator()(const internal::blas_data_mapper< Scalar, Index, ColMajor > &output_mapper, const TensorContractionParams ¶ms, Index i, Index j, Index num_rows, Index num_cols) const
Definition TensorContraction.h:311
Definition TensorMeta.h:50
Definition TensorForwardDeclarations.h:37
Definition TensorContraction.h:366
XprType::CoeffReturnType CoeffReturnType
Definition TensorContraction.h:376
static const int LDims
Definition TensorContraction.h:407
TensorEvaluator< EvalRightArgType, Device > RightEvaluatorType
Definition TensorContraction.h:405
EIGEN_STRONG_INLINE void cleanup()
Definition TensorContraction.h:915
Index m_i_size
Definition TensorContraction.h:956
DSizes< Index, NumDims > Dimensions
Definition TensorContraction.h:418
EIGEN_STRONG_INLINE TensorContractionEvaluatorBase(const XprType &op, const Device &device)
Definition TensorContraction.h:421
StorageMemory< Scalar, Device > Storage
Definition TensorContraction.h:378
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
Definition TensorContraction.h:934
right_nocontract_t m_j_strides
Definition TensorContraction.h:952
internal::traits< Derived >::Device Device
Definition TensorContraction.h:371
right_nocontract_t m_right_nocontract_strides
Definition TensorContraction.h:954
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition TensorContraction.h:603
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition TensorContraction.h:377
internal::traits< Derived >::LeftArgType LeftArgType
Definition TensorContraction.h:368
contract_t m_right_contracting_strides
Definition TensorContraction.h:945
const Device EIGEN_DEVICE_REF m_device
Definition TensorContraction.h:964
array< Index, RDims - ContractDims > right_nocontract_t
Definition TensorContraction.h:416
internal::conditional< static_cast< int >(Layout)==static_cast< int >(ColMajor), LeftArgType, RightArgType >::type EvalLeftArgType
Definition TensorContraction.h:400
TensorEvaluator< EvalLeftArgType, Device > LeftEvaluatorType
Definition TensorContraction.h:404
left_nocontract_t m_left_nocontract_strides
Definition TensorContraction.h:953
EvaluatorPointerType m_result
Definition TensorContraction.h:966
EIGEN_DEVICE_FUNC void evalGemv(Scalar *buffer) const
Definition TensorContraction.h:732
XprType::Index Index
Definition TensorContraction.h:375
Storage::Type EvaluatorPointerType
Definition TensorContraction.h:379
bool m_rhs_inner_dim_reordered
Definition TensorContraction.h:949
internal::traits< Derived >::RightArgType RightArgType
Definition TensorContraction.h:369
bool m_rhs_inner_dim_contiguous
Definition TensorContraction.h:948
EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar *buffer, Index k_start, Index k_end, int num_threads) const
Definition TensorContraction.h:801
contract_t m_left_contracting_strides
Definition TensorContraction.h:944
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data)
Definition TensorContraction.h:605
TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType > XprType
Definition TensorContraction.h:373
EIGEN_DEVICE_FUNC void evalTo(Scalar *buffer) const
Definition TensorContraction.h:702
Index m_j_size
Definition TensorContraction.h:957
TensorEvaluator< EvalRightArgType, Device > m_rightImpl
Definition TensorContraction.h:963
internal::conditional< static_cast< int >(Layout)==static_cast< int >(ColMajor), RightArgType, LeftArgType >::type EvalRightArgType
Definition TensorContraction.h:402
static const int NumDims
Definition TensorContraction.h:412
internal::remove_const< typenameXprType::Scalar >::type Scalar
Definition TensorContraction.h:374
Index m_k_size
Definition TensorContraction.h:958
@ PreferBlockAccess
Definition TensorContraction.h:385
@ PacketAccess
Definition TensorContraction.h:383
@ RawAccess
Definition TensorContraction.h:388
@ CoordAccess
Definition TensorContraction.h:387
@ IsAligned
Definition TensorContraction.h:382
@ BlockAccess
Definition TensorContraction.h:384
@ Layout
Definition TensorContraction.h:386
internal::TensorBlockNotImplemented TensorBlock
Definition TensorContraction.h:392
void evalProductSequential(Scalar *buffer) const
Definition TensorContraction.h:717
OutputKernelType m_output_kernel
Definition TensorContraction.h:965
TensorContractionParams m_tensor_contraction_params
Definition TensorContraction.h:960
bool m_lhs_inner_dim_contiguous
Definition TensorContraction.h:947
Dimensions m_dimensions
Definition TensorContraction.h:941
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool) const
Definition TensorContraction.h:929
array< Index, ContractDims > contract_t
Definition TensorContraction.h:414
internal::traits< Derived >::OutputKernelType OutputKernelType
Definition TensorContraction.h:370
TensorEvaluator< EvalLeftArgType, Device > m_leftImpl
Definition TensorContraction.h:962
array< Index, LDims - ContractDims > left_nocontract_t
Definition TensorContraction.h:415
EIGEN_DEVICE_FUNC void evalGemm(Scalar *buffer) const
Definition TensorContraction.h:781
contract_t m_k_strides
Definition TensorContraction.h:943
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EvaluatorPointerType data() const
Definition TensorContraction.h:938
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
Definition TensorContraction.h:925
static const int RDims
Definition TensorContraction.h:409
EIGEN_DEVICE_FUNC void evalGemmPartialWithoutOutputKernel(Scalar *buffer, Index k_start, Index k_end, int num_threads) const
Definition TensorContraction.h:792
static const int ContractDims
Definition TensorContraction.h:411
internal::traits< Derived >::Indices Indices
Definition TensorContraction.h:367
left_nocontract_t m_i_strides
Definition TensorContraction.h:951
Definition TensorContraction.h:281
bool swapped_arguments
Definition TensorContraction.h:284
internal::conditional< static_cast< int >(Layout)==static_cast< int >(ColMajor), LeftArgType, RightArgType >::type EvalLeftArgType
Definition TensorContraction.h:993
internal::conditional< static_cast< int >(Layout)==static_cast< int >(ColMajor), RightArgType, LeftArgType >::type EvalRightArgType
Definition TensorContraction.h:995
TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType > XprType
Definition TensorContraction.h:978
XprType::Index Index
Definition TensorContraction.h:980
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition TensorContraction.h:982
TensorEvaluator< const TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType >, Device > Self
Definition TensorContraction.h:975
DSizes< Index, NumDims > Dimensions
Definition TensorContraction.h:1010
array< Index, LDims - ContractDims > left_nocontract_t
Definition TensorContraction.h:1004
TensorEvaluator(const XprType &op, const Device &device)
Definition TensorContraction.h:1012
XprType::CoeffReturnType CoeffReturnType
Definition TensorContraction.h:981
TensorContractionEvaluatorBase< Self > Base
Definition TensorContraction.h:976
internal::remove_const< typenameXprType::Scalar >::type Scalar
Definition TensorContraction.h:979
void evalProduct(Scalar *buffer) const
Definition TensorContraction.h:1016
array< Index, ContractDims > contract_t
Definition TensorContraction.h:1003
array< Index, RDims - ContractDims > right_nocontract_t
Definition TensorContraction.h:1005
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition TensorEvaluator.h:29
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType dest)
Definition TensorEvaluator.h:75
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition TensorEvaluator.h:73
EIGEN_STRONG_INLINE void cleanup()
Definition TensorEvaluator.h:92
@ Layout
Definition TensorEvaluator.h:50
Derived::Index Index
Definition TensorEvaluator.h:30
Definition TensorContraction.h:79
static EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device &d, const Index bm, const Index bk, const Index bn, LhsScalar **lhs_block, RhsScalar **rhs_block)
Definition TensorContraction.h:83
static EIGEN_DEVICE_FUNC void deallocate(Device &d, BlockMemHandle handle)
Definition TensorContraction.h:131
void * BlockMemHandle
Definition TensorContraction.h:80
static EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices(Device &d, const Index bm, const Index bk, const Index bn, const Index num_lhs, const Index num_rhs, const Index num_slices, std::vector< LhsScalar * > *lhs_blocks, std::vector< RhsScalar * > *rhs_blocks)
Definition TensorContraction.h:99
Definition TensorContraction.h:182
LhsScalar * LhsBlock
Definition TensorContraction.h:193
RhsScalar * RhsBlock
Definition TensorContraction.h:194
EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices(Device &d, const StorageIndex num_lhs, const StorageIndex num_rhs, const StorageIndex num_slices, std::vector< LhsBlock > *lhs_blocks, std::vector< RhsBlock > *rhs_blocks)
Definition TensorContraction.h:225
EIGEN_DEVICE_FUNC TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_, StorageIndex bm_, StorageIndex bk_, StorageIndex bn_)
Definition TensorContraction.h:188
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs(LhsBlock *lhsBlock, const typename LhsMapper::SubMapper &data_mapper, const StorageIndex depth, const StorageIndex rows)
Definition TensorContraction.h:238
internal::gebp_kernel< LhsScalar, RhsScalar, StorageIndex, OutputMapper, Traits::mr, Traits::nr, false, false > GebpKernel
Definition TensorContraction.h:216
internal::gemm_pack_lhs< LhsScalar, StorageIndex, typename LhsMapper::SubMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor > LhsPacker
Definition TensorContraction.h:206
@ HasBeta
Definition TensorContraction.h:185
internal::gemm_pack_rhs< RhsScalar, StorageIndex, typename RhsMapper::SubMapper, Traits::nr, ColMajor > RhsPacker
Definition TensorContraction.h:211
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke(const OutputMapper &output_mapper, const LhsBlock &lhsBlock, const RhsBlock &rhsBlock, const StorageIndex rows, const StorageIndex depth, const StorageIndex cols, const ResScalar alpha, const ResScalar beta)
Definition TensorContraction.h:251
static EIGEN_DEVICE_FUNC void deallocate(Device &d, BlockMemHandle handle)
Definition TensorContraction.h:234
BlockMemAllocator::BlockMemHandle BlockMemHandle
Definition TensorContraction.h:199
internal::gebp_traits< LhsScalar, RhsScalar > Traits
Definition TensorContraction.h:201
EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device &d, LhsBlock *lhs_block, RhsBlock *rhs_block)
Definition TensorContraction.h:219
TensorContractionBlockMemAllocator< LhsScalar, RhsScalar > BlockMemAllocator
Definition TensorContraction.h:198
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs(RhsBlock *rhsBlock, const typename RhsMapper::SubMapper &data_mapper, const StorageIndex depth, const StorageIndex cols)
Definition TensorContraction.h:245
const TensorContractionOp< Dimensions, LhsXprType, RhsXprType, OutputKernelType > & type
Definition TensorContraction.h:56
Definition XprHelper.h:332
Definition GeneralBlockPanelKernel.h:1058
TensorContractionOp< Dimensions, LhsXprType, RhsXprType, OutputKernelType > type
Definition TensorContraction.h:62
Definition TensorTraits.h:175
promote_storage_type< typenametraits< LhsXprType >::StorageKind, typenametraits< RhsXprType >::StorageKind >::ret StorageKind
Definition TensorContraction.h:32
remove_reference< LhsNested >::type _LhsNested
Definition TensorContraction.h:37
RhsXprType::Nested RhsNested
Definition TensorContraction.h:36
gebp_traits< typenameremove_const< typenameLhsXprType::Scalar >::type, typenameremove_const< typenameRhsXprType::Scalar >::type >::ResScalar Scalar
Definition TensorContraction.h:29
remove_reference< RhsNested >::type _RhsNested
Definition TensorContraction.h:38
LhsXprType::Nested LhsNested
Definition TensorContraction.h:35
conditional< Pointer_type_promotion< typenameLhsXprType::Scalar, Scalar >::val, typenametraits< LhsXprType >::PointerType, typenametraits< RhsXprType >::PointerType >::type PointerType
Definition TensorContraction.h:46
promote_index_type< typenametraits< LhsXprType >::Index, typenametraits< RhsXprType >::Index >::type Index
Definition TensorContraction.h:34
Device_ Device
Definition TensorContraction.h:71
RightArgType_ RightArgType
Definition TensorContraction.h:69
LeftArgType_ LeftArgType
Definition TensorContraction.h:68
Indices_ Indices
Definition TensorContraction.h:67
OutputKernelType_ OutputKernelType
Definition TensorContraction.h:70
Definition ForwardDeclarations.h:17
Definition GenericPacketMath.h:133
std::ptrdiff_t j
Definition tut_arithmetic_redux_minmax.cpp:2