Storm
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
ViOperatorMultiplier.cpp
Go to the documentation of this file.
2
6
10
12
13namespace storm::solver {
14
15namespace detail {
17
22template<typename ValueType, BackendOptimizationDirection Dir = BackendOptimizationDirection::None, bool TrackChoices = false>
24 public:
26 requires(!TrackChoices)
27 : choiceTracking(std::nullopt) {};
28
29 MultiplierBackend(std::vector<uint64_t>& choices, std::vector<uint64_t> const& rowGroupIndices)
30 requires TrackChoices
31 : choiceTracking({choices, rowGroupIndices}) {
32 // intentionally left empty.
33 }
34
36 // intentionally left empty.
37 }
38
39 void firstRow(ValueType&& value, [[maybe_unused]] uint64_t rowGroup, [[maybe_unused]] uint64_t row) {
40 best = std::move(value);
41 if constexpr (TrackChoices) {
42 choiceTracking.currentBestRow = row;
43 }
44 }
45
46 void nextRow(ValueType&& value, [[maybe_unused]] uint64_t rowGroup, [[maybe_unused]] uint64_t row) {
47 if constexpr (TrackChoices) {
48 if (best &= value) {
49 choiceTracking.currentBestRow = row;
50 } else if (*best == value) {
51 // Reaching this point means that there are multiple 'best' values
52 // For rows that are equally good, we prefer to not change the currently selected choice.
53 // This is necessary, e.g., for policy iteration correctness and helps keeping schedulers simple.
54 if (row == choiceTracking.choices[rowGroup] + choiceTracking.rowGroupIndices[rowGroup]) {
55 choiceTracking.currentBestRow = row;
56 }
57 }
58 } else if constexpr (HasDir) {
59 best &= value;
60 } else {
61 STORM_LOG_ASSERT(false, "This backend does not support optimization direction.");
62 }
63 }
64
65 void applyUpdate(ValueType& currValue, [[maybe_unused]] uint64_t rowGroup) {
66 if constexpr (HasDir) {
67 currValue = std::move(*best);
68 if constexpr (TrackChoices) {
69 choiceTracking.choices[rowGroup] = choiceTracking.currentBestRow - choiceTracking.rowGroupIndices[rowGroup];
70 }
71 } else {
72 currValue = std::move(best);
73 }
74 }
75
76 void endOfIteration() const {
77 // intentionally left empty.
78 }
79
80 bool converged() const {
81 return true;
82 }
83
84 bool constexpr abort() const {
85 return false;
86 }
87
88 private:
89 static constexpr bool HasDir = Dir != BackendOptimizationDirection::None;
90 static constexpr storm::OptimizationDirection OptDir =
91 Dir == BackendOptimizationDirection::Maximize ? storm::OptimizationDirection::Maximize : storm::OptimizationDirection::Minimize;
92 static_assert(!TrackChoices || HasDir, "If TrackChoices is true, Dir must be set to either Minimize or Maximize.");
93
94 std::conditional_t<HasDir, storm::utility::Extremum<OptDir, ValueType>, ValueType> best;
95
96 struct SchedulerTrackingData {
97 std::vector<uint64_t>& choices; // Storage for the scheduler choices.
98 std::vector<uint64_t> const& rowGroupIndices; // Indices of the row groups.
99 uint64_t currentBestRow{0};
100 };
101 std::conditional_t<TrackChoices, SchedulerTrackingData, std::nullopt_t> choiceTracking; // Only used if TrackChoices is true.
102};
103
108template<typename ValueType>
110 public:
111 PlainMultiplicationBackend(std::vector<ValueType>& rowResults) : rowResults(rowResults) {};
112
114 // intentionally left empty.
115 }
116
117 void firstRow(ValueType&& value, [[maybe_unused]] uint64_t rowGroup, uint64_t row) {
118 rowResults[row] = std::move(value);
119 }
120
121 void nextRow(ValueType&& value, [[maybe_unused]] uint64_t rowGroup, uint64_t row) {
122 rowResults[row] = std::move(value);
123 }
124
125 void applyUpdate([[maybe_unused]] ValueType& currValue, [[maybe_unused]] uint64_t rowGroup) {
126 // intentionally left empty.
127 }
128
129 void endOfIteration() const {
130 // intentionally left empty.
131 }
132
133 bool converged() const {
134 return true;
135 }
136
137 bool constexpr abort() const {
138 return false;
139 }
140
141 private:
142 std::vector<ValueType>& rowResults;
143};
144
145} // namespace detail
146
147template<typename ValueType, bool TrivialRowGrouping>
152
153template<typename ValueType, bool TrivialRowGrouping>
155 if (!viOperatorFwd) {
156 return initialize(false); // default to backward operator
157 } else {
158 return *viOperatorFwd;
159 }
160}
161
162template<typename ValueType, bool TrivialRowGrouping>
163typename ViOperatorMultiplier<ValueType, TrivialRowGrouping>::ViOpT& ViOperatorMultiplier<ValueType, TrivialRowGrouping>::initialize(bool backwards) const {
164 auto& viOp = backwards ? viOperatorBwd : viOperatorFwd;
165 if (!viOp) {
166 viOp = std::make_unique<ViOpT>();
167 if (backwards) {
168 viOp->setMatrixBackwards(this->matrix);
169 } else {
170 viOp->setMatrixForwards(this->matrix);
171 }
172 }
173 return *viOp;
174}
175
176template<typename ValueType, bool TrivialRowGrouping>
177void ViOperatorMultiplier<ValueType, TrivialRowGrouping>::multiply(Environment const& env, std::vector<ValueType> const& x, std::vector<ValueType> const* b,
178 std::vector<ValueType>& result) const {
179 if (&result == &x) {
180 auto& tmpResult = this->provideCachedVector(x.size());
181 multiply(env, x, b, tmpResult);
182 std::swap(result, tmpResult);
183 return;
184 }
185 auto const& viOp = initialize();
187 // Below, we just add 'result' as a dummy argument to the apply method.
188 // The backend already takes care of filling the result vector while processing the rows.
189 if (b) {
190 viOp.apply(x, result, *b, backend);
191 } else {
192 viOp.apply(x, result, storm::utility::zero<ValueType>(), backend);
193 }
194}
195
196template<typename ValueType, bool TrivialRowGrouping>
198 std::vector<ValueType> const* b, bool backwards) const {
199 STORM_LOG_THROW(TrivialRowGrouping, storm::exceptions::NotSupportedException,
200 "This multiplier does not support multiplications without reduction when invoked with non-trivial row groups");
202 auto const& viOp = initialize(backwards);
203 if (b) {
204 viOp.applyInPlace(x, *b, backend);
205 } else {
206 viOp.applyInPlace(x, storm::utility::zero<ValueType>(), backend);
207 }
208}
209
210template<typename ValueType, bool TrivialRowGrouping>
212 std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType> const& x,
213 std::vector<ValueType> const* b, std::vector<ValueType>& result,
214 std::vector<uint64_t>* choices) const {
215 if (&result == &x) {
216 auto& tmpResult = this->provideCachedVector(x.size());
217 multiplyAndReduce(env, dir, rowGroupIndices, x, b, tmpResult, choices);
218 std::swap(result, tmpResult);
219 return;
220 }
221 STORM_LOG_THROW(&rowGroupIndices == &this->matrix.getRowGroupIndices(), storm::exceptions::NotSupportedException,
222 "The row group indices must be the same as the ones stored in the matrix of this multiplier");
223 auto const& viOp = initialize();
224 auto apply = [&]<typename BT>(BT& backend) {
225 if (b) {
226 viOp.apply(x, result, *b, backend);
227 } else {
228 viOp.apply(x, result, storm::utility::zero<ValueType>(), backend);
229 }
230 };
231 if (storm::solver::minimize(dir)) {
232 if (choices) {
234 apply(backend);
235 } else {
237 apply(backend);
238 }
239 } else {
240 if (choices) {
242 apply(backend);
243 } else {
245 apply(backend);
246 }
247 }
248}
249
250template<typename ValueType, bool TrivialRowGrouping>
252 std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x,
253 std::vector<ValueType> const* b, std::vector<uint_fast64_t>* choices,
254 bool backwards) const {
255 STORM_LOG_THROW(&rowGroupIndices == &this->matrix.getRowGroupIndices(), storm::exceptions::NotSupportedException,
256 "The row group indices must be the same as the ones stored in the matrix of this multiplier");
257 auto const& viOp = initialize(backwards);
258 auto apply = [&]<typename BT>(BT& backend) {
259 if (b) {
260 viOp.applyInPlace(x, *b, backend);
261 } else {
262 viOp.applyInPlace(x, storm::utility::zero<ValueType>(), backend);
263 }
264 };
265 if (storm::solver::minimize(dir)) {
266 if (choices) {
268 apply(backend);
269 } else {
271 apply(backend);
272 }
273 } else {
274 if (choices) {
276 apply(backend);
277 } else {
279 apply(backend);
280 }
281 }
282}
283
284template<typename ValueType, bool TrivialRowGrouping>
286 viOperatorBwd.reset();
287 viOperatorFwd.reset();
289};
290
293
296
297} // namespace storm::solver
virtual void clearCache() const
virtual void multiplyAndReduceGaussSeidel(Environment const &env, OptimizationDirection const &dir, std::vector< uint64_t > const &rowGroupIndices, std::vector< ValueType > &x, std::vector< ValueType > const *b, std::vector< uint_fast64_t > *choices=nullptr, bool backwards=true) const override
virtual void multiplyGaussSeidel(Environment const &env, std::vector< ValueType > &x, std::vector< ValueType > const *b, bool backwards=true) const override
Performs a matrix-vector multiplication in gauss-seidel style.
virtual void multiplyAndReduce(Environment const &env, OptimizationDirection const &dir, std::vector< uint64_t > const &rowGroupIndices, std::vector< ValueType > const &x, std::vector< ValueType > const *b, std::vector< ValueType > &result, std::vector< uint_fast64_t > *choices=nullptr) const override
ViOperatorMultiplier(storm::storage::SparseMatrix< ValueType > const &matrix)
virtual void clearCache() const override
virtual void multiply(Environment const &env, std::vector< ValueType > const &x, std::vector< ValueType > const *b, std::vector< ValueType > &result) const override
Performs a matrix-vector multiplication x' = A*x + b.
This backend stores the best (maximal or minimal) value of the current row group.
void applyUpdate(ValueType &currValue, uint64_t rowGroup)
void firstRow(ValueType &&value, uint64_t rowGroup, uint64_t row)
void nextRow(ValueType &&value, uint64_t rowGroup, uint64_t row)
MultiplierBackend(std::vector< uint64_t > &choices, std::vector< uint64_t > const &rowGroupIndices)
This backend simply stores the row results in a vector.
void applyUpdate(ValueType &currValue, uint64_t rowGroup)
PlainMultiplicationBackend(std::vector< ValueType > &rowResults)
void firstRow(ValueType &&value, uint64_t rowGroup, uint64_t row)
void nextRow(ValueType &&value, uint64_t rowGroup, uint64_t row)
This class represents the Value Iteration Operator (also known as Bellman operator).
A class that holds a possibly non-square matrix in the compressed row storage format.
#define STORM_LOG_ASSERT(cond, message)
Definition macros.h:11
#define STORM_LOG_THROW(cond, exception, message)
Definition macros.h:30
bool constexpr minimize(OptimizationDirection d)