00001 #ifndef SCITBX_ARRAY_FAMILY_VERSA_MATRIX_H
00002 #define SCITBX_ARRAY_FAMILY_VERSA_MATRIX_H
00003
00004 #include <scitbx/array_family/versa.h>
00005 #include <scitbx/array_family/shared.h>
00006 #include <scitbx/array_family/accessors/flex_grid.h>
00007 #include <scitbx/array_family/accessors/c_grid.h>
00008 #include <scitbx/matrix/lu_decomposition.h>
00009 #include <scitbx/matrix/inversion.h>
00010 #include <scitbx/matrix/diagonal.h>
00011 #include <scitbx/matrix/packed.h>
00012 #include <scitbx/mat_ref.h>
00013 #include <scitbx/constants.h>
00014 #include <boost/optional.hpp>
00015 #include <boost/scoped_array.hpp>
00016
00017 namespace scitbx { namespace af {
00018
00019 template <typename NumType>
00020 shared<NumType>
00021 matrix_diagonal(
00022 const_ref<NumType, c_grid<2> > const& a)
00023 {
00024 SCITBX_ASSERT(a.accessor().is_square());
00025 shared<NumType> result(a.accessor()[0], init_functor_null<NumType>());
00026 matrix::diagonal(
00027 a.begin(), a.accessor()[0], result.begin());
00028 return result;
00029 }
00030
00031 template <typename NumType>
00032 void
00033 matrix_diagonal_set_in_place(
00034 ref<NumType, c_grid<2> > const& a,
00035 NumType const& value)
00036 {
00037 SCITBX_ASSERT(a.accessor().is_square());
00038 typedef typename c_grid<2>::index_value_type ivt;
00039 ivt n = a.accessor()[0];
00040 ivt n_sq = n*n;
00041 ivt n_plus_1 = n + 1;
00042 for(ivt i=0;i<n_sq;i+=n_plus_1) {
00043 a[i] = value;
00044 }
00045 }
00046
00047 template <typename NumType>
00048 void
00049 matrix_diagonal_add_in_place(
00050 ref<NumType, c_grid<2> > const& a,
00051 NumType const& value)
00052 {
00053 SCITBX_ASSERT(a.accessor().is_square());
00054 typedef typename c_grid<2>::index_value_type ivt;
00055 ivt n = a.accessor()[0];
00056 ivt n_sq = n*n;
00057 ivt n_plus_1 = n + 1;
00058 for(ivt i=0;i<n_sq;i+=n_plus_1) {
00059 a[i] += value;
00060 }
00061 }
00062
00063 template <typename NumType>
00064 NumType
00065 matrix_diagonal_sum(
00066 const_ref<NumType, c_grid<2> > const& a)
00067 {
00068 SCITBX_ASSERT(a.accessor().is_square());
00069 return matrix::diagonal_sum(a.begin(), a.accessor()[0]);
00070 }
00071
00072 template <typename NumType>
00073 NumType
00074 matrix_diagonal_product(
00075 const_ref<NumType, c_grid<2> > const& a)
00076 {
00077 SCITBX_ASSERT(a.accessor().is_square());
00078 return matrix::diagonal_product(a.begin(), a.accessor()[0]);
00079 }
00080
00081 template <typename NumTypeA, typename NumTypeB>
00082 versa<
00083 typename binary_operator_traits<NumTypeA, NumTypeB>::arithmetic,
00084 c_grid<2> >
00085 matrix_multiply(
00086 const_ref<NumTypeA, c_grid<2> > const& a,
00087 const_ref<NumTypeB, c_grid<2> > const& b)
00088 {
00089 typedef typename
00090 binary_operator_traits<NumTypeA, NumTypeB>::arithmetic
00091 numtype_ab;
00092 versa<numtype_ab, c_grid<2> > ab(
00093 c_grid<2>(a.accessor()[0], b.accessor()[1]),
00094 init_functor_null<numtype_ab>());
00095 mat_const_ref<NumTypeA> a_(a.begin(), a.accessor()[0], a.accessor()[1]);
00096 mat_const_ref<NumTypeB> b_(b.begin(), b.accessor()[0], b.accessor()[1]);
00097 mat_ref<numtype_ab> ab_(ab.begin(), ab.accessor()[0], ab.accessor()[1]);
00098 multiply(a_, b_, ab_);
00099 return ab;
00100 }
00101
00102 template <typename NumType>
00103 shared<NumType>
00104 matrix_multiply(
00105 const_ref<NumType, c_grid<2> > const& a,
00106 const_ref<NumType> const& b)
00107 {
00108 shared<NumType> ab(a.accessor()[0], init_functor_null<NumType>());
00109 mat_const_ref<NumType> a_(a.begin(), a.accessor()[0], a.accessor()[1]);
00110 mat_const_ref<NumType> b_(b.begin(), b.size(), 1);
00111 mat_ref<NumType> ab_(ab.begin(), a.accessor()[0], 1);
00112 multiply(a_, b_, ab_);
00113 return ab;
00114 }
00115
00116 template <typename NumType>
00117 shared<NumType>
00118 matrix_multiply(
00119 const_ref<NumType> const& a,
00120 const_ref<NumType, c_grid<2> > const& b)
00121 {
00122 shared<NumType> ab(b.accessor()[1], init_functor_null<NumType>());
00123 mat_const_ref<NumType> a_(a.begin(), 1, a.size());
00124 mat_const_ref<NumType> b_(b.begin(), b.accessor()[0], b.accessor()[1]);
00125 mat_ref<NumType> ab_(ab.begin(), 1, b.accessor()[1]);
00126 multiply(a_, b_, ab_);
00127 return ab;
00128 }
00129
00130 template <typename NumType>
00131 NumType
00132 matrix_multiply(
00133 const_ref<NumType> const& a,
00134 const_ref<NumType> const& b)
00135 {
00136 NumType ab;
00137 mat_const_ref<NumType> a_(a.begin(), 1, a.size());
00138 mat_const_ref<NumType> b_(b.begin(), b.size(), 1);
00139 mat_ref<NumType> ab_(&ab, 1, 1);
00140 multiply(a_, b_, ab_);
00141 return ab;
00142 }
00143
00144 template <typename NumTypeA, typename NumTypeB>
00145 versa<
00146 typename binary_operator_traits<NumTypeA, NumTypeB>::arithmetic,
00147 c_grid<2> >
00148 matrix_multiply_packed_u(
00149 const_ref<NumTypeA, c_grid<2> > const& a,
00150 const_ref<NumTypeB> const& b)
00151 {
00152 unsigned a_n_rows = a.accessor()[0];
00153 unsigned a_n_columns = a.accessor()[1];
00154 SCITBX_ASSERT(matrix::symmetric_n_from_packed_size(b.size())
00155 == a_n_columns);
00156 typedef typename
00157 binary_operator_traits<NumTypeA, NumTypeB>::arithmetic
00158 numtype_ab;
00159 versa<numtype_ab, c_grid<2> > ab(
00160 c_grid<2>(a_n_rows, a_n_columns),
00161 init_functor_null<numtype_ab>());
00162 matrix::multiply_packed_u(
00163 a.begin(), b.begin(), a_n_rows, a_n_columns, ab.begin());
00164 return ab;
00165 }
00166
00167 template <typename NumTypeA, typename NumTypeB>
00168 shared<typename binary_operator_traits<NumTypeA, NumTypeB>::arithmetic>
00169 matrix_multiply_packed_u_multiply_lhs_transpose(
00170 const_ref<NumTypeA, c_grid<2> > const& a,
00171 const_ref<NumTypeB> const& b)
00172 {
00173 unsigned a_n_rows = a.accessor()[0];
00174 unsigned a_n_columns = a.accessor()[1];
00175 SCITBX_ASSERT(matrix::symmetric_n_from_packed_size(b.size())
00176 == a_n_columns);
00177 typedef typename
00178 binary_operator_traits<NumTypeA, NumTypeB>::arithmetic
00179 numtype_ab;
00180 shared<numtype_ab> abat(
00181 a_n_rows*(a_n_rows+1)/2, init_functor_null<numtype_ab>());
00182 boost::scoped_array<numtype_ab> ab(new numtype_ab[a_n_rows * a_n_columns]);
00183 matrix::multiply_packed_u_multiply_lhs_transpose(
00184 a.begin(),
00185 b.begin(),
00186 a_n_rows,
00187 a_n_columns,
00188 ab.get(),
00189 abat.begin());
00190 return abat;
00191 }
00192
00193 template <typename NumType>
00194 shared<NumType>
00195 matrix_transpose_multiply_as_packed_u(
00196 const_ref<NumType, c_grid<2> > const& a)
00197 {
00198 unsigned na = a.accessor()[1];
00199 shared<NumType> ata(na*(na+1)/2, init_functor_null<NumType>());
00200 matrix::transpose_multiply_as_packed_u(
00201 a.begin(), a.accessor()[0], na, ata.begin());
00202 return ata;
00203 }
00204
00205 template <typename NumType>
00206 shared<NumType>
00207 matrix_transpose_multiply_diagonal_multiply_as_packed_u(
00208 const_ref<NumType, c_grid<2> > const& a,
00209 const_ref<NumType> const& diagonal_elements)
00210 {
00211 SCITBX_ASSERT(a.accessor().is_square());
00212 unsigned n = a.accessor()[0];
00213 shared<NumType> atda(n*(n+1)/2, init_functor_null<NumType>());
00214 matrix::transpose_multiply_diagonal_multiply_as_packed_u(
00215 a.begin(), diagonal_elements.begin(), n, atda.begin());
00216 return atda;
00217 }
00218
00219 template <typename NumType>
00220 versa<NumType, c_grid<2> >
00221 matrix_transpose(const_ref<NumType, c_grid<2> > const& a)
00222 {
00223 typedef typename c_grid<2>::value_type index_value_type;
00224 index_value_type n_rows = a.accessor()[0];
00225 index_value_type n_columns = a.accessor()[1];
00226 versa<NumType, c_grid<2> > result(
00227 c_grid<2>(n_columns, n_rows), init_functor_null<NumType>());
00228 NumType* r = result.begin();
00229 for (index_value_type ic=0;ic<n_columns;ic++) {
00230 std::size_t ir_nc_ic = ic;
00231 for (index_value_type ir=0;ir<n_rows;ir++,ir_nc_ic+=n_columns) {
00232 *r++ = a[ir_nc_ic];
00233 }
00234 }
00235 return result;
00236 }
00237
00238 template <typename NumType, typename FlexGridIndexType>
00239 void
00240 matrix_transpose_in_place(versa<NumType, flex_grid<FlexGridIndexType> >& a)
00241 {
00242 SCITBX_ASSERT(a.accessor().nd() == 2);
00243 SCITBX_ASSERT(a.accessor().is_0_based());
00244 SCITBX_ASSERT(!a.accessor().is_padded());
00245 typedef typename FlexGridIndexType::value_type index_value_type;
00246 index_value_type n_rows = a.accessor().all()[0];
00247 index_value_type n_columns = a.accessor().all()[1];
00248 mat_ref<NumType> a_(a.begin(), n_rows, n_columns);
00249 a_.transpose_in_place();
00250 a.resize(flex_grid<FlexGridIndexType>(n_columns, n_rows));
00251 }
00252
00253 template <typename FloatType>
00254 shared<std::size_t>
00255 matrix_lu_decomposition_in_place(
00256 ref<FloatType, c_grid<2> > const& a)
00257 {
00258 SCITBX_ASSERT(a.accessor().is_square());
00259 shared<std::size_t>
00260 pivot_indices(a.accessor()[0]+1, init_functor_null<std::size_t>());
00261 matrix::lu_decomposition_in_place(
00262 a.begin(), a.accessor()[0], pivot_indices.begin());
00263 return pivot_indices;
00264 }
00265
00266 template <typename FloatType>
00267 shared<FloatType>
00268 matrix_lu_back_substitution(
00269 const_ref<FloatType, c_grid<2> > const& a,
00270 const_ref<std::size_t> const& pivot_indices,
00271 const_ref<FloatType> const& b)
00272 {
00273 SCITBX_ASSERT(a.accessor().is_square());
00274 SCITBX_ASSERT(pivot_indices.size() == a.accessor()[0]+1);
00275 SCITBX_ASSERT(b.size() == a.accessor()[0]);
00276 shared<FloatType> x(b.begin(), b.end());
00277 matrix::lu_back_substitution(
00278 a.begin(), a.accessor()[0], pivot_indices.begin(), x.begin());
00279 return x;
00280 }
00281
00282 template <typename FloatType>
00283 FloatType
00284 matrix_determinant_via_lu(
00285 const_ref<FloatType, c_grid<2> > const& a,
00286 const_ref<std::size_t> const& pivot_indices)
00287 {
00288 SCITBX_ASSERT(a.accessor().is_square());
00289 SCITBX_ASSERT(pivot_indices.size() == a.accessor()[0]+1);
00290 FloatType result = matrix_diagonal_product(a);
00291 if (pivot_indices[a.accessor()[0]] % 2) result = -result;
00292 return result;
00293 }
00294
00295 template <typename FloatType>
00296 FloatType
00297 matrix_determinant_via_lu(
00298 const_ref<FloatType, c_grid<2> > const& a)
00299 {
00300 SCITBX_ASSERT(a.accessor().is_square());
00301 boost::scoped_array<FloatType> a_(new FloatType[a.accessor().size_1d()]);
00302 std::copy(a.begin(), a.end(), a_.get());
00303 FloatType result;
00304 try {
00305 shared<std::size_t>
00306 pivot_indices = matrix_lu_decomposition_in_place(
00307 ref<FloatType, c_grid<2> >(a_.get(), a.accessor()));
00308 result = matrix_diagonal_product(
00309 const_ref<FloatType, c_grid<2> >(a_.get(), a.accessor()));
00310 if (pivot_indices[a.accessor()[0]] % 2) result = -result;
00311 }
00312 catch (std::runtime_error const& e) {
00313 if (std::string(e.what())
00314 != "lu_decomposition_in_place: singular matrix") throw;
00315 result = 0;
00316 }
00317 return result;
00318 }
00319
00320 template <typename FloatType>
00321 void
00322 matrix_inversion_in_place(
00323 ref<FloatType, c_grid<2> > const& a,
00324 ref<FloatType, c_grid<2> > const& b)
00325 {
00326 SCITBX_ASSERT(a.accessor().is_square());
00327 if ( b.accessor()[0] != 0
00328 && b.accessor()[1] != a.accessor()[0]) {
00329 throw std::runtime_error(
00330 "matrix_inversion_in_place: if a is a (n*n) matrix b must be (m*n)");
00331 }
00332 matrix::inversion_in_place(
00333 a.begin(),
00334 static_cast<std::size_t>(a.accessor()[0]),
00335 b.begin(),
00336 static_cast<std::size_t>(b.accessor()[0]));
00337 }
00338
00339 template <typename FloatType>
00340 void
00341 matrix_inversion_in_place(
00342 ref<FloatType, c_grid<2> > const& a)
00343 {
00344 matrix_inversion_in_place(
00345 a, ref<FloatType, c_grid<2> >(0, c_grid<2>(0,0)));
00346 }
00347
00348 template <typename FloatType>
00349 boost::optional<FloatType>
00350 cos_angle(
00351 const_ref<FloatType> const& a,
00352 const_ref<FloatType> const& b)
00353 {
00354 SCITBX_ASSERT(b.size() == a.size());
00355 FloatType a_sum_sq = 0;
00356 FloatType b_sum_sq = 0;
00357 FloatType a_dot_b = 0;
00358 for(std::size_t i=0;i<a.size();i++) {
00359 const FloatType& ai = a[i];
00360 a_sum_sq += ai * ai;
00361 const FloatType& bi = b[i];
00362 b_sum_sq += bi * bi;
00363 a_dot_b += ai * bi;
00364 }
00365 if (a_sum_sq == 0 || b_sum_sq == 0) {
00366 return boost::optional<FloatType>();
00367 }
00368 FloatType d = a_sum_sq * b_sum_sq;
00369 if (d == 0) return boost::optional<FloatType>();
00370 return boost::optional<FloatType>(a_dot_b / std::sqrt(d));
00371 }
00372
00373 template <typename FloatType>
00374 FloatType
00375 cos_angle(
00376 const_ref<FloatType> const& a,
00377 const_ref<FloatType> const& b,
00378 FloatType const& value_if_undefined)
00379 {
00380 boost::optional<FloatType> result = cos_angle(a, b);
00381 if (result) return *result;
00382 return value_if_undefined;
00383 }
00384
00385 template <typename FloatType>
00386 boost::optional<FloatType>
00387 angle(
00388 const_ref<FloatType> const& a,
00389 const_ref<FloatType> const& b)
00390 {
00391 boost::optional<FloatType> c = cos_angle(a, b);
00392 if (!c) return c;
00393 FloatType cv = *c;
00394 if (cv > 1) cv = static_cast<FloatType>(1);
00395 else if (cv < -1) cv = static_cast<FloatType>(-1);
00396 FloatType result = std::acos(cv);
00397 return boost::optional<FloatType>(result);
00398 }
00399
00400 template <typename FloatType>
00401 boost::optional<FloatType>
00402 angle(
00403 const_ref<FloatType> const& a,
00404 const_ref<FloatType> const& b,
00405 bool deg)
00406 {
00407 boost::optional<FloatType> rad = angle(a, b);
00408 if (!rad || !deg) return rad;
00409 return boost::optional<FloatType>((*rad) / constants::pi_180);
00410 }
00411
00412 template <typename ElementType>
00413 versa<ElementType, c_grid<2> >
00414 mat_const_ref_as_versa(
00415 scitbx::mat_const_ref<ElementType> const& m)
00416 {
00417 versa<ElementType, c_grid<2> > result(
00418 c_grid<2>(m.n_rows(), m.n_columns()),
00419 init_functor_null<ElementType>());
00420 if (m.begin() != 0) {
00421 std::copy(m.begin(), m.end(), result.begin());
00422 }
00423 else {
00424 SCITBX_ASSERT(m.size() == 0);
00425 }
00426 return result;
00427 }
00428
00429 }}
00430
00431 #endif // SCITBX_ARRAY_FAMILY_VERSA_MATRIX_H