Storm
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
IterativePolicySearch.cpp
Go to the documentation of this file.
2#include "storm/io/file.h"
3
7
8namespace storm::pomdp {
9namespace detail {
10void printRelevantInfoFromModel(std::shared_ptr<storm::solver::SmtSolver::ModelReference> const& model,
11 std::vector<storm::expressions::Variable> const& reachVars, std::vector<storm::expressions::Variable> const& continuationVars) {
12 uint64_t i = 0;
13 std::stringstream ss;
14 STORM_LOG_TRACE("states which we have now: ");
15 for (auto const& rv : reachVars) {
16 if (model->getBooleanValue(rv)) {
17 ss << " " << i;
18 }
19 ++i;
20 }
21 STORM_LOG_TRACE(ss.str());
22 i = 0;
23 STORM_LOG_TRACE("states from which we continue: ");
24 std::stringstream ss2;
25 for (auto const& rv : continuationVars) {
26 if (model->getBooleanValue(rv)) {
27 ss2 << " " << i;
28 }
29 ++i;
30 }
31 STORM_LOG_TRACE(ss2.str());
32}
33} // namespace detail
34
35template<typename ValueType>
37 STORM_PRINT_AND_LOG("#STATS Total time: " << totalTimer << '\n');
38 STORM_PRINT_AND_LOG("#STATS SAT Calls: " << satCalls << '\n');
39 STORM_PRINT_AND_LOG("#STATS SAT Calls time: " << smtCheckTimer << '\n');
40 STORM_PRINT_AND_LOG("#STATS Outer iterations: " << outerIterations << '\n');
41 STORM_PRINT_AND_LOG("#STATS Solver initialization time: " << initializeSolverTimer << '\n');
42 STORM_PRINT_AND_LOG("#STATS Obtain partial scheduler time: " << evaluateExtensionSolverTime << '\n');
43 STORM_PRINT_AND_LOG("#STATS Update solver to extend partial scheduler time: " << encodeExtensionSolverTime << '\n');
44 STORM_PRINT_AND_LOG("#STATS Update solver with new scheduler time: " << updateNewStrategySolverTime << '\n');
45 STORM_PRINT_AND_LOG("#STATS Winning regions update time: " << winningRegionUpdatesTimer << '\n');
46 STORM_PRINT_AND_LOG("#STATS Graph search time: " << graphSearchTime << '\n');
47}
48
49template<typename ValueType>
51 storm::storage::BitVector const& surelyReachSinkStates,
52 std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory,
53 MemlessSearchOptions const& options)
54 : pomdp(pomdp), surelyReachSinkStates(surelyReachSinkStates), targetStates(targetStates), options(options), smtSolverFactory(smtSolverFactory) {
55 this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
56 smtSolver = smtSolverFactory->create(*expressionManager);
57 // Initialize states per observation.
58 for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
59 statesPerObservation.emplace_back(); // TODO Consider using bitvectors instead.
60 reachVarExpressionsPerObservation.emplace_back();
61 }
62 uint64_t state = 0;
63 for (auto obs : pomdp.getObservations()) {
64 statesPerObservation.at(obs).push_back(state++);
65 }
66 // Initialize winning region
67 std::vector<uint64_t> nrStatesPerObservation;
68 for (auto const& states : statesPerObservation) {
69 nrStatesPerObservation.push_back(states.size());
70 }
71 winningRegion = WinningRegion(nrStatesPerObservation);
72 if (options.validateResult || options.validateEveryStep) {
73 STORM_LOG_WARN("The validator should only be created when the option is set.");
74 validator = std::make_shared<WinningRegionQueryInterface<ValueType>>(pomdp, winningRegion);
75 }
76}
77
78template<typename ValueType>
80 STORM_LOG_INFO("Start intializing solver...");
81 bool delayedSwitching = false; // Notice that delayed switching is currently not compatible with some of the other optimizations and it is unclear which of
82 // these optimizations causes the problem.
83 bool lookaheadConstraintsRequired;
84 if (options.forceLookahead) {
85 lookaheadConstraintsRequired = true;
86 } else {
87 lookaheadConstraintsRequired = qualitative::isLookaheadRequired(pomdp, targetStates, surelyReachSinkStates);
88 }
89 if (options.pathVariableType == MemlessSearchPathVariables::RealRanking) {
90 k = 10; // magic constant, consider moving.
91 }
92
93 if (actionSelectionVars.empty()) {
94 for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
95 actionSelectionVars.push_back(std::vector<storm::expressions::Variable>());
96 actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>());
97 }
98 for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
99 reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId)));
100 reachVarExpressions.push_back(reachVars.back().getExpression());
101 reachVarExpressionsPerObservation[pomdp.getObservation(stateId)].push_back(reachVarExpressions.back());
102 continuationVars.push_back(expressionManager->declareBooleanVariable("D-" + std::to_string(stateId)));
103 continuationVarExpressions.push_back(continuationVars.back().getExpression());
104 }
105 // Create the action selection variables.
106 uint64_t obs = 0;
107 for (auto const& statesForObservation : statesPerObservation) {
108 for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) {
109 std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a);
110 actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName));
111 actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression());
112 }
113 schedulerVariables.push_back(expressionManager->declareBitVectorVariable("scheduler-obs-" + std::to_string(obs), statesPerObservation.size()));
114 schedulerVariableExpressions.push_back(schedulerVariables.back());
115 switchVars.push_back(expressionManager->declareBooleanVariable("S-" + std::to_string(obs)));
116 switchVarExpressions.push_back(switchVars.back().getExpression());
117 observationUpdatedVariables.push_back(expressionManager->declareBooleanVariable("U-" + std::to_string(obs)));
118 observationUpdatedExpressions.push_back(observationUpdatedVariables.back().getExpression());
119 followVars.push_back(expressionManager->declareBooleanVariable("F-" + std::to_string(obs)));
120 followVarExpressions.push_back(followVars.back().getExpression());
121
122 ++obs;
123 }
124
125 for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
126 pathVars.push_back(std::vector<storm::expressions::Variable>());
127 pathVarExpressions.push_back(std::vector<storm::expressions::Expression>());
128 }
129 }
130
131 uint64_t initK = 0;
132 if (maxK != std::numeric_limits<uint64_t>::max()) {
133 initK = maxK;
134 }
135 if (initK < k) {
136 for (uint64_t stateId = 0; stateId < pomdp.getNumberOfStates(); ++stateId) {
137 if (lookaheadConstraintsRequired) {
138 if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
139 for (uint64_t i = initK; i < k; ++i) {
140 pathVars[stateId].push_back(expressionManager->declareBooleanVariable("P-" + std::to_string(stateId) + "-" + std::to_string(i)));
141 pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression());
142 }
143 } else if (options.pathVariableType == MemlessSearchPathVariables::IntegerRanking) {
144 pathVars[stateId].push_back(expressionManager->declareIntegerVariable("P-" + std::to_string(stateId)));
145 pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression());
146 } else {
147 assert(options.pathVariableType == MemlessSearchPathVariables::RealRanking);
148 pathVars[stateId].push_back(expressionManager->declareRationalVariable("P-" + std::to_string(stateId)));
149 pathVarExpressions[stateId].push_back(pathVars[stateId].back().getExpression());
150 }
151 }
152 }
153 }
154
155 assert(!lookaheadConstraintsRequired || pathVarExpressions.size() == pomdp.getNumberOfStates());
156 assert(reachVars.size() == pomdp.getNumberOfStates());
157 assert(reachVarExpressions.size() == pomdp.getNumberOfStates());
158
159 uint64_t obs = 0;
160
161 for (auto const& statesForObservation : statesPerObservation) {
162 if (pomdp.getNumberOfChoices(statesForObservation.front()) == 1) {
163 ++obs;
164 continue;
165 }
166 if (options.onlyDeterministicStrategies || statesForObservation.size() == 1) {
167 for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()) - 1; ++a) {
168 for (uint64_t b = a + 1; b < pomdp.getNumberOfChoices(statesForObservation.front()); ++b) {
169 smtSolver->add(!(actionSelectionVarExpressions[obs][a]) || !(actionSelectionVarExpressions[obs][b]));
170 }
171 }
172 }
173 ++obs;
174 }
175
176 obs = 0;
177 for (auto const& actionVars : actionSelectionVarExpressions) {
178 std::vector<storm::expressions::Expression> actExprs = actionVars;
179 actExprs.push_back(followVarExpressions[obs]);
180 smtSolver->add(storm::expressions::disjunction(actExprs));
181 for (auto const& av : actionVars) {
182 smtSolver->add(!followVarExpressions[obs] || !av);
183 }
184 ++obs;
185 }
186
187 // Update at least one observation.
188 // PAPER COMMENT: 2
189 smtSolver->add(storm::expressions::disjunction(observationUpdatedExpressions));
190
191 // PAPER COMMENT: 3
192 if (lookaheadConstraintsRequired) {
193 if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
194 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
195 if (targetStates.get(state)) {
196 smtSolver->add(pathVarExpressions[state][0]);
197 } else {
198 smtSolver->add(!pathVarExpressions[state][0] || followVarExpressions[pomdp.getObservation(state)]);
199 }
200 }
201 } else {
202 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
203 smtSolver->add(pathVarExpressions[state][0] <= expressionManager->integer(k));
204 smtSolver->add(pathVarExpressions[state][0] >= expressionManager->integer(0));
205 }
206 }
207 }
208
209 uint64_t rowindex = 0;
210 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
211 if (targetStates.get(state) || surelyReachSinkStates.get(state)) {
212 rowindex += pomdp.getNumberOfChoices(state);
213 continue;
214 }
215 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
216 std::vector<storm::expressions::Expression> subexprreachSwitch;
217 std::vector<storm::expressions::Expression> subexprreachNoSwitch;
218
219 subexprreachSwitch.push_back(!reachVarExpressions[state]);
220 subexprreachSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
221 subexprreachSwitch.push_back(!switchVarExpressions[pomdp.getObservation(state)]);
222 subexprreachSwitch.push_back(followVarExpressions[pomdp.getObservation(state)]);
223
224 subexprreachNoSwitch.push_back(!reachVarExpressions[state]);
225 subexprreachNoSwitch.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
226 subexprreachNoSwitch.push_back(switchVarExpressions[pomdp.getObservation(state)]);
227 subexprreachNoSwitch.push_back(followVarExpressions[pomdp.getObservation(state)]);
228
229 for (auto const& entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
230 if (!delayedSwitching || pomdp.getObservation(entries.getColumn()) != pomdp.getObservation(state)) {
231 subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn()));
232 } else {
233 // TODO: This could be the spot where delayed switching is broken.
234 subexprreachSwitch.push_back(reachVarExpressions.at(entries.getColumn()));
235 subexprreachSwitch.push_back(continuationVarExpressions.at(entries.getColumn()));
236 }
237 smtSolver->add(storm::expressions::disjunction(subexprreachSwitch));
238 subexprreachSwitch.pop_back();
239 subexprreachNoSwitch.push_back(reachVarExpressions.at(entries.getColumn()));
240 smtSolver->add(storm::expressions::disjunction(subexprreachNoSwitch));
241 subexprreachNoSwitch.pop_back();
242 }
243 rowindex++;
244 }
245 }
246
247 rowindex = 0;
248 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
249 if (surelyReachSinkStates.get(state)) {
250 smtSolver->add(!reachVarExpressions[state]);
251 smtSolver->add(!continuationVarExpressions[state]);
252 if (lookaheadConstraintsRequired) {
253 if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
254 for (uint64_t j = 1; j < k; ++j) {
255 smtSolver->add(!pathVarExpressions[state][j]);
256 }
257 } else {
258 smtSolver->add(pathVarExpressions[state][0] == expressionManager->integer(k));
259 }
260 }
261 rowindex += pomdp.getNumberOfChoices(state);
262 } else if (!targetStates.get(state)) {
263 if (lookaheadConstraintsRequired) {
264 if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
265 smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVarExpressions.at(state).back()));
266 std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
267 for (uint64_t j = 1; j < k; ++j) {
268 pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
269 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
270 pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>());
271 }
272 }
273
274 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
275 std::vector<storm::expressions::Expression> subexprreach;
276 for (auto const& entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
277 for (uint64_t j = 1; j < k; ++j) {
278 pathsubsubexprs[j - 1][action].push_back(pathVarExpressions[entries.getColumn()][j - 1]);
279 }
280 }
281 rowindex++;
282 }
283
284 for (uint64_t j = 1; j < k; ++j) {
285 std::vector<storm::expressions::Expression> pathsubexprs;
286 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
287 pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) &&
288 storm::expressions::disjunction(pathsubsubexprs[j - 1][action]));
289 }
290 if (!delayedSwitching) {
291 pathsubexprs.push_back(switchVarExpressions.at(pomdp.getObservation(state)));
292 pathsubexprs.push_back(followVarExpressions[pomdp.getObservation(state)]);
293 }
294 smtSolver->add(storm::expressions::iff(pathVarExpressions[state][j], storm::expressions::disjunction(pathsubexprs)));
295 }
296 } else {
297 std::vector<storm::expressions::Expression> actPathDisjunction;
298 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
299 std::vector<storm::expressions::Expression> pathDisjunction;
300 for (auto const& entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
301 pathDisjunction.push_back(pathVarExpressions[entries.getColumn()][0] < pathVarExpressions[state][0]);
302 }
303 actPathDisjunction.push_back(storm::expressions::disjunction(pathDisjunction) &&
304 actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action));
305 rowindex++;
306 }
307 if (!delayedSwitching) {
308 actPathDisjunction.push_back(switchVarExpressions.at(pomdp.getObservation(state)));
309 actPathDisjunction.push_back(followVarExpressions[pomdp.getObservation(state)]);
310 }
311 actPathDisjunction.push_back(!reachVarExpressions[state]);
312 smtSolver->add(storm::expressions::disjunction(actPathDisjunction));
313 }
314 }
315 } else {
316 if (lookaheadConstraintsRequired) {
317 if (options.pathVariableType == MemlessSearchPathVariables::BooleanRanking) {
318 for (uint64_t j = 1; j < k; ++j) {
319 smtSolver->add(pathVarExpressions[state][j]);
320 }
321 } else {
322 smtSolver->add(pathVarExpressions[state][0] == expressionManager->integer(0));
323 }
324 }
325 smtSolver->add(reachVars[state]);
326 rowindex += pomdp.getNumberOfChoices(state);
327 }
328 }
329
330 obs = 0;
331 for (auto const& statesForObservation : statesPerObservation) {
332 for (auto const& state : statesForObservation) {
333 if (!targetStates.get(state)) {
334 smtSolver->add(!continuationVars[state] || schedulerVariableExpressions[obs] > 0);
335 smtSolver->add(!reachVarExpressions[state] || !followVarExpressions[obs] || schedulerVariableExpressions[obs] > 0);
336 }
337 }
338 ++obs;
339 }
340
341 for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
342 smtSolver->add(
343 storm::expressions::implies(switchVarExpressions[observation], storm::expressions::disjunction(reachVarExpressionsPerObservation[observation])));
344 }
345 return lookaheadConstraintsRequired;
346}
347
348template<typename ValueType>
349uint64_t IterativePolicySearch<ValueType>::getOffsetFromObservation(uint64_t state, uint64_t observation) const {
350 if (!useFindOffset) {
351 STORM_LOG_WARN("This code is slow and should only be used for debugging.");
352 useFindOffset = true;
353 }
354 uint64_t offset = 0;
355 for (uint64_t s : statesPerObservation[observation]) {
356 if (s == state) {
357 return offset;
358 }
359 ++offset;
360 }
361 assert(false); // State should have occured.
362 return 0;
363}
364
365template<typename ValueType>
367 storm::storage::BitVector const& allOfTheseStates) {
368 STORM_LOG_DEBUG("Surely reach sink states: " << surelyReachSinkStates);
369 STORM_LOG_DEBUG("Target states " << targetStates);
370 STORM_LOG_DEBUG("Questionmark states " << (~surelyReachSinkStates & ~targetStates));
371 stats.initializeSolverTimer.start();
372 // TODO: When do we need to reinitialize? When the solver has been reset.
373 bool lookaheadConstraintsRequired = initialize(k);
374 if (lookaheadConstraintsRequired) {
375 maxK = k;
376 }
377
378 stats.winningRegionUpdatesTimer.start();
379 storm::storage::BitVector updated(pomdp.getNrObservations());
380 storm::storage::BitVector potentialWinner(pomdp.getNrObservations());
381 storm::storage::BitVector observationsWithPartialWinners(pomdp.getNrObservations());
382 for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
383 if (winningRegion.observationIsWinning(observation)) {
384 continue;
385 }
386 bool observationIsWinning = true;
387 for (uint64_t state : statesPerObservation[observation]) {
388 if (!targetStates.get(state)) {
389 observationIsWinning = false;
390 observationsWithPartialWinners.set(observation);
391 } else {
392 potentialWinner.set(observation);
393 }
394 }
395 if (observationIsWinning) {
396 STORM_LOG_TRACE("Observation " << observation << " is winning.");
397 stats.incrementGraphBasedWinningObservations();
398 winningRegion.setObservationIsWinning(observation);
399 updated.set(observation);
400 }
401 }
402 STORM_LOG_DEBUG("Graph based winning obs: " << stats.getGraphBasedwinningObservations());
403 observationsWithPartialWinners &= potentialWinner;
404 for (auto const observation : observationsWithPartialWinners) {
405 uint64_t nrStatesForObs = statesPerObservation[observation].size();
406 storm::storage::BitVector update(nrStatesForObs);
407 for (uint64_t i = 0; i < nrStatesForObs; ++i) {
408 uint64_t state = statesPerObservation[observation][i];
409 if (targetStates.get(state)) {
410 update.set(i);
411 }
412 }
413 assert(!update.empty());
414 STORM_LOG_TRACE("Extend winning region for observation " << observation << " with target states/offsets" << update);
415 winningRegion.addTargetStates(observation, update);
416 assert(winningRegion.query(observation, update)); // "Cannot continue: No scheduler known for state " << i << " (observation " << obs << ").");
417
418 updated.set(observation);
419 }
420
421#ifndef NDEBUG
422 for (auto const& state : targetStates) {
423 STORM_LOG_ASSERT(winningRegion.isWinning(pomdp.getObservation(state), getOffsetFromObservation(state, pomdp.getObservation(state))),
424 "Target state " << state << " , observation " << pomdp.getObservation(state) << " is not reflected as winning.");
425 }
426#endif
427
428 stats.winningRegionUpdatesTimer.stop();
429
430 uint64_t maximalNrActions = 0;
431 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
432 maximalNrActions = std::max(pomdp.getTransitionMatrix().getRowGroupSize(state), maximalNrActions);
433 }
434 std::vector<storm::expressions::Expression> atLeastOneOfStates;
435 for (uint64_t state : oneOfTheseStates) {
436 STORM_LOG_ASSERT(reachVarExpressions.size() > state, "state id " << state << " exceeds number of states (" << reachVarExpressions.size() << ")");
437 atLeastOneOfStates.push_back(reachVarExpressions[state]);
438 }
439 if (!atLeastOneOfStates.empty()) {
440 smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates));
441 }
442
443 std::set<storm::expressions::Expression> allOfTheseAssumption;
444 std::vector<storm::expressions::Expression> updateForObservationExpressions;
445
446 for (uint64_t state : allOfTheseStates) {
447 assert(reachVarExpressions.size() > state);
448 allOfTheseAssumption.insert(reachVarExpressions[state]);
449 }
450
451 if (winningRegion.empty()) {
452 // Keep it simple here to help bughunting if necessary.
453 for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) {
454 updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[ob]));
455 schedulerForObs.push_back(0);
456 }
457 } else {
458 uint64_t obs = 0;
459 for (auto const& statesForObservation : statesPerObservation) {
460 schedulerForObs.push_back(0);
461 for (auto const& winningSet : winningRegion.getWinningSetsPerObservation(obs)) {
462 assert(!winningSet.empty());
463 assert(obs < schedulerForObs.size());
464 ++(schedulerForObs[obs]);
465 auto constant = expressionManager->integer(schedulerForObs[obs]);
466 for (auto const& stateOffset : ~winningSet) {
467 uint64_t state = statesForObservation[stateOffset];
468 STORM_LOG_TRACE("State " << state << " with observation " << obs << " does not allow scheduler " << constant);
469 smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant)));
470 smtSolver->add(
471 !(reachVarExpressions[state] && followVarExpressions[pomdp.getObservation(state)] && (schedulerVariableExpressions[obs] == constant)));
472 }
473 }
474 if (winningRegion.getWinningSetsPerObservation(obs).empty()) {
475 updateForObservationExpressions.push_back(storm::expressions::disjunction(reachVarExpressionsPerObservation[obs]));
476 // Here is some opportunity for further constraints,
477 // but one has to be careful that the constraints added here are never removed (the push happens after adding these constraints)
478 } else {
479 updateForObservationExpressions.push_back(winningRegion.extensionExpression(obs, reachVarExpressionsPerObservation[obs]));
480 }
481 ++obs;
482 }
483 }
484
485 smtSolver->push();
486 for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
487 auto constant = expressionManager->integer(schedulerForObs[obs]);
488 smtSolver->add(schedulerVariableExpressions[obs] <= constant);
489 smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs]));
490 }
491
492 assert(pomdp.getNrObservations() == schedulerForObs.size());
493
495 scheduler.switchObservations = storm::storage::BitVector(pomdp.getNrObservations());
496 storm::storage::BitVector newObservations(pomdp.getNrObservations());
497 storm::storage::BitVector newObservationsAfterSwitch(pomdp.getNrObservations());
498 storm::storage::BitVector observations(pomdp.getNrObservations());
499 storm::storage::BitVector observationsAfterSwitch(pomdp.getNrObservations());
500 storm::storage::BitVector observationUpdated(pomdp.getNrObservations());
501 storm::storage::BitVector uncoveredStates(pomdp.getNumberOfStates());
502 storm::storage::BitVector coveredStates(pomdp.getNumberOfStates());
503 storm::storage::BitVector coveredStatesAfterSwitch(pomdp.getNumberOfStates());
504
505 stats.initializeSolverTimer.stop();
506 STORM_LOG_INFO("Start iterative solver...");
507
508 uint64_t iterations = 0;
509
510 bool foundWhatWeLookFor = false;
511 while (true) {
512 stats.incrementOuterIterations();
513 // TODO consider what we really want to store about the schedulers.
514 scheduler.reset(pomdp.getNrObservations(), maximalNrActions);
515 observations.clear();
516 observationsAfterSwitch.clear();
517 coveredStates = targetStates;
518 coveredStatesAfterSwitch.clear();
519 observationUpdated.clear();
520 bool newSchedulerDiscovered = false;
521 if (!allOfTheseAssumption.empty()) {
522 ++iterations;
523 bool foundResult = this->smtCheck(iterations, allOfTheseAssumption);
524 if (foundResult) {
525 // Consider storing the scheduler
526 foundWhatWeLookFor = true;
527 }
528 }
529 uint64_t localIterations = 0;
530 while (true) {
531 ++iterations;
532 ++localIterations;
533
534 bool foundScheduler = foundWhatWeLookFor;
535 if (!foundScheduler) {
536 foundScheduler = this->smtCheck(iterations);
537 }
538 if (!foundScheduler) {
539 break;
540 }
541 newSchedulerDiscovered = true;
542 stats.evaluateExtensionSolverTime.start();
543 auto const& model = smtSolver->getModel();
544
545 newObservationsAfterSwitch.clear();
546 newObservations.clear();
547
548 uint64_t obs = 0;
549 for (auto const& ov : observationUpdatedVariables) {
550 if (!observationUpdated.get(obs) && model->getBooleanValue(ov)) {
551 STORM_LOG_TRACE("New observation updated: " << obs);
552 observationUpdated.set(obs);
553 }
554 obs++;
555 }
556
557 uncoveredStates = ~coveredStates;
558 for (uint64_t i : uncoveredStates) {
559 auto const& rv = reachVars[i];
560 auto const& rvExpr = reachVarExpressions[i];
561 if (observationUpdated.get(pomdp.getObservation(i)) && model->getBooleanValue(rv)) {
562 STORM_LOG_TRACE("New state: " << i);
563 smtSolver->add(rvExpr);
564 assert(!surelyReachSinkStates.get(i));
565 newObservations.set(pomdp.getObservation(i));
566 coveredStates.set(i);
567 if (lookaheadConstraintsRequired) {
568 if (options.pathVariableType == MemlessSearchPathVariables::IntegerRanking) {
569 smtSolver->add(pathVarExpressions[i][0] == expressionManager->integer(model->getIntegerValue(pathVars[i][0])));
570 } else if (options.pathVariableType == MemlessSearchPathVariables::RealRanking) {
571 smtSolver->add(pathVarExpressions[i][0] == expressionManager->rational(model->getRationalValue(pathVars[i][0])));
572 }
573 }
574 }
575 }
576
577 storm::storage::BitVector uncoveredStatesAfterSwitch(~coveredStatesAfterSwitch);
578 for (uint64_t i : uncoveredStatesAfterSwitch) {
579 auto const& cv = continuationVars[i];
580 if (model->getBooleanValue(cv)) {
581 uint64_t obs = pomdp.getObservation(i);
582 STORM_LOG_ASSERT(winningRegion.isWinning(obs, getOffsetFromObservation(i, obs)),
583 "Cannot continue: No scheduler known for state " << i << " (observation " << obs << ").");
584 auto const& cvExpr = continuationVarExpressions[i];
585 smtSolver->add(cvExpr);
586 if (!observationsAfterSwitch.get(obs)) {
587 newObservationsAfterSwitch.set(obs);
588 }
589 }
590 }
591 stats.evaluateExtensionSolverTime.stop();
592
593 if (options.computeTraceOutput()) {
594 detail::printRelevantInfoFromModel(model, reachVars, continuationVars);
595 }
596 stats.encodeExtensionSolverTime.start();
597 for (auto obs : newObservations) {
598 auto const& actionSelectionVarsForObs = actionSelectionVars[obs];
599 observations.set(obs);
600 for (uint64_t act = 0; act < actionSelectionVarsForObs.size(); ++act) {
601 if (model->getBooleanValue(actionSelectionVarsForObs[act])) {
602 scheduler.actions[obs].set(act);
603 smtSolver->add(actionSelectionVarExpressions[obs][act]);
604 } else {
605 smtSolver->add(!actionSelectionVarExpressions[obs][act]);
606 }
607 }
608 if (model->getBooleanValue(switchVars[obs])) {
609 scheduler.switchObservations.set(obs);
610 smtSolver->add(switchVarExpressions[obs]);
611 } else {
612 smtSolver->add(!switchVarExpressions[obs]);
613 }
614 if (model->getBooleanValue(followVars[obs])) {
615 smtSolver->add(followVarExpressions[obs]);
616 } else {
617 smtSolver->add(!followVarExpressions[obs]);
618 }
619 }
620 for (auto obs : newObservationsAfterSwitch) {
621 observationsAfterSwitch.set(obs);
622 scheduler.schedulerRef[obs] = model->getIntegerValue(schedulerVariables[obs]);
623 smtSolver->add(schedulerVariableExpressions[obs] == expressionManager->integer(scheduler.schedulerRef[obs]));
624 }
625
626 if (options.computeTraceOutput()) {
627 // generates debug output, but here we only want it for trace level.
628 // For consistency, all output on debug level.
629 STORM_LOG_DEBUG("the scheduler so far: ");
630 scheduler.printForObservations(observations, observationsAfterSwitch);
631 }
632
633 if (foundWhatWeLookFor ||
634 (options.localIterationMaximum > 0 && (localIterations % (options.localIterationMaximum + 1) == options.localIterationMaximum))) {
635 stats.encodeExtensionSolverTime.stop();
636 break;
637 }
638
639 std::vector<storm::expressions::Expression> remainingExpressions;
640 for (auto index : ~coveredStates) {
641 if (observationUpdated.get(pomdp.getObservation(index))) {
642 remainingExpressions.push_back(reachVarExpressions[index]);
643 }
644 }
645 for (auto index : ~observationUpdated) {
646 remainingExpressions.push_back(observationUpdatedExpressions[index]);
647 }
648
649 if (remainingExpressions.empty()) {
650 stats.encodeExtensionSolverTime.stop();
651 break;
652 }
653 smtSolver->add(storm::expressions::disjunction(remainingExpressions));
654 stats.encodeExtensionSolverTime.stop();
655 // smtSolver->setTimeout(options.extensionCallTimeout);
656 }
657 if (!newSchedulerDiscovered) {
658 break;
659 }
660 // smtSolver->unsetTimeout();
661 smtSolver->pop();
662
663 if (options.computeDebugOutput()) {
664 std::stringstream strstr;
665 coveredStatesToStream(strstr, ~coveredStates);
666 STORM_LOG_DEBUG(strstr.str());
667 // generates info output, but here we only want it for debug level.
668 // For consistency, all output on info level.
669 STORM_LOG_DEBUG("the scheduler: ");
670 scheduler.printForObservations(observations, observationsAfterSwitch);
671 }
672
673 stats.winningRegionUpdatesTimer.start();
674 storm::storage::BitVector updated(observations.size());
675 uint64_t newTargetObservations = 0;
676 for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
677 STORM_LOG_TRACE("consider observation " << observation);
678 storm::storage::BitVector update(statesPerObservation[observation].size());
679 uint64_t i = 0;
680 for (uint64_t state : statesPerObservation[observation]) {
681 if (coveredStates.get(state)) {
682 assert(!surelyReachSinkStates.get(state));
683 update.set(i);
684 }
685 ++i;
686 }
687 if (!update.empty()) {
688 STORM_LOG_TRACE("Update Winning Region: Observation " << observation << " with update " << update);
689 bool updateResult = winningRegion.update(observation, update);
690 STORM_LOG_TRACE("Region changed:" << updateResult);
691 if (updateResult) {
692 if (winningRegion.observationIsWinning(observation)) {
693 ++newTargetObservations;
694 for (uint64_t state : statesPerObservation[observation]) {
695 targetStates.set(state);
696 assert(!surelyReachSinkStates.get(state));
697 }
698 }
699 updated.set(observation);
700 updateForObservationExpressions[observation] =
701 winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]);
702 }
703 }
704 }
705 stats.winningRegionUpdatesTimer.stop();
706 if (foundWhatWeLookFor) {
707 return true;
708 }
709 if (newTargetObservations > 0) {
710 stats.graphSearchTime.start();
712 uint64_t targetStatesBefore = targetStates.getNumberOfSetBits();
713 STORM_LOG_DEBUG("Target states before graph based analysis " << targetStates.getNumberOfSetBits());
714 targetStates = graphanalysis.analyseProb1Max(~surelyReachSinkStates, targetStates);
715 uint64_t targetStatesAfter = targetStates.getNumberOfSetBits();
716 STORM_LOG_DEBUG("Target states after graph based analysis " << targetStates.getNumberOfSetBits());
717 stats.graphSearchTime.stop();
718 if (targetStatesAfter - targetStatesBefore > 0) {
719 stats.winningRegionUpdatesTimer.start();
720 storm::storage::BitVector potentialWinner(pomdp.getNrObservations());
721 storm::storage::BitVector observationsWithPartialWinners(pomdp.getNrObservations());
722 for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
723 if (winningRegion.observationIsWinning(observation)) {
724 continue;
725 }
726 bool observationIsWinning = true;
727 for (uint64_t state : statesPerObservation[observation]) {
728 if (!targetStates.get(state)) {
729 observationIsWinning = false;
730 observationsWithPartialWinners.set(observation);
731 } else {
732 potentialWinner.set(observation);
733 }
734 }
735 if (observationIsWinning) {
736 stats.incrementGraphBasedWinningObservations();
737 winningRegion.setObservationIsWinning(observation);
738 updated.set(observation);
739 }
740 }
741 STORM_LOG_DEBUG("Graph-based winning obs: " << stats.getGraphBasedwinningObservations());
742 observationsWithPartialWinners &= potentialWinner;
743 for (auto const observation : observationsWithPartialWinners) {
744 uint64_t nrStatesForObs = statesPerObservation[observation].size();
745 storm::storage::BitVector update(nrStatesForObs);
746 for (uint64_t i = 0; i < nrStatesForObs; ++i) {
747 uint64_t state = statesPerObservation[observation][i];
748 if (targetStates.get(state)) {
749 update.set(i);
750 }
751 }
752 assert(!update.empty());
753 STORM_LOG_TRACE("Extend winning region for observation " << observation << " with target states/offsets" << update);
754 winningRegion.addTargetStates(observation, update);
755 assert(winningRegion.query(observation, update)); //
756 updated.set(observation);
757 }
758 stats.winningRegionUpdatesTimer.stop();
759
760 if (observationsWithPartialWinners.getNumberOfSetBits() > 0) {
761 reset();
762 return analyze(k, ~targetStates & ~surelyReachSinkStates, allOfTheseStates);
763 }
764 }
765 }
766 STORM_LOG_ASSERT(!updated.empty(), "The strategy should be new in at least one place");
767 if (options.computeDebugOutput()) {
768 winningRegion.print();
769 }
770 if (options.validateEveryStep) {
771 STORM_LOG_WARN("Validating every step, for debug purposes only!");
772 validator->validate();
773 }
774 if (stats.getIterations() % options.restartAfterNIterations == options.restartAfterNIterations - 1) {
775 reset();
776 return analyze(k, ~targetStates & ~surelyReachSinkStates, allOfTheseStates);
777 }
778 stats.updateNewStrategySolverTime.start();
779 for (uint64_t observation : updated) {
780 updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]);
781 }
782
783 uint64_t obs = 0;
784 for (auto const& statesForObservation : statesPerObservation) {
785 if (observations.get(obs) && updated.get(obs)) {
786 STORM_LOG_DEBUG("We have a new policy ( " << finalSchedulers.size() << " ) for states with observation " << obs << ".");
787 assert(schedulerForObs.size() > obs);
788 (schedulerForObs[obs])++;
789 STORM_LOG_DEBUG("We now have " << schedulerForObs[obs] << " policies for states with observation " << obs);
790 if (winningRegion.observationIsWinning(obs)) {
791 for (auto const& state : statesForObservation) {
792 smtSolver->add(reachVarExpressions[state]);
793 }
794 auto constant = expressionManager->integer(schedulerForObs[obs]);
795 smtSolver->add(schedulerVariableExpressions[obs] == constant);
796 } else {
797 auto constant = expressionManager->integer(schedulerForObs[obs]);
798 for (auto const& state : statesForObservation) {
799 if (!coveredStates.get(state)) {
800 smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant)));
801 smtSolver->add(!(reachVarExpressions[state] && followVarExpressions[pomdp.getObservation(state)] &&
802 (schedulerVariableExpressions[obs] == constant)));
803 }
804 }
805 }
806 }
807 ++obs;
808 }
809 finalSchedulers.push_back(scheduler);
810
811 smtSolver->push();
812
813 for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
814 if (winningRegion.observationIsWinning(obs)) {
815 auto constant = expressionManager->integer(schedulerForObs[obs]);
816 // Scheduler variable is already fixed.
817 // Observation will not be updated.
818 smtSolver->add(!observationUpdatedExpressions[obs]);
819 } else {
820 auto constant = expressionManager->integer(schedulerForObs[obs]);
821 smtSolver->add(schedulerVariableExpressions[obs] <= constant);
822 smtSolver->add(storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs]));
823 }
824 }
825 stats.updateNewStrategySolverTime.stop();
826
827 STORM_LOG_INFO("... after iteration " << stats.getIterations() << " so far " << stats.getChecks() << " checks.");
828 }
829 if (options.validateResult) {
830 STORM_LOG_WARN("Validating result is a winning region, only for debugging purposes.");
831 validator->validate();
832 STORM_LOG_WARN("Validating result is a fixed point, only for debugging purposes.");
833 validator->validateIsMaximal(surelyReachSinkStates);
834 }
835
836 if (!allOfTheseStates.empty()) {
837 for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
838 storm::storage::BitVector check(statesPerObservation[observation].size());
839 uint64_t i = 0;
840 for (uint64_t state : statesPerObservation[observation]) {
841 if (allOfTheseStates.get(state)) {
842 check.set(i);
843 }
844 ++i;
845 }
846 if (!winningRegion.query(observation, check)) {
847 return false;
848 }
849 }
850 }
851 return true;
852}
853
854template<typename ValueType>
855void IterativePolicySearch<ValueType>::coveredStatesToStream(std::ostream& os, storm::storage::BitVector const& remaining) const {
856 bool first = true;
857 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
858 if (!remaining.get(state)) {
859 if (first) {
860 first = false;
861 } else {
862 os << ", ";
863 }
864 std::cout << state;
865 if (pomdp.hasStateValuations()) {
866 os << ":" << pomdp.getStateValuations().getStateInfo(state);
867 }
868 }
869 }
870 os << '\n';
871}
872
873template<typename ValueType>
874void IterativePolicySearch<ValueType>::printScheduler(std::vector<InternalObservationScheduler> const&) {}
875
876template<typename ValueType>
878
879template<typename ValueType>
883
884template<typename ValueType>
885bool IterativePolicySearch<ValueType>::smtCheck(uint64_t iteration, std::set<storm::expressions::Expression> const& assumptions) {
886 if (options.isExportSATSet()) {
887 STORM_LOG_DEBUG("Export SMT Solver Call (" << iteration << ")");
888 std::string filepath = options.getExportSATCallsPath() + "call_" + std::to_string(iteration) + ".smt2";
889 std::ofstream filestream;
890 storm::io::openFile(filepath, filestream);
891 filestream << smtSolver->getSmtLibString() << '\n';
892 storm::io::closeFile(filestream);
893 }
894
895 STORM_LOG_DEBUG("Call to SMT Solver (" << iteration << ")");
897 stats.smtCheckTimer.start();
898 if (assumptions.empty()) {
899 result = smtSolver->check();
900 } else {
901 result = smtSolver->checkWithAssumptions(assumptions);
902 }
903 stats.smtCheckTimer.stop();
904 stats.incrementSmtChecks();
905
907 STORM_LOG_DEBUG("Unknown");
908 return false;
910 STORM_LOG_DEBUG("Unsatisfiable!");
911 return false;
912 }
913
914 STORM_LOG_TRACE("Satisfying assignment: ");
915 STORM_LOG_TRACE(smtSolver->getModelAsValuation().toString(true));
916 return true;
917}
918
919template class IterativePolicySearch<double>;
920template class IterativePolicySearch<storm::RationalNumber>;
921} // namespace storm::pomdp
storm::storage::BitVector analyseProb1Max(storm::storage::BitVector const &okay, storm::storage::BitVector const &target) const
This class represents a partially observable Markov decision process.
Definition Pomdp.h:15
uint64_t getNrObservations() const
Definition Pomdp.cpp:68
std::vector< uint32_t > const & getObservations() const
Definition Pomdp.cpp:89
uint64_t getOffsetFromObservation(uint64_t state, uint64_t observation) const
IterativePolicySearch(storm::models::sparse::Pomdp< ValueType > const &pomdp, storm::storage::BitVector const &targetStates, storm::storage::BitVector const &surelyReachSinkStates, std::shared_ptr< storm::utility::solver::SmtSolverFactory > &smtSolverFactory, MemlessSearchOptions const &options)
bool analyze(uint64_t k, storm::storage::BitVector const &oneOfTheseStates, storm::storage::BitVector const &allOfTheseStates=storm::storage::BitVector())
CheckResult
possible check results
Definition SmtSolver.h:25
A bit vector that is internally represented as a vector of 64-bit values.
Definition BitVector.h:18
bool empty() const
Retrieves whether no bits are set to true in this bit vector.
void clear()
Removes all set bits from the bit vector.
void set(uint_fast64_t index, bool value=true)
Sets the given truth value at the given index.
size_t size() const
Retrieves the number of bits this bit vector can store.
uint_fast64_t getNumberOfSetBits() const
Returns the number of bits that are set to true in this bit vector.
bool get(uint_fast64_t index) const
Retrieves the truth value of the bit at the given index and performs a bound check.
#define STORM_LOG_INFO(message)
Definition logging.h:29
#define STORM_LOG_WARN(message)
Definition logging.h:30
#define STORM_LOG_DEBUG(message)
Definition logging.h:23
#define STORM_LOG_TRACE(message)
Definition logging.h:17
#define STORM_LOG_ASSERT(cond, message)
Definition macros.h:11
#define STORM_PRINT_AND_LOG(message)
Definition macros.h:68
Expression iff(Expression const &first, Expression const &second)
Expression disjunction(std::vector< storm::expressions::Expression > const &expressions)
Expression implies(Expression const &first, Expression const &second)
void closeFile(std::ofstream &stream)
Close the given file after writing.
Definition file.h:47
void openFile(std::string const &filepath, std::ofstream &filestream, bool append=false, bool silent=false)
Open the given file for writing.
Definition file.h:18
void printRelevantInfoFromModel(std::shared_ptr< storm::solver::SmtSolver::ModelReference > const &model, std::vector< storm::expressions::Variable > const &reachVars, std::vector< storm::expressions::Variable > const &continuationVars)
bool isLookaheadRequired(storm::models::sparse::Pomdp< ValueType > const &pomdp, storm::storage::BitVector const &targetStates, storm::storage::BitVector const &surelyReachSinkStates)
std::vector< storm::storage::BitVector > actions
void printForObservations(storm::storage::BitVector const &observations, storm::storage::BitVector const &observationsAfterSwitch) const
void reset(uint64_t nrObservations, uint64_t nrActions)