Storm
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
OneShotPolicySearch.cpp
Go to the documentation of this file.
1#include "storm/io/file.h"
2
4
5namespace storm {
6namespace pomdp {
7
8template<typename ValueType>
9void OneShotPolicySearch<ValueType>::initialize(uint64_t k) {
10 if (maxK == std::numeric_limits<uint64_t>::max()) {
11 // not initialized at all.
12 // Create some data structures.
13 for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
14 actionSelectionVars.push_back(std::vector<storm::expressions::Variable>());
15 actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>());
16 statesPerObservation.push_back(std::vector<uint64_t>()); // Consider using bitvectors instead.
17 }
18
19 // Fill the states-per-observation mapping,
20 // declare the reachability variables,
21 // declare the path variables.
22 uint64_t stateId = 0;
23 for (auto obs : pomdp.getObservations()) {
24 pathVars.push_back(std::vector<storm::expressions::Expression>());
25 for (uint64_t i = 0; i < k; ++i) {
26 pathVars.back().push_back(expressionManager->declareBooleanVariable("P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression());
27 }
28 reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId)));
29 reachVarExpressions.push_back(reachVars.back().getExpression());
30 statesPerObservation.at(obs).push_back(stateId++);
31 }
32 assert(pathVars.size() == pomdp.getNumberOfStates());
33
34 // Create the action selection variables.
35 uint64_t obs = 0;
36 for (auto const& statesForObservation : statesPerObservation) {
37 for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) {
38 std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a);
39 actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName));
40 actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression());
41 }
42 ++obs;
43 }
44 } else {
45 assert(false);
46 }
47
48 for (auto const& actionVars : actionSelectionVarExpressions) {
49 smtSolver->add(storm::expressions::disjunction(actionVars));
50 }
51
52 uint64_t rowindex = 0;
53 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
54 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
55 std::vector<storm::expressions::Expression> subexprreach;
56 subexprreach.push_back(!reachVarExpressions[state]);
57 subexprreach.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
58 for (auto const& entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
59 subexprreach.push_back(reachVarExpressions.at(entries.getColumn()));
60 smtSolver->add(storm::expressions::disjunction(subexprreach));
61 subexprreach.pop_back();
62 }
63 rowindex++;
64 }
65 }
66
67 rowindex = 0;
68 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
69 if (targetStates.get(state)) {
70 smtSolver->add(pathVars[state][0]);
71 } else {
72 smtSolver->add(!pathVars[state][0]);
73 }
74
75 if (surelyReachSinkStates.get(state)) {
76 smtSolver->add(!reachVarExpressions[state]);
77 rowindex += pomdp.getNumberOfChoices(state);
78 } else if (!targetStates.get(state)) {
79 std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
80 for (uint64_t j = 1; j < k; ++j) {
81 pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
82 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
83 pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>());
84 }
85 }
86
87 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
88 std::vector<storm::expressions::Expression> subexprreach;
89 for (auto const& entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
90 for (uint64_t j = 1; j < k; ++j) {
91 pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]);
92 }
93 }
94 rowindex++;
95 }
96
97 for (uint64_t j = 1; j < k; ++j) {
98 std::vector<storm::expressions::Expression> pathsubexprs;
99
100 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
101 pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) &&
102 storm::expressions::disjunction(pathsubsubexprs[j - 1][action]));
103 }
104 smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs)));
105 }
106
107 smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
108
109 } else {
110 rowindex += pomdp.getNumberOfChoices(state);
111 }
112 }
113}
114
115template<typename ValueType>
116bool OneShotPolicySearch<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) {
117 STORM_LOG_TRACE("Use lookahead of " << k);
118 if (k < maxK) {
119 initialize(k);
120 }
121
122 std::vector<storm::expressions::Expression> atLeastOneOfStates;
123
124 for (uint64_t state : oneOfTheseStates) {
125 atLeastOneOfStates.push_back(reachVarExpressions[state]);
126 }
127 assert(atLeastOneOfStates.size() > 0);
128 smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates));
129
130 for (uint64_t state : allOfTheseStates) {
131 smtSolver->add(reachVarExpressions[state]);
132 }
133
134 STORM_LOG_TRACE(smtSolver->getSmtLibString());
135
136 STORM_LOG_DEBUG("Call to SMT Solver");
137 stats.smtCheckTimer.start();
138 auto result = smtSolver->check();
139 stats.smtCheckTimer.stop();
140
142 STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
144 STORM_LOG_DEBUG("Unsatisfiable!");
145 return false;
146 }
147
148 STORM_LOG_DEBUG("Satisfying assignment: ");
149 auto model = smtSolver->getModel();
150 size_t i = 0;
151 storm::storage::BitVector observations(pomdp.getNrObservations());
152 storm::storage::BitVector remainingstates(pomdp.getNumberOfStates());
153 for (auto rv : reachVars) {
154 if (model->getBooleanValue(rv)) {
155 observations.set(pomdp.getObservation(i));
156 } else {
157 remainingstates.set(i);
158 }
159 ++i;
160 }
161 std::vector<std::set<uint64_t>> scheduler;
162 for (auto const& actionSelectionVarsForObs : actionSelectionVars) {
163 uint64_t act = 0;
164 scheduler.push_back(std::set<uint64_t>());
165 for (auto const& asv : actionSelectionVarsForObs) {
166 if (model->getBooleanValue(asv)) {
167 scheduler.back().insert(act);
168 }
169 act++;
170 }
171 }
172
173 return true;
174}
175
176template class OneShotPolicySearch<double>;
177template class OneShotPolicySearch<storm::RationalNumber>;
178} // namespace pomdp
179} // namespace storm
A bit vector that is internally represented as a vector of 64-bit values.
Definition BitVector.h:18
#define STORM_LOG_DEBUG(message)
Definition logging.h:23
#define STORM_LOG_TRACE(message)
Definition logging.h:17
#define STORM_LOG_THROW(cond, exception, message)
Definition macros.h:30
Expression iff(Expression const &first, Expression const &second)
Expression disjunction(std::vector< storm::expressions::Expression > const &expressions)
Expression implies(Expression const &first, Expression const &second)
void initialize(int *argc, char **argv)
Definition storm_gtest.h:49
LabParser.cpp.
Definition cli.cpp:18