Change in the heuristics about multi-thread usage.
[piranha:mainline.git] / src / core / base_classes / base_series_multiplier.h
1 /***************************************************************************
2  *   Copyright (C) 2007, 2008 by Francesco Biscani   *
3  *   bluescarni@gmail.com   *
4  *                                                                         *
5  *   This program is free software; you can redistribute it and/or modify  *
6  *   it under the terms of the GNU General Public License as published by  *
7  *   the Free Software Foundation; either version 2 of the License, or     *
8  *   (at your option) any later version.                                   *
9  *                                                                         *
10  *   This program is distributed in the hope that it will be useful,       *
11  *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
12  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
13  *   GNU General Public License for more details.                          *
14  *                                                                         *
15  *   You should have received a copy of the GNU General Public License     *
16  *   along with this program; if not, write to the                         *
17  *   Free Software Foundation, Inc.,                                       *
18  *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
19  ***************************************************************************/
20
21 #ifndef PIRANHA_BASE_SERIES_MULTIPLIER_H
22 #define PIRANHA_BASE_SERIES_MULTIPLIER_H
23
24 #include <algorithm>
25 #include <boost/lambda/lambda.hpp>
26 #include <boost/numeric/conversion/cast.hpp>
27 #include <boost/thread/thread.hpp>
28 #include <boost/type_traits/is_same.hpp> // For key type detection.
29 #include <boost/tuple/tuple.hpp>
30 #include <cmath>
31 #include <cstddef>
32 #include <iterator>
33 #include <vector>
34
35 #include "../config.h"
36 #include "../exceptions.h"
37 #include "../settings.h"
38 #include "../stats.h"
39 #include "base_series_multiplier_mp.h"
40 #include "null_truncator.h"
41
42 #define derived_const_cast static_cast<Derived const *>(this)
43 #define derived_cast static_cast<Derived *>(this)
44
45 namespace piranha
46 {
47         /// Base series multiplier.
48         /**
49          * This class is meant to be extended to build specific multipliers.
50          */
51         template <class Series1, class Series2, class ArgsTuple, class Truncator, class Derived>
52         class base_series_multiplier
53         {
54                         friend class base_insert_multiplication_result;
55                 protected:
56                         // Alias for term type of first input series and return value series.
57                         typedef typename Series1::term_type term_type1;
58                         // Alias for term type of second input series.
59                         typedef typename Series2::term_type term_type2;
60                         p_static_check((boost::is_same<typename term_type1::key_type, typename term_type2::key_type>::value),
61                                 "Key type mismatch in base multiplier.");
62                         /// Compute block size for multiplication.
63                         /**
64                          * Resulting block size depends on the cache memory available, and will always be in the [16,8192] range. N is
65                          * the size of the base storage unit type used to store the result of the multiplication (e.g., coefficient type in vector coded, coded hash table
66                          * term in the sparse hashed multiplication, etc.).
67                          */
68                         template <std::size_t N>
69                         static std::size_t compute_block_size()
70                         {
71                                 // NOTE: this function is used typically considering only the output storage requirements, since storage of input series
72                                 //       will be a small fraction of storage for the output series.
73                                 p_static_check(N > 0,"");
74                                 const std::size_t shift = boost::numeric_cast<std::size_t>(
75                                         std::log(std::max<double>(16.,std::sqrt(static_cast<double>((settings::cache_size * 1024) / N)))) / std::log(2.) - 1
76                                 );
77                                 return (std::size_t(2) << std::min<std::size_t>(std::size_t(12),shift));
78                         }
79                         template <class Functor>
80                         static void blocked_multiplication(const std::size_t &block_size, const std::size_t &size1, const std::size_t &size2, Functor &m)
81                         {
82                                 piranha_assert(block_size > 0);
83                                 const std::size_t nblocks1 = size1 / block_size, nblocks2 = size2 / block_size;
84                                 for (std::size_t n1 = 0; n1 < nblocks1; ++n1) {
85                                         const std::size_t i_start = n1 * block_size, i_end = i_start + block_size;
86                                         // regulars1 * regulars2
87                                         for (std::size_t n2 = 0; n2 < nblocks2; ++n2) {
88                                                 const std::size_t j_start = n2 * block_size, j_end = j_start + block_size;
89                                                 for (std::size_t i = i_start; i < i_end; ++i) {
90                                                         for (std::size_t j = j_start; j < j_end; ++j) {
91                                                                 if (!m(i,j)) {
92                                                                         break;
93                                                                 }
94                                                         }
95                                                 }
96                                         }
97                                         // regulars1 * rem2
98                                         for (std::size_t i = i_start; i < i_end; ++i) {
99                                                 for (std::size_t j = nblocks2 * block_size; j < size2; ++j) {
100                                                         if (!m(i,j)) {
101                                                                 break;
102                                                         }
103                                                 }
104                                         }
105                                 }
106                                 // rem1 * regulars2
107                                 for (std::size_t n2 = 0; n2 < nblocks2; ++n2) {
108                                         const std::size_t j_start = n2 * block_size, j_end = j_start + block_size;
109                                         for (std::size_t i = nblocks1 * block_size; i < size1; ++i) {
110                                                 for (std::size_t j = j_start; j < j_end; ++j) {
111                                                         if (!m(i,j)) {
112                                                                 break;
113                                                         }
114                                                 }
115                                         }
116                                 }
117                                 // rem1 * rem2.
118                                 for (std::size_t i = nblocks1 * block_size; i < size1; ++i) {
119                                         for (std::size_t j = nblocks2 * block_size; j < size2; ++j) {
120                                                 if (!m(i,j)) {
121                                                         break;
122                                                 }
123                                         }
124                                 }
125                         }
126                         /// Cache pointers to series' terms in the internal storage.
127                         template <class Container1, class Container2>
128                         void cache_terms_pointers(const Container1 &c1, const Container2 &c2)
129                         {
130                                 piranha_assert(m_terms1.empty() && m_terms2.empty());
131                                 std::transform(c1.begin(),c1.end(),
132                                         std::insert_iterator<std::vector<typename Series1::term_type const *> >(m_terms1,m_terms1.begin()),
133                                         &(boost::lambda::_1));
134                                 std::transform(c2.begin(),c2.end(),
135                                         std::insert_iterator<std::vector<typename Series2::term_type const *> >(m_terms2,m_terms2.begin()),
136                                         &(boost::lambda::_1));
137                         }
138                 public:
139                         base_series_multiplier(const Series1 &s1, const Series2 &s2, Series1 &retval, const ArgsTuple &args_tuple):
140                                 m_s1(s1), m_s2(s2), m_args_tuple(args_tuple), m_retval(retval)
141                         {
142                                 piranha_assert(s1.length() > 0 && s2.length() > 0);
143                         }
144                         // Plain multiplication.
145                         void perform_plain_multiplication()
146                         {
147                                 perform_plain_threaded_multiplication();
148                         }
149                 private:
150                         template <class GenericTruncator>
151                         struct plain_functor {
152                                 typedef typename term_type1::multiplication_result mult_res;
153                                 plain_functor(mult_res &res,const term_type1 **t1, const term_type2 **t2, const GenericTruncator &trunc,
154                                         Series1 &retval, const ArgsTuple &args_tuple):m_res(res),m_t1(t1),m_t2(t2),
155                                         m_trunc(trunc),m_retval(retval),m_args_tuple(args_tuple)
156                                 {}
157                                 bool operator()(const std::size_t &i, const std::size_t &j)
158                                 {
159                                         if (m_trunc.skip(&m_t1[i], &m_t2[j])) {
160                                                 return false;
161                                         }
162                                         term_type1::multiply(*m_t1[i], *m_t2[j], m_res, m_args_tuple);
163                                         insert_multiplication_result<mult_res>::run(m_res, m_retval, m_args_tuple);
164                                         return true;
165                                 }
166                                 mult_res                &m_res;
167                                 const term_type1        **m_t1;
168                                 const term_type2        **m_t2;
169                                 const GenericTruncator  &m_trunc;
170                                 Series1                 &m_retval;
171                                 const ArgsTuple         &m_args_tuple;
172                         };
173                         struct plain_worker {
174                                 plain_worker(base_series_multiplier &mult, Series1 &retval):
175                                         m_mult(mult),m_retval(retval),m_terms1(mult.m_terms1),m_terms2(mult.m_terms2)
176                                 {}
177                                 plain_worker(base_series_multiplier &mult, Series1 &retval,
178                                         std::vector<std::vector<term_type1 const *> > &split1, const std::size_t &idx):
179                                         m_mult(mult),m_retval(retval),m_terms1(split1[idx]),m_terms2(mult.m_terms2)
180                                 {}
181                                 void operator()()
182                                 {
183                                         // Build the truncator.
184                                         const typename Truncator::template get_type<Series1,Series2,ArgsTuple> trunc(m_terms1,m_terms2,m_mult.m_args_tuple);
185                                         // Use the selected truncator only if it really truncates, otherwise use the
186                                         // null truncator.
187                                         if (trunc.is_effective()) {
188                                                 plain_implementation(trunc);
189                                         } else {
190                                                 plain_implementation(
191                                                         null_truncator::template get_type<Series1,Series2,ArgsTuple>(
192                                                         m_terms1,m_mult.m_terms2,m_mult.m_args_tuple
193                                                         )
194                                                 );
195                                         }
196                                 }
197                                 template <class GenericTruncator>
198                                 void plain_implementation(const GenericTruncator &trunc)
199                                 {
200                                         typedef typename term_type1::multiplication_result mult_res;
201                                         mult_res res;
202                                         const std::size_t size1 = m_terms1.size(), size2 = m_terms2.size();
203                                         piranha_assert(size1 && size2);
204                                         const term_type1 **t1 = &m_terms1[0];
205                                         const term_type2 **t2 = &m_terms2[0];
206                                         plain_functor<GenericTruncator> pf(res,t1,t2,trunc,m_retval,m_mult.m_args_tuple);
207                                         const std::size_t block_size = compute_block_size
208                                                 <boost::tuples::length<mult_res>::value * sizeof(term_type1)>();
209                                         blocked_multiplication(block_size,size1,size2,pf);
210                                 }
211                                 base_series_multiplier          &m_mult;
212                                 Series1                         &m_retval;
213                                 std::vector<term_type1 const *> &m_terms1;
214                                 std::vector<term_type2 const *> m_terms2;
215                         };
216                         // Threaded multiplication.
217                         void perform_plain_threaded_multiplication()
218                         {
219                                 // Effective number of threads to use. If the two series are small, we want to use one single thread.
220                                 // NOTE: here the number 100 is a kind of rule-of thumb. Basically multiplications of series < 10 elements
221                                 // will use just one thread.
222                                 if (double(m_terms1.size()) * double(m_terms2.size()) < 100) {
223                                         stats::trace_stat("mult_st",std::size_t(0),boost::lambda::_1 + 1);
224                                         plain_worker w(*derived_cast,m_retval);
225                                         w();
226                                 } else {
227                                         stats::trace_stat("mult_mt",std::size_t(0),boost::lambda::_1 + 1);
228                                         // TODO: fix numeric casting here.
229                                         // If size1 is less than the number of desired threads,
230                                         // use size1 as number of threads.
231                                         const std::size_t n = std::min(boost::numeric_cast<typename std::vector<term_type1 const *>::size_type>(settings::get_nthread()),m_terms1.size());
232                                         std::vector<std::vector<term_type1 const *> > split1(n);
233                                         // m is the number of terms per thread for regular blocks.
234                                         const std::size_t m = m_terms1.size() / n;
235                                         // Iterate up to n - 1 because that's the number up to which we can divide series1 into
236                                         // regular blocks.
237                                         for (std::size_t i = 0; i < n - 1; ++i) {
238                                                 split1[i].insert(split1[i].end(),m_terms1.begin() + i * m, m_terms1.begin() + (i + 1) * m);
239                                         }
240                                         // Last iteration.
241                                         split1[n - 1].insert(split1[n - 1].end(),m_terms1.begin() + (n - 1) * m, m_terms1.end());
242                                         boost::thread_group tg;
243                                         std::vector<Series1> retvals(n,Series1());
244                                         for (std::size_t i = 0; i < n; ++i) {
245                                                 tg.create_thread(plain_worker(*derived_cast,retvals[i],split1,i));
246                                         }
247                                         tg.join_all();
248                                         // Take the retvals and insert them into final retval.
249                                         for (std::size_t i = 0; i < n; ++i) {
250                                                 m_retval.insert_range(retvals[i].begin(),retvals[i].end(),m_args_tuple);
251                                         }
252                                 }
253                         }
254                 public:
255                         // TODO: make these protected?
256                         // References to the series.
257                         const Series1                                   &m_s1;
258                         const Series2                                   &m_s2;
259                         // Reference to the arguments tuple.
260                         const ArgsTuple                                 &m_args_tuple;
261                         // Reference to the result.
262                         Series1                                         &m_retval;
263                         // Vectors of pointers to the input terms.
264                         std::vector<term_type1 const *>                 m_terms1;
265                         std::vector<term_type2 const *>                 m_terms2;
266         };
267 }
268
269 #undef derived_const_cast
270 #undef derived_cast
271
272 #endif