GetFEM  5.5
gmm_blas_interface.h
Go to the documentation of this file.
1 /* -*- c++ -*- (enables emacs c++ mode) */
2 /*===========================================================================
3 
4  Copyright (C) 2003-2026 Yves Renard
5 
6  This file is a part of GetFEM
7 
8  GetFEM is free software; you can redistribute it and/or modify it
9  under the terms of the GNU Lesser General Public License as published
10  by the Free Software Foundation; either version 3 of the License, or
11  (at your option) any later version along with the GCC Runtime Library
12  Exception either version 3.1 or (at your option) any later version.
13  This program is distributed in the hope that it will be useful, but
14  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15  or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
16  License and GCC Runtime Library Exception for more details.
17  You should have received a copy of the GNU Lesser General Public License
18  along with this program. If not, see https://www.gnu.org/licenses/.
19 
20  As a special exception, you may use this file as it is a part of a free
21  software library without restriction. Specifically, if other files
22  instantiate templates or use macros or inline functions from this file,
23  or you compile this file and link it with other files to produce an
24  executable, this file does not by itself cause the resulting executable
25  to be covered by the GNU Lesser General Public License. This exception
26  does not however invalidate any other reasons why the executable file
27  might be covered by the GNU Lesser General Public License.
28 
29 ===========================================================================*/
30 
31 /**@file gmm_blas_interface.h
32  @author Yves Renard <Yves.Renard@insa-lyon.fr>
33  @date October 7, 2003.
34  @brief gmm interface for fortran BLAS.
35 */
36 
37 #if defined(GMM_USES_BLAS) || defined(GMM_USES_LAPACK)
38 
39 #ifndef GMM_BLAS_INTERFACE_H
40 #define GMM_BLAS_INTERFACE_H
41 
42 #include "gmm_blas.h"
43 #include "gmm_interface.h"
44 #include "gmm_matrix.h"
45 
46 namespace gmm {
47 
48  // Use ./configure --enable-blas-interface to activate this interface.
49 
50 #define GMMLAPACK_TRACE(f)
51 // #define GMMLAPACK_TRACE(f) cout << "function " << f << " called" << endl;
52 
53 #if defined(WeirdNEC) || defined(GMM_USE_BLAS64_INTERFACE)
54  #define BLAS_INT long
55 #else // By default BLAS_INT will just be int in C
56  #define BLAS_INT int
57 #endif
58 
59  /* ********************************************************************* */
60  /* Operations interfaced for T = float, double, std::complex<float> */
61  /* or std::complex<double> : */
62  /* */
63  /* vect_norm2(std::vector<T>) */
64  /* */
65  /* vect_sp(std::vector<T>, std::vector<T>) */
66  /* vect_sp(scaled(std::vector<T>), std::vector<T>) */
67  /* vect_sp(std::vector<T>, scaled(std::vector<T>)) */
68  /* vect_sp(scaled(std::vector<T>), scaled(std::vector<T>)) */
69  /* */
70  /* vect_hp(std::vector<T>, std::vector<T>) */
71  /* vect_hp(scaled(std::vector<T>), std::vector<T>) */
72  /* vect_hp(std::vector<T>, scaled(std::vector<T>)) */
73  /* vect_hp(scaled(std::vector<T>), scaled(std::vector<T>)) */
74  /* */
75  /* add(std::vector<T>, std::vector<T>) */
76  /* add(scaled(std::vector<T>, a), std::vector<T>) */
77  /* */
78  /* mult(dense_matrix<T>, dense_matrix<T>, dense_matrix<T>) */
79  /* mult(transposed(dense_matrix<T>), dense_matrix<T>, dense_matrix<T>) */
80  /* mult(dense_matrix<T>, transposed(dense_matrix<T>), dense_matrix<T>) */
81  /* mult(transposed(dense_matrix<T>), transposed(dense_matrix<T>), */
82  /* dense_matrix<T>) */
83  /* mult(conjugated(dense_matrix<T>), dense_matrix<T>, dense_matrix<T>) */
84  /* mult(dense_matrix<T>, conjugated(dense_matrix<T>), dense_matrix<T>) */
85  /* mult(conjugated(dense_matrix<T>), conjugated(dense_matrix<T>), */
86  /* dense_matrix<T>) */
87  /* */
88  /* mult(dense_matrix<T>, std::vector<T>, std::vector<T>) */
89  /* mult(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>) */
90  /* mult(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>) */
91  /* mult(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>) */
92  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
93  /* std::vector<T>) */
94  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
95  /* std::vector<T>) */
96  /* */
97  /* mult_add(dense_matrix<T>, std::vector<T>, std::vector<T>) */
98  /* mult_add(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>) */
99  /* mult_add(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>) */
100  /* mult_add(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>) */
101  /* mult_add(transposed(dense_matrix<T>), scaled(std::vector<T>), */
102  /* std::vector<T>) */
103  /* mult_add(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
104  /* std::vector<T>) */
105  /* */
106  /* mult(dense_matrix<T>, std::vector<T>, std::vector<T>, std::vector<T>) */
107  /* mult(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>, */
108  /* std::vector<T>) */
109  /* mult(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>, */
110  /* std::vector<T>) */
111  /* mult(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>, */
112  /* std::vector<T>) */
113  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
114  /* std::vector<T>, std::vector<T>) */
115  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
116  /* std::vector<T>, std::vector<T>) */
117  /* mult(dense_matrix<T>, std::vector<T>, scaled(std::vector<T>), */
118  /* std::vector<T>) */
119  /* mult(transposed(dense_matrix<T>), std::vector<T>, */
120  /* scaled(std::vector<T>), std::vector<T>) */
121  /* mult(conjugated(dense_matrix<T>), std::vector<T>, */
122  /* scaled(std::vector<T>), std::vector<T>) */
123  /* mult(dense_matrix<T>, scaled(std::vector<T>), scaled(std::vector<T>), */
124  /* std::vector<T>) */
125  /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
126  /* scaled(std::vector<T>), std::vector<T>) */
127  /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
128  /* scaled(std::vector<T>), std::vector<T>) */
129  /* */
130  /* lower_tri_solve(dense_matrix<T>, std::vector<T>, k, b) */
131  /* upper_tri_solve(dense_matrix<T>, std::vector<T>, k, b) */
132  /* lower_tri_solve(transposed(dense_matrix<T>), std::vector<T>, k, b) */
133  /* upper_tri_solve(transposed(dense_matrix<T>), std::vector<T>, k, b) */
134  /* lower_tri_solve(conjugated(dense_matrix<T>), std::vector<T>, k, b) */
135  /* upper_tri_solve(conjugated(dense_matrix<T>), std::vector<T>, k, b) */
136  /* */
137  /* rank_one_update(dense_matrix<T>, std::vector<T>, std::vector<T>) */
138  /* rank_one_update(dense_matrix<T>, scaled(std::vector<T>), */
139  /* std::vector<T>) */
140  /* rank_one_update(dense_matrix<T>, std::vector<T>, */
141  /* scaled(std::vector<T>)) */
142  /* */
143  /* ********************************************************************* */
144 
145  /* ********************************************************************* */
146  /* Basic defines. */
147  /* ********************************************************************* */
148 
149 # define BLAS_S float
150 # define BLAS_D double
151 # define BLAS_C std::complex<float>
152 # define BLAS_Z std::complex<double>
153 typedef struct{float r,i;} FORTRAN_BLAS_C;
154 typedef struct{double r,i;} FORTRAN_BLAS_Z;
155 
156 // Hack due to BLAS ABI mess
157 #if defined(GMM_BLAS_RETURN_COMPLEX_AS_ARGUMENT)
158 # define BLAS_CPLX_FUNC_CALL(blasname, ftype, res, ...) \
159  blasname(&res, __VA_ARGS__)
160 #else
161 # define BLAS_CPLX_FUNC_CALL(blasname, ftype, res, ...) \
162  ftype _res=blasname(__VA_ARGS__); res=decltype(res){_res.r,_res.i};
163 #endif
164 
165  /* ********************************************************************* */
166  /* BLAS functions used. */
167  /* ********************************************************************* */
168  extern "C" {
169  void daxpy_(const BLAS_INT *n, const double *alpha, const double *x,
170  const BLAS_INT *incx, double *y, const BLAS_INT *incy);
171  void saxpy_(...); /*void daxpy_(...);*/ void caxpy_(...); void zaxpy_(...);
172  void dgemm_(const char *tA, const char *tB, const BLAS_INT *m,
173  const BLAS_INT *n, const BLAS_INT *k, const BLAS_D *alpha,
174  const BLAS_D *A, const BLAS_INT *ldA, const BLAS_D *B,
175  const BLAS_INT *ldB, const BLAS_D *beta, BLAS_D *C,
176  const BLAS_INT *ldC);
177  void sgemm_(...); /*void dgemm_(...);*/ void cgemm_(...); void zgemm_(...);
178  void sgemv_(...); void dgemv_(...); void cgemv_(...); void zgemv_(...);
179  void strsv_(...); void dtrsv_(...); void ctrsv_(...); void ztrsv_(...);
180  BLAS_S sdot_ (...); BLAS_D ddot_ (...);
181 #if defined(GMM_BLAS_RETURN_COMPLEX_AS_ARGUMENT)
182  void cdotu_(...); void zdotu_(...); void cdotc_(...); void zdotc_(...);
183 #else
184  FORTRAN_BLAS_C cdotu_(...); FORTRAN_BLAS_Z zdotu_(...);
185  // Hermitian product in {c,z}dotc is defined in reverse order than usually
186  FORTRAN_BLAS_C cdotc_(...); FORTRAN_BLAS_Z zdotc_(...);
187 #endif
188  BLAS_S snrm2_(...); BLAS_D dnrm2_(...);
189  BLAS_S scnrm2_(...); BLAS_D dznrm2_(...);
190  void sger_(...); void dger_(...); void cgerc_(...); void zgerc_(...);
191  }
192 
193 
194  /* ********************************************************************* */
195  /* vect_norm2(x). */
196  /* ********************************************************************* */
197 
198 # define nrm2_interface(blas_name, base_type) \
199  inline number_traits<base_type>::magnitude_type \
200  vect_norm2(const std::vector<base_type> &x) { \
201  GMMLAPACK_TRACE("nrm2_interface"); \
202  const BLAS_INT n=BLAS_INT(vect_size(x)), inc(1); \
203  return blas_name(&n, &x[0], &inc); \
204  }
205 
206  nrm2_interface(snrm2_, BLAS_S)
207  nrm2_interface(dnrm2_, BLAS_D)
208  nrm2_interface(scnrm2_, BLAS_C)
209  nrm2_interface(dznrm2_, BLAS_Z)
210 
211  /* ********************************************************************* */
212  /* vect_sp(x,y) = vect_hp(x,y) for real vectors */
213  /* ********************************************************************* */
214 
215 # define dot_interface(funcname, msg, blas_name, base_type) \
216  inline base_type funcname(const std::vector<base_type> &x, \
217  const std::vector<base_type> &y) { \
218  GMMLAPACK_TRACE(msg); \
219  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
220  return blas_name(&n, &x[0], &inc, &y[0], &inc); \
221  } \
222  inline base_type funcname \
223  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
224  const std::vector<base_type> &y) { \
225  GMMLAPACK_TRACE(msg); \
226  const std::vector<base_type> &x = *(linalg_origin(x_)); \
227  base_type a(x_.r); \
228  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
229  return a * blas_name(&n, &x[0], &inc, &y[0], &inc); \
230  } \
231  inline base_type funcname \
232  (const std::vector<base_type> &x, \
233  const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
234  GMMLAPACK_TRACE(msg); \
235  const std::vector<base_type> &y = *(linalg_origin(y_)); \
236  base_type b(y_.r); \
237  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
238  return b * blas_name(&n, &x[0], &inc, &y[0], &inc); \
239  } \
240  inline base_type funcname \
241  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
242  const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
243  GMMLAPACK_TRACE(msg); \
244  const std::vector<base_type> &x = *(linalg_origin(x_)); \
245  const std::vector<base_type> &y = *(linalg_origin(y_)); \
246  base_type a(x_.r), b(y_.r); \
247  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
248  return a*b * blas_name(&n, &x[0], &inc, &y[0], &inc); \
249  }
250 
251  dot_interface(vect_sp, "dot_interface", sdot_, BLAS_S)
252  dot_interface(vect_sp, "dot_interface", ddot_, BLAS_D)
253  dot_interface(vect_hp, "dotc_interface", sdot_, BLAS_S)
254  dot_interface(vect_hp, "dotc_interface", ddot_, BLAS_D)
255 
256  /* ********************************************************************* */
257  /* vect_sp(x,y) and vect_hp(x,y) for complex vectors */
258  /* vect_hp(x, y) = x.conj(y) (different order than in BLAS) */
259  /* switching x,y before passed to BLAS is important only for vect_hp */
260  /* ********************************************************************* */
261 
262 # define dot_interface_cplx(funcname, msg, blas_name, base_type, ftype, b) \
263  inline base_type funcname(const std::vector<base_type> &x, \
264  const std::vector<base_type> &y) { \
265  GMMLAPACK_TRACE(msg); \
266  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); base_type res; \
267  BLAS_CPLX_FUNC_CALL(blas_name, ftype, res, \
268  &n, &y[0], &inc, &x[0], &inc) \
269  return res; \
270  } \
271  inline base_type funcname \
272  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
273  const std::vector<base_type> &y) { \
274  GMMLAPACK_TRACE(msg); \
275  const std::vector<base_type> &x = *(linalg_origin(x_)); \
276  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); base_type res; \
277  BLAS_CPLX_FUNC_CALL(blas_name, ftype, res, \
278  &n, &y[0], &inc, &x[0], &inc) \
279  return (x_.r)*res; \
280  } \
281  inline base_type funcname \
282  (const std::vector<base_type> &x, \
283  const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
284  GMMLAPACK_TRACE(msg); \
285  const std::vector<base_type> &y = *(linalg_origin(y_)); \
286  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); base_type res; \
287  BLAS_CPLX_FUNC_CALL(blas_name, ftype, res, \
288  &n, &y[0], &inc, &x[0], &inc) \
289  return (b)*res; \
290  } \
291  inline base_type funcname \
292  (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
293  const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
294  GMMLAPACK_TRACE(msg); \
295  const std::vector<base_type> &x = *(linalg_origin(x_)); \
296  const std::vector<base_type> &y = *(linalg_origin(y_)); \
297  const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); base_type res; \
298  BLAS_CPLX_FUNC_CALL(blas_name, ftype, res, \
299  &n, &y[0], &inc, &x[0], &inc) \
300  return (x_.r)*(b)*res; \
301  }
302 
303  dot_interface_cplx(vect_sp, "dot_interface", cdotu_,
304  BLAS_C, FORTRAN_BLAS_C, y_.r)
305  dot_interface_cplx(vect_sp, "dot_interface", zdotu_,
306  BLAS_Z, FORTRAN_BLAS_Z, y_.r)
307  dot_interface_cplx(vect_hp, "dotc_interface", cdotc_,
308  BLAS_C, FORTRAN_BLAS_C, gmm::conj(y_.r))
309  dot_interface_cplx(vect_hp, "dotc_interface", zdotc_,
310  BLAS_Z, FORTRAN_BLAS_Z, gmm::conj(y_.r))
311 
312 
313  /* ********************************************************************* */
314  /* add(x, y). */
315  /* ********************************************************************* */
316  template<size_type N, class V1, class V2>
317  inline void add_fixed(const V1 &x, V2 &y)
318  {
319  for(size_type i = 0; i != N; ++i) y[i] += x[i];
320  }
321 
322  template<class V1, class V2>
323  inline void add_for_short_vectors(const V1 &x, V2 &y, size_type n)
324  {
325  switch(n)
326  {
327  case 1: add_fixed<1>(x, y); break;
328  case 2: add_fixed<2>(x, y); break;
329  case 3: add_fixed<3>(x, y); break;
330  case 4: add_fixed<4>(x, y); break;
331  case 5: add_fixed<5>(x, y); break;
332  case 6: add_fixed<6>(x, y); break;
333  case 7: add_fixed<7>(x, y); break;
334  case 8: add_fixed<8>(x, y); break;
335  case 9: add_fixed<9>(x, y); break;
336  case 10: add_fixed<10>(x, y); break;
337  case 11: add_fixed<11>(x, y); break;
338  case 12: add_fixed<12>(x, y); break;
339  case 13: add_fixed<13>(x, y); break;
340  case 14: add_fixed<14>(x, y); break;
341  case 15: add_fixed<15>(x, y); break;
342  case 16: add_fixed<16>(x, y); break;
343  case 17: add_fixed<17>(x, y); break;
344  case 18: add_fixed<18>(x, y); break;
345  case 19: add_fixed<19>(x, y); break;
346  case 20: add_fixed<20>(x, y); break;
347  case 21: add_fixed<21>(x, y); break;
348  case 22: add_fixed<22>(x, y); break;
349  case 23: add_fixed<23>(x, y); break;
350  case 24: add_fixed<24>(x, y); break;
351  default:
352  GMM_ASSERT2(false, "add_for_short_vectors used with unsupported size");
353  break;
354  }
355  }
356 
357  template<size_type N, class V1, class V2, class T>
358  inline void add_fixed(const V1 &x, V2 &y, const T &a)
359  {
360  for(size_type i = 0; i != N; ++i) y[i] += a*x[i];
361  }
362 
363  template<class V1, class V2, class T>
364  inline void add_for_short_vectors(const V1 &x, V2 &y, const T &a, size_type n)
365  {
366  switch(n)
367  {
368  case 1: add_fixed<1>(x, y, a); break;
369  case 2: add_fixed<2>(x, y, a); break;
370  case 3: add_fixed<3>(x, y, a); break;
371  case 4: add_fixed<4>(x, y, a); break;
372  case 5: add_fixed<5>(x, y, a); break;
373  case 6: add_fixed<6>(x, y, a); break;
374  case 7: add_fixed<7>(x, y, a); break;
375  case 8: add_fixed<8>(x, y, a); break;
376  case 9: add_fixed<9>(x, y, a); break;
377  case 10: add_fixed<10>(x, y, a); break;
378  case 11: add_fixed<11>(x, y, a); break;
379  case 12: add_fixed<12>(x, y, a); break;
380  case 13: add_fixed<13>(x, y, a); break;
381  case 14: add_fixed<14>(x, y, a); break;
382  case 15: add_fixed<15>(x, y, a); break;
383  case 16: add_fixed<16>(x, y, a); break;
384  case 17: add_fixed<17>(x, y, a); break;
385  case 18: add_fixed<18>(x, y, a); break;
386  case 19: add_fixed<19>(x, y, a); break;
387  case 20: add_fixed<20>(x, y, a); break;
388  case 21: add_fixed<21>(x, y, a); break;
389  case 22: add_fixed<22>(x, y, a); break;
390  case 23: add_fixed<23>(x, y, a); break;
391  case 24: add_fixed<24>(x, y, a); break;
392  default:
393  GMM_ASSERT2(false, "add_for_short_vectors used with unsupported size");
394  break;
395  }
396  }
397 
398 
399 # define axpy_interface(blas_name, base_type) \
400  inline void add(const std::vector<base_type> &x, \
401  std::vector<base_type> &y) { \
402  GMMLAPACK_TRACE("axpy_interface"); \
403  const size_type nn=vect_size(y); \
404  if (nn == 0) return; \
405  else if (nn < 25) add_for_short_vectors(x, y, nn); \
406  else { const BLAS_INT n=BLAS_INT(nn), inc(1); const base_type a(1); \
407  blas_name(&n, &a, &x[0], &inc, &y[0], &inc); } \
408  } \
409  inline void add(const scaled_vector_const_ref<std::vector<base_type>, \
410  base_type> &x_, \
411  std::vector<base_type> &y) { \
412  GMMLAPACK_TRACE("axpy_interface"); \
413  const size_type nn=vect_size(y); const base_type a(x_.r); \
414  const std::vector<base_type>& x = *(linalg_origin(x_)); \
415  if (nn == 0) return; \
416  else if (nn < 25) add_for_short_vectors(x, y, a, nn); \
417  else { const BLAS_INT n=BLAS_INT(nn), inc(1); \
418  blas_name(&n, &a, &x[0], &inc, &y[0], &inc); } \
419  }
420 
421  axpy_interface(saxpy_, BLAS_S)
422  axpy_interface(daxpy_, BLAS_D)
423  axpy_interface(caxpy_, BLAS_C)
424  axpy_interface(zaxpy_, BLAS_Z)
425 
426 
427  /* ********************************************************************* */
428  /* mult_add(A, x, z). */
429  /* mult(A, x, y). */
430  /* ********************************************************************* */
431 
432 # define gemv_interface(param1, trans1, param2, trans2, blas_name, \
433  base_type, orien) \
434  inline void mult_add_spec(param1(base_type), param2(base_type), \
435  std::vector<base_type> &z, orien) { \
436  GMMLAPACK_TRACE("gemv_interface"); \
437  trans1(base_type); trans2(base_type); const base_type beta(1); \
438  const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
439  n=BLAS_INT(mat_ncols(A)), inc(1); \
440  if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, \
441  &beta, &z[0], &inc); \
442  else gmm::clear(z); \
443  } \
444  inline void mult_spec(param1(base_type), param2(base_type), \
445  std::vector<base_type> &z, orien) { \
446  GMMLAPACK_TRACE("gemv_interface2"); \
447  trans1(base_type); trans2(base_type); const base_type beta(0); \
448  const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
449  n=BLAS_INT(mat_ncols(A)), inc(1); \
450  if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, \
451  &x[0], &inc, &beta, &z[0], &inc); \
452  else gmm::clear(z); \
453  }
454 
455  // First parameter
456 # define gem_p1_n(base_type) const dense_matrix<base_type> &A
457 # define gem_trans1_n(base_type) const char t = 'N'
458 # define gem_p1_t(base_type) \
459  const transposed_col_ref<dense_matrix<base_type> *> &A_
460 # define gem_trans1_t(base_type) const dense_matrix<base_type> &A = \
461  *(linalg_origin(A_)); \
462  const char t = 'T'
463 # define gem_p1_tc(base_type) \
464  const transposed_col_ref<const dense_matrix<base_type> *> &A_
465 # define gem_p1_c(base_type) \
466  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_
467 # define gem_trans1_c(base_type) const dense_matrix<base_type> &A = \
468  *(linalg_origin(A_)); \
469  const char t = 'C'
470 
471  // second parameter
472 # define gemv_p2_n(base_type) const std::vector<base_type> &x
473 # define gemv_trans2_n(base_type) base_type alpha(1)
474 # define gemv_p2_s(base_type) \
475  const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_
476 # define gemv_trans2_s(base_type) const std::vector<base_type> &x = \
477  (*(linalg_origin(x_))); \
478  base_type alpha(x_.r)
479 
480  // Z <- AX + Z.
481  // Y <- AX.
482  gemv_interface(gem_p1_n, gem_trans1_n,
483  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, col_major)
484  gemv_interface(gem_p1_n, gem_trans1_n,
485  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, col_major)
486  gemv_interface(gem_p1_n, gem_trans1_n,
487  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, col_major)
488  gemv_interface(gem_p1_n, gem_trans1_n,
489  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, col_major)
490 
491  // Z <- transposed(A)X + Z.
492  // Y <- transposed(A)X.
493  gemv_interface(gem_p1_t, gem_trans1_t,
494  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
495  gemv_interface(gem_p1_t, gem_trans1_t,
496  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
497  gemv_interface(gem_p1_t, gem_trans1_t,
498  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
499  gemv_interface(gem_p1_t, gem_trans1_t,
500  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
501 
502  // Z <- transposed(const A)X + Z.
503  // Y <- transposed(const A)X.
504  gemv_interface(gem_p1_tc, gem_trans1_t,
505  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
506  gemv_interface(gem_p1_tc, gem_trans1_t,
507  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
508  gemv_interface(gem_p1_tc, gem_trans1_t,
509  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
510  gemv_interface(gem_p1_tc, gem_trans1_t,
511  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
512 
513  // Z <- conjugated(A)X + Z.
514  // Y <- conjugated(A)X.
515  gemv_interface(gem_p1_c, gem_trans1_c,
516  gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
517  gemv_interface(gem_p1_c, gem_trans1_c,
518  gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
519  gemv_interface(gem_p1_c, gem_trans1_c,
520  gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
521  gemv_interface(gem_p1_c, gem_trans1_c,
522  gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
523 
524  // Z <- A scaled(X) + Z.
525  // Y <- A scaled(X).
526  gemv_interface(gem_p1_n, gem_trans1_n,
527  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, col_major)
528  gemv_interface(gem_p1_n, gem_trans1_n,
529  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, col_major)
530  gemv_interface(gem_p1_n, gem_trans1_n,
531  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, col_major)
532  gemv_interface(gem_p1_n, gem_trans1_n,
533  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, col_major)
534 
535  // Z <- transposed(A) scaled(X) + Z.
536  // Y <- transposed(A) scaled(X).
537  gemv_interface(gem_p1_t, gem_trans1_t,
538  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
539  gemv_interface(gem_p1_t, gem_trans1_t,
540  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
541  gemv_interface(gem_p1_t, gem_trans1_t,
542  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
543  gemv_interface(gem_p1_t, gem_trans1_t,
544  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
545 
546  // Z <- transposed(const A) scaled(X) + Z.
547  // Y <- transposed(const A) scaled(X).
548  gemv_interface(gem_p1_tc, gem_trans1_t,
549  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
550  gemv_interface(gem_p1_tc, gem_trans1_t,
551  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
552  gemv_interface(gem_p1_tc, gem_trans1_t,
553  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
554  gemv_interface(gem_p1_tc, gem_trans1_t,
555  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
556 
557  // Z <- conjugated(A) scaled(X) + Z.
558  // Y <- conjugated(A) scaled(X).
559  gemv_interface(gem_p1_c, gem_trans1_c,
560  gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
561  gemv_interface(gem_p1_c, gem_trans1_c,
562  gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
563  gemv_interface(gem_p1_c, gem_trans1_c,
564  gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
565  gemv_interface(gem_p1_c, gem_trans1_c,
566  gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
567 
568 
569  /* ********************************************************************* */
570  /* Rank one update. */
571  /* ********************************************************************* */
572 
573 # define ger_interface(blas_name, base_type) \
574  inline void rank_one_update(dense_matrix<base_type> &A, \
575  const std::vector<base_type> &V, \
576  const std::vector<base_type> &W) { \
577  GMMLAPACK_TRACE("ger_interface"); \
578  const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
579  n=BLAS_INT(mat_ncols(A)), inc(1); \
580  const base_type alpha(1); \
581  if (m && n) \
582  blas_name(&m, &n, &alpha, &V[0], &inc, &W[0], &inc, &A(0,0), &lda); \
583  } \
584  inline void \
585  rank_one_update(dense_matrix<base_type> &A, \
586  const scaled_vector_const_ref<std::vector<base_type>, \
587  base_type> &x_, \
588  const std::vector<base_type> &W) { \
589  GMMLAPACK_TRACE("ger_interface"); \
590  const std::vector<base_type> &x = (*(linalg_origin(x_))); \
591  const base_type alpha(x_.r); \
592  const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
593  n=BLAS_INT(mat_ncols(A)), inc(1); \
594  if (m && n) \
595  blas_name(&m, &n, &alpha, &x[0], &inc, &W[0], &inc, &A(0,0), &lda); \
596  } \
597  inline void \
598  rank_one_update(dense_matrix<base_type> &A, \
599  const std::vector<base_type> &V, \
600  const scaled_vector_const_ref<std::vector<base_type>, \
601  base_type> &x_) { \
602  GMMLAPACK_TRACE("ger_interface"); \
603  const std::vector<base_type> &x = (*(linalg_origin(x_))); \
604  const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
605  n=BLAS_INT(mat_ncols(A)), inc(1); \
606  const base_type alpha0(x_.r), alpha=gmm::conj(alpha0); \
607  if (m && n) \
608  blas_name(&m, &n, &alpha, &V[0], &inc, &x[0], &inc, &A(0,0), &lda); \
609  }
610 
611  ger_interface(sger_, BLAS_S)
612  ger_interface(dger_, BLAS_D)
613  ger_interface(cgerc_, BLAS_C)
614  ger_interface(zgerc_, BLAS_Z)
615 
616 
617  /* ********************************************************************* */
618  /* dense matrix x dense matrix multiplication. */
619  /* ********************************************************************* */
620 
621 # define gemm_interface_nn(blas_name, base_type) \
622  inline void mult_spec(const dense_matrix<base_type> &A, \
623  const dense_matrix<base_type> &B, \
624  dense_matrix<base_type> &C, c_mult) { \
625  GMMLAPACK_TRACE("gemm_interface_nn"); \
626  const char t='N'; const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
627  k=BLAS_INT(mat_ncols(A)), ldb(k), \
628  n=BLAS_INT(mat_ncols(B)), ldc(m); \
629  const base_type alpha(1), beta(0); \
630  if (m && k && n) \
631  blas_name(&t, &t, &m, &n, &k, &alpha, \
632  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
633  else gmm::clear(C); \
634  }
635 
636  gemm_interface_nn(sgemm_, BLAS_S)
637  gemm_interface_nn(dgemm_, BLAS_D)
638  gemm_interface_nn(cgemm_, BLAS_C)
639  gemm_interface_nn(zgemm_, BLAS_Z)
640 
641  /* ********************************************************************* */
642  /* transposed(dense matrix) x dense matrix multiplication. */
643  /* dense matrix x transposed(dense matrix) multiplication. */
644  /* ********************************************************************* */
645 
646 # define gemm_interface_tn_nt(blas_name, base_type, mat_type) \
647  inline void mult_spec( \
648  const transposed_col_ref<mat_type<base_type> *> &A_, \
649  const dense_matrix<base_type> &B, \
650  dense_matrix<base_type> &C, rcmult) { \
651  GMMLAPACK_TRACE("gemm_interface_tn"); \
652  const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
653  const char t = 'T', u = 'N'; \
654  const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
655  n=BLAS_INT(mat_ncols(B)), lda(k), ldb(k), ldc(m); \
656  const base_type alpha(1), beta(0); \
657  if (m && k && n) blas_name(&t, &u, &m, &n, &k, &alpha, &A(0,0), &lda, \
658  &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
659  else gmm::clear(C); \
660  } \
661  inline void \
662  mult_spec(const dense_matrix<base_type> &A, \
663  const transposed_col_ref<mat_type<base_type> *> &B_, \
664  dense_matrix<base_type> &C, r_mult) { \
665  GMMLAPACK_TRACE("gemm_interface_nt"); \
666  const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
667  const char t = 'N', u = 'T'; \
668  const BLAS_INT m=BLAS_INT(mat_nrows(A)), k=BLAS_INT(mat_ncols(A)), \
669  n=BLAS_INT(mat_nrows(B)), lda(m), ldb(n), ldc(m); \
670  const base_type alpha(1), beta(0); \
671  if (m && k && n) blas_name(&t, &u, &m, &n, &k, &alpha, &A(0,0), &lda, \
672  &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
673  else gmm::clear(C); \
674  }
675 
676  gemm_interface_tn_nt(sgemm_, BLAS_S, dense_matrix)
677  gemm_interface_tn_nt(dgemm_, BLAS_D, dense_matrix)
678  gemm_interface_tn_nt(cgemm_, BLAS_C, dense_matrix)
679  gemm_interface_tn_nt(zgemm_, BLAS_Z, dense_matrix)
680  gemm_interface_tn_nt(sgemm_, BLAS_S, const dense_matrix)
681  gemm_interface_tn_nt(dgemm_, BLAS_D, const dense_matrix)
682  gemm_interface_tn_nt(cgemm_, BLAS_C, const dense_matrix)
683  gemm_interface_tn_nt(zgemm_, BLAS_Z, const dense_matrix)
684 
685 
686  /* ********************************************************************* */
687  /* transposed(dense matrix) x transposed(dense matrix) multiplication. */
688  /* ********************************************************************* */
689 
690 # define gemm_interface_tt(blas_name, base_type, matA_type, matB_type) \
691  inline void \
692  mult_spec(const transposed_col_ref<matA_type<base_type> *> &A_, \
693  const transposed_col_ref<matB_type<base_type> *> &B_, \
694  dense_matrix<base_type> &C, r_mult) { \
695  GMMLAPACK_TRACE("gemm_interface_tt"); \
696  const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
697  const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
698  const char t = 'T', u = 'T'; \
699  const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
700  n=BLAS_INT(mat_nrows(B)), lda(k), ldb(n), ldc(m); \
701  base_type alpha(1), beta(0); \
702  if (m && k && n) \
703  blas_name(&t, &u, &m, &n, &k, &alpha, \
704  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
705  else gmm::clear(C); \
706  }
707 
708  gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, dense_matrix)
709  gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, dense_matrix)
710  gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, dense_matrix)
711  gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, dense_matrix)
712  gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, dense_matrix)
713  gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, dense_matrix)
714  gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, dense_matrix)
715  gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, dense_matrix)
716  gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, const dense_matrix)
717  gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, const dense_matrix)
718  gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, const dense_matrix)
719  gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, const dense_matrix)
720  gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, const dense_matrix)
721  gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, const dense_matrix)
722  gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, const dense_matrix)
723  gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, const dense_matrix)
724 
725 
726  /* ********************************************************************* */
727  /* conjugated(dense matrix) x dense matrix multiplication. */
728  /* dense matrix x conjugated(dense matrix) multiplication. */
729  /* conjugated(dense matrix) x conjugated(dense matrix) multiplication. */
730  /* ********************************************************************* */
731 
732 # define gemm_interface_cn_nc_cc(blas_name, base_type) \
733  inline void mult_spec( \
734  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_, \
735  const dense_matrix<base_type> &B, \
736  dense_matrix<base_type> &C, rcmult) { \
737  GMMLAPACK_TRACE("gemm_interface_cn"); \
738  const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
739  const char t = 'C', u = 'N'; \
740  const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
741  n=BLAS_INT(mat_ncols(B)), lda(k), ldb(k), ldc(m); \
742  const base_type alpha(1), beta(0); \
743  if (m && k && n) \
744  blas_name(&t, &u, &m, &n, &k, &alpha, \
745  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
746  else gmm::clear(C); \
747  } \
748  inline void mult_spec( \
749  const dense_matrix<base_type> &A, \
750  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &B_, \
751  dense_matrix<base_type> &C, c_mult, row_major) { \
752  GMMLAPACK_TRACE("gemm_interface_nc"); \
753  const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
754  const char t = 'N', u = 'C'; \
755  const BLAS_INT m=BLAS_INT(mat_nrows(A)), k=BLAS_INT(mat_ncols(A)), \
756  n=BLAS_INT(mat_nrows(B)), lda(m), ldb(n), ldc(m); \
757  const base_type alpha(1), beta(0); \
758  if (m && k && n) \
759  blas_name(&t, &u, &m, &n, &k, &alpha, \
760  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
761  else gmm::clear(C); \
762  } \
763  inline void mult_spec( \
764  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_, \
765  const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &B_, \
766  dense_matrix<base_type> &C, r_mult) { \
767  GMMLAPACK_TRACE("gemm_interface_cc"); \
768  const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
769  const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
770  const char t = 'C', u = 'C'; \
771  const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
772  n=BLAS_INT(mat_nrows(B)), lda(k), ldb(n), ldc(m); \
773  const base_type alpha(1), beta(0); \
774  if (m && k && n) \
775  blas_name(&t, &u, &m, &n, &k, &alpha, \
776  &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
777  else gmm::clear(C); \
778  }
779 
780  gemm_interface_cn_nc_cc(sgemm_, BLAS_S)
781  gemm_interface_cn_nc_cc(dgemm_, BLAS_D)
782  gemm_interface_cn_nc_cc(cgemm_, BLAS_C)
783  gemm_interface_cn_nc_cc(zgemm_, BLAS_Z)
784 
785 
786  /* ********************************************************************* */
787  /* Tri solve. */
788  /* ********************************************************************* */
789 
790 # define trsv_interface(LorU1, LorU2, param1, trans1, blas_name, base_type)\
791  inline void \
792  lower_tri_solve(param1(base_type), std::vector<base_type> &x, \
793  size_type k, bool is_unit) { \
794  GMMLAPACK_TRACE("trsv_interface"); \
795  const char l = LorU1; trans1(base_type); char d = is_unit ? 'U' : 'N'; \
796  const BLAS_INT lda=BLAS_INT(mat_nrows(A)), inc(1), n=BLAS_INT(k); \
797  if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \
798  } \
799  inline void \
800  upper_tri_solve(param1(base_type), std::vector<base_type> &x, \
801  size_type k, bool is_unit) { \
802  GMMLAPACK_TRACE("trsv_interface"); \
803  const char l = LorU2; trans1(base_type); char d = is_unit ? 'U' : 'N'; \
804  const BLAS_INT lda=BLAS_INT(mat_nrows(A)), inc(1), n=BLAS_INT(k); \
805  if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \
806  }
807 
808  // X <- LOWER(A)^{-1}X.
809  // X <- UPPER(A)^{-1}X.
810  trsv_interface('L', 'U', gem_p1_n, gem_trans1_n, strsv_, BLAS_S)
811  trsv_interface('L', 'U', gem_p1_n, gem_trans1_n, dtrsv_, BLAS_D)
812  trsv_interface('L', 'U', gem_p1_n, gem_trans1_n, ctrsv_, BLAS_C)
813  trsv_interface('L', 'U', gem_p1_n, gem_trans1_n, ztrsv_, BLAS_Z)
814 
815  // X <- LOWER(transposed(A))^{-1}X.
816  // X <- UPPER(transposed(A))^{-1}X.
817  trsv_interface('U', 'L', gem_p1_t, gem_trans1_t, strsv_, BLAS_S)
818  trsv_interface('U', 'L', gem_p1_t, gem_trans1_t, dtrsv_, BLAS_D)
819  trsv_interface('U', 'L', gem_p1_t, gem_trans1_t, ctrsv_, BLAS_C)
820  trsv_interface('U', 'L', gem_p1_t, gem_trans1_t, ztrsv_, BLAS_Z)
821 
822  // X <- LOWER(transposed(const A))^{-1}X.
823  // X <- UPPER(transposed(const A))^{-1}X.
824  trsv_interface('U', 'L', gem_p1_tc, gem_trans1_t, strsv_, BLAS_S)
825  trsv_interface('U', 'L', gem_p1_tc, gem_trans1_t, dtrsv_, BLAS_D)
826  trsv_interface('U', 'L', gem_p1_tc, gem_trans1_t, ctrsv_, BLAS_C)
827  trsv_interface('U', 'L', gem_p1_tc, gem_trans1_t, ztrsv_, BLAS_Z)
828 
829  // X <- LOWER(conjugated(A))^{-1}X.
830  // X <- UPPER(conjugated(A))^{-1}X.
831  trsv_interface('U', 'L', gem_p1_c, gem_trans1_c, strsv_, BLAS_S)
832  trsv_interface('U', 'L', gem_p1_c, gem_trans1_c, dtrsv_, BLAS_D)
833  trsv_interface('U', 'L', gem_p1_c, gem_trans1_c, ctrsv_, BLAS_C)
834  trsv_interface('U', 'L', gem_p1_c, gem_trans1_c, ztrsv_, BLAS_Z)
835 }
836 
837 #endif // GMM_BLAS_INTERFACE_H
838 
839 #endif // GMM_USES_BLAS
Basic linear algebra functions.
gmm interface for STL vectors.
Declaration of some matrix types (gmm::dense_matrix, gmm::row_matrix, gmm::col_matrix,...
size_t size_type
used as the common size type in the library
Definition: bgeot_poly.h:48