Storm 1.11.1.1
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
ViOperatorMultiplier.cpp
Go to the documentation of this file.
2
10
11namespace storm::solver {
12
13namespace detail {
15
20template<typename ValueType, BackendOptimizationDirection Dir = BackendOptimizationDirection::None, bool TrackChoices = false>
22 public:
24 requires(!TrackChoices)
25 : choiceTracking(std::nullopt) {};
26
27 MultiplierBackend(std::vector<uint64_t>& choices, std::vector<uint64_t> const& rowGroupIndices)
28 requires TrackChoices
29 : choiceTracking({choices, rowGroupIndices}) {
30 // intentionally left empty.
31 }
32
34 // intentionally left empty.
35 }
36
37 void firstRow(ValueType&& value, [[maybe_unused]] uint64_t rowGroup, [[maybe_unused]] uint64_t row) {
38 best = std::move(value);
39 if constexpr (TrackChoices) {
40 choiceTracking.currentBestRow = row;
41 }
42 }
43
44 void nextRow(ValueType&& value, [[maybe_unused]] uint64_t rowGroup, [[maybe_unused]] uint64_t row) {
45 if constexpr (TrackChoices) {
46 if (best &= value) {
47 choiceTracking.currentBestRow = row;
48 } else if (*best == value) {
49 // Reaching this point means that there are multiple 'best' values
50 // For rows that are equally good, we prefer to not change the currently selected choice.
51 // This is necessary, e.g., for policy iteration correctness and helps keeping schedulers simple.
52 if (row == choiceTracking.choices[rowGroup] + choiceTracking.rowGroupIndices[rowGroup]) {
53 choiceTracking.currentBestRow = row;
54 }
55 }
56 } else if constexpr (HasDir) {
57 best &= value;
58 } else {
59 STORM_LOG_ASSERT(false, "This backend does not support optimization direction.");
60 }
61 }
62
63 void applyUpdate(ValueType& currValue, [[maybe_unused]] uint64_t rowGroup) {
64 if constexpr (HasDir) {
65 currValue = std::move(*best);
66 if constexpr (TrackChoices) {
67 choiceTracking.choices[rowGroup] = choiceTracking.currentBestRow - choiceTracking.rowGroupIndices[rowGroup];
68 }
69 } else {
70 currValue = std::move(best);
71 }
72 }
73
74 void endOfIteration() const {
75 // intentionally left empty.
76 }
77
78 bool converged() const {
79 return true;
80 }
81
82 bool constexpr abort() const {
83 return false;
84 }
85
86 private:
87 static constexpr bool HasDir = Dir != BackendOptimizationDirection::None;
88 static constexpr storm::OptimizationDirection OptDir =
89 Dir == BackendOptimizationDirection::Maximize ? storm::OptimizationDirection::Maximize : storm::OptimizationDirection::Minimize;
90 static_assert(!TrackChoices || HasDir, "If TrackChoices is true, Dir must be set to either Minimize or Maximize.");
91
92 std::conditional_t<HasDir, storm::utility::Extremum<OptDir, ValueType>, ValueType> best;
93
94 struct SchedulerTrackingData {
95 std::vector<uint64_t>& choices; // Storage for the scheduler choices.
96 std::vector<uint64_t> const& rowGroupIndices; // Indices of the row groups.
97 uint64_t currentBestRow{0};
98 };
99 std::conditional_t<TrackChoices, SchedulerTrackingData, std::nullopt_t> choiceTracking; // Only used if TrackChoices is true.
100};
101
106template<typename ValueType>
108 public:
109 PlainMultiplicationBackend(std::vector<ValueType>& rowResults) : rowResults(rowResults) {};
110
112 // intentionally left empty.
113 }
114
115 void firstRow(ValueType&& value, [[maybe_unused]] uint64_t rowGroup, uint64_t row) {
116 rowResults[row] = std::move(value);
117 }
118
119 void nextRow(ValueType&& value, [[maybe_unused]] uint64_t rowGroup, uint64_t row) {
120 rowResults[row] = std::move(value);
121 }
122
123 void applyUpdate([[maybe_unused]] ValueType& currValue, [[maybe_unused]] uint64_t rowGroup) {
124 // intentionally left empty.
125 }
126
127 void endOfIteration() const {
128 // intentionally left empty.
129 }
130
131 bool converged() const {
132 return true;
133 }
134
135 bool constexpr abort() const {
136 return false;
137 }
138
139 private:
140 std::vector<ValueType>& rowResults;
141};
142
143} // namespace detail
144
145template<typename ValueType, bool TrivialRowGrouping>
150
151template<typename ValueType, bool TrivialRowGrouping>
153 if (!viOperatorFwd) {
154 return initialize(false); // default to backward operator
155 } else {
156 return *viOperatorFwd;
157 }
158}
159
160template<typename ValueType, bool TrivialRowGrouping>
161typename ViOperatorMultiplier<ValueType, TrivialRowGrouping>::ViOpT& ViOperatorMultiplier<ValueType, TrivialRowGrouping>::initialize(bool backwards) const {
162 auto& viOp = backwards ? viOperatorBwd : viOperatorFwd;
163 if (!viOp) {
164 viOp = std::make_unique<ViOpT>();
165 if (backwards) {
166 viOp->setMatrixBackwards(this->matrix);
167 } else {
168 viOp->setMatrixForwards(this->matrix);
169 }
170 }
171 return *viOp;
172}
173
174template<typename ValueType, bool TrivialRowGrouping>
175void ViOperatorMultiplier<ValueType, TrivialRowGrouping>::multiply(Environment const& env, std::vector<ValueType> const& x, std::vector<ValueType> const* b,
176 std::vector<ValueType>& result) const {
177 if (&result == &x) {
178 auto& tmpResult = this->provideCachedVector(x.size());
179 multiply(env, x, b, tmpResult);
180 std::swap(result, tmpResult);
181 return;
182 }
183 auto const& viOp = initialize();
185 // Below, we just add 'result' as a dummy argument to the apply method.
186 // The backend already takes care of filling the result vector while processing the rows.
187 if (b) {
188 viOp.apply(x, result, *b, backend);
189 } else {
190 viOp.apply(x, result, storm::utility::zero<ValueType>(), backend);
191 }
192}
193
194template<typename ValueType, bool TrivialRowGrouping>
196 std::vector<ValueType> const* b, bool backwards) const {
197 STORM_LOG_THROW(TrivialRowGrouping, storm::exceptions::NotSupportedException,
198 "This multiplier does not support multiplications without reduction when invoked with non-trivial row groups");
200 auto const& viOp = initialize(backwards);
201 if (b) {
202 viOp.applyInPlace(x, *b, backend);
203 } else {
204 viOp.applyInPlace(x, storm::utility::zero<ValueType>(), backend);
205 }
206}
207
208template<typename ValueType, bool TrivialRowGrouping>
210 std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType> const& x,
211 std::vector<ValueType> const* b, std::vector<ValueType>& result,
212 std::vector<uint64_t>* choices) const {
213 if (&result == &x) {
214 auto& tmpResult = this->provideCachedVector(x.size());
215 multiplyAndReduce(env, dir, rowGroupIndices, x, b, tmpResult, choices);
216 std::swap(result, tmpResult);
217 return;
218 }
219 STORM_LOG_THROW(&rowGroupIndices == &this->matrix.getRowGroupIndices(), storm::exceptions::NotSupportedException,
220 "The row group indices must be the same as the ones stored in the matrix of this multiplier");
221 auto const& viOp = initialize();
222 auto apply = [&]<typename BT>(BT& backend) {
223 if (b) {
224 viOp.apply(x, result, *b, backend);
225 } else {
226 viOp.apply(x, result, storm::utility::zero<ValueType>(), backend);
227 }
228 };
229 if (storm::solver::minimize(dir)) {
230 if (choices) {
232 apply(backend);
233 } else {
235 apply(backend);
236 }
237 } else {
238 if (choices) {
240 apply(backend);
241 } else {
243 apply(backend);
244 }
245 }
246}
247
248template<typename ValueType, bool TrivialRowGrouping>
250 std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x,
251 std::vector<ValueType> const* b, std::vector<uint_fast64_t>* choices,
252 bool backwards) const {
253 STORM_LOG_THROW(&rowGroupIndices == &this->matrix.getRowGroupIndices(), storm::exceptions::NotSupportedException,
254 "The row group indices must be the same as the ones stored in the matrix of this multiplier");
255 auto const& viOp = initialize(backwards);
256 auto apply = [&]<typename BT>(BT& backend) {
257 if (b) {
258 viOp.applyInPlace(x, *b, backend);
259 } else {
260 viOp.applyInPlace(x, storm::utility::zero<ValueType>(), backend);
261 }
262 };
263 if (storm::solver::minimize(dir)) {
264 if (choices) {
266 apply(backend);
267 } else {
269 apply(backend);
270 }
271 } else {
272 if (choices) {
274 apply(backend);
275 } else {
277 apply(backend);
278 }
279 }
280}
281
282template<typename ValueType, bool TrivialRowGrouping>
284 viOperatorBwd.reset();
285 viOperatorFwd.reset();
287};
288
291
294
295} // namespace storm::solver
virtual void clearCache() const
virtual void multiplyAndReduceGaussSeidel(Environment const &env, OptimizationDirection const &dir, std::vector< uint64_t > const &rowGroupIndices, std::vector< ValueType > &x, std::vector< ValueType > const *b, std::vector< uint_fast64_t > *choices=nullptr, bool backwards=true) const override
virtual void multiplyGaussSeidel(Environment const &env, std::vector< ValueType > &x, std::vector< ValueType > const *b, bool backwards=true) const override
Performs a matrix-vector multiplication in gauss-seidel style.
virtual void multiplyAndReduce(Environment const &env, OptimizationDirection const &dir, std::vector< uint64_t > const &rowGroupIndices, std::vector< ValueType > const &x, std::vector< ValueType > const *b, std::vector< ValueType > &result, std::vector< uint_fast64_t > *choices=nullptr) const override
ViOperatorMultiplier(storm::storage::SparseMatrix< ValueType > const &matrix)
virtual void clearCache() const override
virtual void multiply(Environment const &env, std::vector< ValueType > const &x, std::vector< ValueType > const *b, std::vector< ValueType > &result) const override
Performs a matrix-vector multiplication x' = A*x + b.
This backend stores the best (maximal or minimal) value of the current row group.
void applyUpdate(ValueType &currValue, uint64_t rowGroup)
void firstRow(ValueType &&value, uint64_t rowGroup, uint64_t row)
void nextRow(ValueType &&value, uint64_t rowGroup, uint64_t row)
MultiplierBackend(std::vector< uint64_t > &choices, std::vector< uint64_t > const &rowGroupIndices)
This backend simply stores the row results in a vector.
void applyUpdate(ValueType &currValue, uint64_t rowGroup)
PlainMultiplicationBackend(std::vector< ValueType > &rowResults)
void firstRow(ValueType &&value, uint64_t rowGroup, uint64_t row)
void nextRow(ValueType &&value, uint64_t rowGroup, uint64_t row)
This class represents the Value Iteration Operator (also known as Bellman operator).
A class that holds a possibly non-square matrix in the compressed row storage format.
#define STORM_LOG_ASSERT(cond, message)
Definition macros.h:11
#define STORM_LOG_THROW(cond, exception, message)
Definition macros.h:30
bool constexpr minimize(OptimizationDirection d)