Storm
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
LpMinMaxLinearEquationSolver.cpp
Go to the documentation of this file.
2
14
15namespace storm {
16namespace solver {
17
18template<typename ValueType>
20 : lpSolverFactory(std::move(lpSolverFactory)) {
21 // Intentionally left empty.
22}
23
24template<typename ValueType>
26 std::unique_ptr<storm::utility::solver::LpSolverFactory<ValueType>>&& lpSolverFactory)
27 : StandardMinMaxLinearEquationSolver<ValueType>(A), lpSolverFactory(std::move(lpSolverFactory)) {
28 // Intentionally left empty.
29}
30
31template<typename ValueType>
33 std::unique_ptr<storm::utility::solver::LpSolverFactory<ValueType>>&& lpSolverFactory)
34 : StandardMinMaxLinearEquationSolver<ValueType>(std::move(A)), lpSolverFactory(std::move(lpSolverFactory)) {
35 // Intentionally left empty.
36}
37
38template<typename ValueType>
40 std::vector<ValueType> const& b) const {
41 if (env.solver().minMax().getMethod() == MinMaxMethod::LinearProgramming) {
42 return solveEquationsLp(env, dir, x, b);
43 } else {
44 STORM_LOG_THROW(env.solver().minMax().getMethod() == MinMaxMethod::ViToLp, storm::exceptions::InvalidEnvironmentException,
45 "This min max solver does not support the selected technique.");
46 return solveEquationsViToLp(env, dir, x, b);
47 }
48}
49
50template<typename ValueType>
52 std::vector<ValueType> const& b) const {
53 // First create an (inprecise) vi solver to get a good initial bound.
54 STORM_LOG_THROW(!this->choiceFixedForRowGroup, storm::exceptions::NotImplementedException, "Fixed choices not implemented for this solution method.");
55 {
56 auto viOperator = std::make_shared<helper::ValueIterationOperator<double, false>>();
57 if constexpr (std::is_same_v<ValueType, double>) {
58 viOperator->setMatrixBackwards(*this->A);
59 } else {
60 viOperator->setMatrixBackwards(this->A->template toValueType<double>(), &this->A->getRowGroupIndices());
61 }
63 uint64_t numIterations{0};
64 auto viCallback = [&](SolverStatus const& current) {
65 this->showProgressIterative(numIterations);
66 return this->updateStatus(current, false, numIterations, env.solver().minMax().getMaximalNumberOfIterations());
67 };
68 if (minimize(dir)) {
69 this->createUpperBoundsVector(x);
70 } else {
71 this->createLowerBoundsVector(x);
72 }
73 this->startMeasureProgress();
74 if constexpr (std::is_same_v<ValueType, double>) {
75 viHelper.VI(x, b, numIterations, env.solver().minMax().getRelativeTerminationCriterion(),
76 storm::utility::convertNumber<double>(env.solver().minMax().getPrecision()), dir, viCallback);
77 } else {
78 // convert from/to double
79 auto xVi = storm::utility::vector::convertNumericVector<double>(x);
80 auto bVi = storm::utility::vector::convertNumericVector<double>(b);
81 double const precision = storm::utility::convertNumber<double>(env.solver().minMax().getPrecision());
82 bool const relative = env.solver().minMax().getRelativeTerminationCriterion();
83 viHelper.VI(xVi, bVi, numIterations, relative, precision, dir, viCallback);
84 auto xIt = xVi.cbegin();
85 for (auto& xi : x) {
86 xi = storm::utility::convertNumber<ValueType>(*xIt);
87 ++xIt;
88 }
89 }
90 }
91 STORM_LOG_DEBUG("Found initial values using Value Iteration. Starting LP solving now.");
92 bool res = false;
93 if (minimize(dir)) {
94 res = solveEquationsLp(env, dir, x, b, nullptr, &x); // upper bounds
95 } else {
96 res = solveEquationsLp(env, dir, x, b, &x, nullptr); // lower bounds
97 }
98
99 if (!res) {
100 return false;
101 }
102
104 // The above-computed bounds might be incorrect. To obtain a correct procedure, we catch those cases here!
105 for (uint64_t rowGroup = 0; rowGroup < this->A->getRowGroupCount(); ++rowGroup) {
106 uint64_t row = this->A->getRowGroupIndices()[rowGroup];
107 ValueType optimalGroupValue = this->A->multiplyRowWithVector(row, x) + b[row];
108 for (++row; row < this->A->getRowGroupIndices()[rowGroup + 1]; ++row) {
109 ValueType rowValue = this->A->multiplyRowWithVector(row, x) + b[row];
110 if ((minimize(dir) && rowValue < optimalGroupValue) || (maximize(dir) && rowValue > optimalGroupValue)) {
111 optimalGroupValue = rowValue;
112 }
113 }
114 if (x[rowGroup] != optimalGroupValue) {
115 STORM_LOG_WARN("LP with provided bounds is incorrect. Restarting without bounds.");
116 return solveEquationsLp(env, dir, x, b); // no bounds
117 }
118 }
119 }
120
121 return true;
122}
123
124template<typename ValueType>
125bool LpMinMaxLinearEquationSolver<ValueType>::solveEquationsLp(Environment const& env, OptimizationDirection dir, std::vector<ValueType>& x,
126 std::vector<ValueType> const& b, std::vector<ValueType> const* lowerBounds,
127 std::vector<ValueType> const* upperBounds) const {
128 // Determine the variant of the encoding
129 // Enforcing a global bound is only enabled if there is a single initial state.
130 // Otherwise, cases where one state satisfies the bound but another does not will be difficult.
131 bool const optimizeOnlyRelevant = this->hasRelevantValues() && env.solver().minMax().lp().getOptimizeOnlyForInitialState();
132 STORM_LOG_DEBUG("Optimize only for relevant state requested:" << env.solver().minMax().lp().getOptimizeOnlyForInitialState());
133 if (optimizeOnlyRelevant) {
134 STORM_LOG_TRACE("Relevant values " << this->getRelevantValues());
135 } else if (!this->hasRelevantValues()) {
136 STORM_LOG_DEBUG("No relevant values set! Optimizing over all states.");
137 }
138
139 // Set-up lower/upper bounds
140 std::function<ValueType(uint64_t const&)> lower, upper;
141 if (this->hasLowerBound() && lowerBounds == nullptr) {
142 lower = [this](uint64_t const& i) { return this->getLowerBound(i); };
143 } else if (!this->hasLowerBound() && lowerBounds != nullptr) {
144 STORM_LOG_ASSERT(lowerBounds->size() == x.size(), "lower bounds vector has invalid size.");
145 lower = [&lowerBounds](uint64_t const& i) { return (*lowerBounds)[i]; };
146 } else if (this->hasLowerBound() && lowerBounds != nullptr) {
147 STORM_LOG_ASSERT(lowerBounds->size() == x.size(), "lower bounds vector has invalid size.");
148 lower = [&lowerBounds, this](uint64_t const& i) { return std::max(this->getLowerBound(i), (*lowerBounds)[i]); };
149 }
150 if (this->hasUpperBound() && upperBounds == nullptr) {
151 upper = [this](uint64_t const& i) { return this->getUpperBound(i); };
152 } else if (!this->hasUpperBound() && upperBounds != nullptr) {
153 STORM_LOG_ASSERT(upperBounds->size() == x.size(), "upper bounds vector has invalid size.");
154 upper = [&upperBounds](uint64_t const& i) { return (*upperBounds)[i]; };
155 } else if (this->hasUpperBound() && upperBounds != nullptr) {
156 STORM_LOG_ASSERT(upperBounds->size() == x.size(), "upper bounds vector has invalid size.");
157 upper = [&upperBounds, this](uint64_t const& i) { return std::min(this->getUpperBound(i), (*upperBounds)[i]); };
158 }
159 bool const useBounds = lower || upper;
160
161 // Set up the LP solver
162 auto solver = lpSolverFactory->createRaw("");
163 solver->setOptimizationDirection(invert(dir));
164 using VariableIndex = typename LpSolver<ValueType, true>::Variable;
165 std::map<VariableIndex, ValueType> constantRowGroups; // Keep track of the rows that are known to be constants
166
167 // Create a variable for each row group
168 for (uint64_t rowGroup = 0; rowGroup < this->A->getRowGroupCount(); ++rowGroup) {
169 ValueType const objValue =
170 (optimizeOnlyRelevant && !this->getRelevantValues().get(rowGroup)) ? storm::utility::zero<ValueType>() : storm::utility::one<ValueType>();
171 std::optional<ValueType> lowerBound, upperBound;
172 if (useBounds) {
173 if (lower) {
174 lowerBound = lower(rowGroup);
175 }
176 if (upper) {
177 upperBound = upper(rowGroup);
178 if (lowerBound) {
179 STORM_LOG_ASSERT(*lowerBound <= *upperBound, "Lower Bound at row group " << rowGroup << " is " << *lowerBound
180 << " which exceeds the upper bound " << *upperBound << ".");
181 if (*lowerBound == *upperBound) {
182 // Some solvers (like glpk) don't support variables with bounds [x,x]. We therefore just use a constant instead. This should be more
183 // efficient anyway.
184 constantRowGroups.emplace(rowGroup, *lowerBound);
185 // Still, a dummy variable is added so that variable-indices coincide with state indices
186 solver->addContinuousVariable("dummy" + std::to_string(rowGroup));
187 continue; // with next rowGroup
188 }
189 }
190 }
191 }
192 solver->addContinuousVariable("x" + std::to_string(rowGroup), lowerBound, upperBound, objValue);
193 }
194 solver->update();
195 STORM_LOG_DEBUG("Use eq if there is a single action: " << env.solver().minMax().lp().getUseEqualityForSingleActions());
196 bool const useEqualityForSingleAction = env.solver().minMax().lp().getUseEqualityForSingleActions();
197
198 // Add a set of constraints for each row group
200 for (uint64_t rowGroup = 0; rowGroup < this->A->getRowGroupCount(); ++rowGroup) {
201 // The rowgroup refers to the state number
202 uint64_t rowIndex, rowGroupEnd;
203 if (this->choiceFixedForRowGroup && this->choiceFixedForRowGroup.get()[rowGroup]) {
204 rowIndex = this->A->getRowGroupIndices()[rowGroup] + this->getInitialScheduler()[rowGroup];
205 rowGroupEnd = rowIndex + 1;
206 } else {
207 rowIndex = this->A->getRowGroupIndices()[rowGroup];
208 rowGroupEnd = this->A->getRowGroupIndices()[rowGroup + 1];
209 }
210 bool const singleAction = (rowIndex + 1 == rowGroupEnd);
211 auto const relationType = (useEqualityForSingleAction && singleAction) ? storm::expressions::RelationType::Equal : defaultRelationType;
212 // Add a constraint for each row in the current row group
213 for (; rowIndex < rowGroupEnd; ++rowIndex) {
214 auto row = this->A->getRow(rowIndex);
215 RawLpConstraint<ValueType> constraint(relationType, -b[rowIndex], row.getNumberOfEntries());
216 auto addToConstraint = [&constraint, &constantRowGroups](VariableIndex const& var, ValueType const& val) {
217 if (auto findRes = constantRowGroups.find(var); findRes != constantRowGroups.end()) {
218 constraint.rhs -= findRes->second * val;
219 } else {
220 constraint.addToLhs(var, val);
221 }
222 };
223 auto entryIt = row.begin();
224 auto const entryItEnd = row.end();
225 for (; entryIt != entryItEnd && entryIt->getColumn() < rowGroup; ++entryIt) {
226 addToConstraint(entryIt->getColumn(), entryIt->getValue());
227 }
228 ValueType diagVal = -storm::utility::one<ValueType>();
229 if (entryIt != entryItEnd && entryIt->getColumn() == rowGroup) {
230 diagVal += entryIt->getValue();
231 ++entryIt;
232 }
233 addToConstraint(rowGroup, diagVal);
234 for (; entryIt != entryItEnd; ++entryIt) {
235 addToConstraint(entryIt->getColumn(), entryIt->getValue());
236 }
237 solver->addConstraint("", constraint);
238 }
239 }
240
241 // Invoke optimization
242 STORM_LOG_TRACE("Run solver...");
243 solver->optimize();
244 STORM_LOG_TRACE("...done.");
245
246 bool const infeasible = solver->isInfeasible();
247 if (infeasible && (lowerBounds || upperBounds)) {
248 // The provied bounds (not the once set for this solver object) are meant as a hint and might thus be wrong. We restart without bounds.
249 STORM_LOG_WARN("LP with provided bounds is infeasible. Restarting without bounds.");
250 return solveEquationsLp(env, dir, x, b);
251 }
252 STORM_LOG_THROW(!infeasible, storm::exceptions::UnexpectedException, "The MinMax equation system is infeasible.");
253 STORM_LOG_THROW(!solver->isUnbounded(), storm::exceptions::UnexpectedException, "The MinMax equation system is unbounded.");
254 STORM_LOG_THROW(solver->isOptimal(), storm::exceptions::UnexpectedException, "Unable to find optimal solution for MinMax equation system.");
255
256 // write the solution into the solution vector
257 auto xIt = x.begin();
258 VariableIndex i = 0;
259 for (; xIt != x.end(); ++xIt, ++i) {
260 if (auto findRes = constantRowGroups.find(i); findRes != constantRowGroups.end()) {
261 *xIt = findRes->second;
262 } else {
263 *xIt = solver->getContinuousValue(i);
264 }
265 }
266
267 // If requested, we store the scheduler for retrieval.
268 if (this->isTrackSchedulerSet()) {
269 this->schedulerChoices = std::vector<uint_fast64_t>(this->A->getRowGroupCount());
270 for (uint64_t rowGroup = 0; rowGroup < this->A->getRowGroupCount(); ++rowGroup) {
271 if (!this->choiceFixedForRowGroup || !this->choiceFixedForRowGroup.get()[rowGroup]) {
272 // Only update scheduler choice for the states that don't have a fixed choice
273 uint64_t row = this->A->getRowGroupIndices()[rowGroup];
274 uint64_t optimalChoiceIndex = 0;
275 uint64_t currChoice = 0;
276 ValueType optimalGroupValue = this->A->multiplyRowWithVector(row, x) + b[row];
277 for (++row, ++currChoice; row < this->A->getRowGroupIndices()[rowGroup + 1]; ++row, ++currChoice) {
278 ValueType rowValue = this->A->multiplyRowWithVector(row, x) + b[row];
279 if ((minimize(dir) && rowValue < optimalGroupValue) || (maximize(dir) && rowValue > optimalGroupValue)) {
280 optimalGroupValue = rowValue;
281 optimalChoiceIndex = currChoice;
282 }
283 }
284 this->schedulerChoices.get()[rowGroup] = optimalChoiceIndex;
285 }
286 }
287 }
288 // Reaching this point means that the solution was found.
289 return true;
290}
291
292template<typename ValueType>
296
297template<typename ValueType>
299 Environment const& env, boost::optional<storm::solver::OptimizationDirection> const& direction, bool const& hasInitialScheduler) const {
301
302 if (!this->hasUniqueSolution() && (env.solver().minMax().isForceRequireUnique() || this->isTrackSchedulerSet())) {
303 requirements.requireUniqueSolution();
304 }
305
306 if (env.solver().minMax().lp().getUseNonTrivialBounds()) {
307 requirements.requireBounds(false); // not critical
308 }
309
310 if (env.solver().minMax().getMethod() == MinMaxMethod::ViToLp) {
311 if (direction) {
312 if (minimize(*direction)) {
313 requirements.requireUpperBounds(true);
314 } else {
315 requirements.requireLowerBounds(true);
316 }
317 } else {
318 requirements.requireBounds(true);
319 }
320 }
321
322 return requirements;
323}
324
326
327#ifdef STORM_HAVE_CARL
329#endif
330} // namespace solver
331} // namespace storm
SolverEnvironment & solver()
uint64_t const & getMaximalNumberOfIterations() const
MinMaxLpSolverEnvironment const & lp() const
storm::RationalNumber const & getPrecision() const
storm::solver::MinMaxMethod const & getMethod() const
bool const & getRelativeTerminationCriterion() const
MinMaxSolverEnvironment & minMax()
Solves a MinMaxLinearEquationSystem using a linear programming solver.
LpMinMaxLinearEquationSolver(std::unique_ptr< storm::utility::solver::LpSolverFactory< ValueType > > &&lpSolverFactory)
virtual bool internalSolveEquations(Environment const &env, OptimizationDirection dir, std::vector< ValueType > &x, std::vector< ValueType > const &b) const override
virtual MinMaxLinearEquationSolverRequirements getRequirements(Environment const &env, boost::optional< storm::solver::OptimizationDirection > const &direction=boost::none, bool const &hasInitialScheduler=false) const override
Retrieves the requirements of this solver for solving equations with the current settings.
virtual void clearCache() const override
Clears the currently cached data that has been stored during previous calls of the solver.
std::conditional_t< RawMode, typename RawLpConstraint< ValueType >::VariableIndexType, storm::expressions::Variable > Variable
Definition LpSolver.h:53
virtual void clearCache() const
Clears the currently cached data that has been stored during previous calls of the solver.
MinMaxLinearEquationSolverRequirements & requireBounds(bool critical=true)
MinMaxLinearEquationSolverRequirements & requireUniqueSolution(bool critical=true)
MinMaxLinearEquationSolverRequirements & requireLowerBounds(bool critical=true)
MinMaxLinearEquationSolverRequirements & requireUpperBounds(bool critical=true)
A class that holds a possibly non-square matrix in the compressed row storage format.
#define STORM_LOG_WARN(message)
Definition logging.h:30
#define STORM_LOG_DEBUG(message)
Definition logging.h:23
#define STORM_LOG_TRACE(message)
Definition logging.h:17
#define STORM_LOG_ASSERT(cond, message)
Definition macros.h:11
#define STORM_LOG_THROW(cond, exception, message)
Definition macros.h:30
SFTBDDChecker::ValueType ValueType
bool constexpr maximize(OptimizationDirection d)
OptimizationDirection constexpr invert(OptimizationDirection d)
bool constexpr minimize(OptimizationDirection d)
LabParser.cpp.
Definition cli.cpp:18