Storm 1.11.1.1
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
OneShotPolicySearch.cpp
Go to the documentation of this file.
1#include "storm/io/file.h"
2
5
6namespace storm {
7namespace pomdp {
8
9template<typename ValueType>
10void OneShotPolicySearch<ValueType>::initialize(uint64_t k) {
11 if (maxK == std::numeric_limits<uint64_t>::max()) {
12 // not initialized at all.
13 // Create some data structures.
14 for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
15 actionSelectionVars.push_back(std::vector<storm::expressions::Variable>());
16 actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>());
17 statesPerObservation.push_back(std::vector<uint64_t>()); // Consider using bitvectors instead.
18 }
19
20 // Fill the states-per-observation mapping,
21 // declare the reachability variables,
22 // declare the path variables.
23 uint64_t stateId = 0;
24 for (auto obs : pomdp.getObservations()) {
25 pathVars.push_back(std::vector<storm::expressions::Expression>());
26 for (uint64_t i = 0; i < k; ++i) {
27 pathVars.back().push_back(expressionManager->declareBooleanVariable("P-" + std::to_string(stateId) + "-" + std::to_string(i)).getExpression());
28 }
29 reachVars.push_back(expressionManager->declareBooleanVariable("C-" + std::to_string(stateId)));
30 reachVarExpressions.push_back(reachVars.back().getExpression());
31 statesPerObservation.at(obs).push_back(stateId++);
32 }
33 assert(pathVars.size() == pomdp.getNumberOfStates());
34
35 // Create the action selection variables.
36 uint64_t obs = 0;
37 for (auto const& statesForObservation : statesPerObservation) {
38 for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) {
39 std::string varName = "A-" + std::to_string(obs) + "-" + std::to_string(a);
40 actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName));
41 actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression());
42 }
43 ++obs;
44 }
45 } else {
46 assert(false);
47 }
48
49 for (auto const& actionVars : actionSelectionVarExpressions) {
50 smtSolver->add(storm::expressions::disjunction(actionVars));
51 }
52
53 uint64_t rowindex = 0;
54 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
55 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
56 std::vector<storm::expressions::Expression> subexprreach;
57 subexprreach.push_back(!reachVarExpressions[state]);
58 subexprreach.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
59 for (auto const& entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
60 subexprreach.push_back(reachVarExpressions.at(entries.getColumn()));
61 smtSolver->add(storm::expressions::disjunction(subexprreach));
62 subexprreach.pop_back();
63 }
64 rowindex++;
65 }
66 }
67
68 rowindex = 0;
69 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
70 if (targetStates.get(state)) {
71 smtSolver->add(pathVars[state][0]);
72 } else {
73 smtSolver->add(!pathVars[state][0]);
74 }
75
76 if (surelyReachSinkStates.get(state)) {
77 smtSolver->add(!reachVarExpressions[state]);
78 rowindex += pomdp.getNumberOfChoices(state);
79 } else if (!targetStates.get(state)) {
80 std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
81 for (uint64_t j = 1; j < k; ++j) {
82 pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
83 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
84 pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>());
85 }
86 }
87
88 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
89 std::vector<storm::expressions::Expression> subexprreach;
90 for (auto const& entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
91 for (uint64_t j = 1; j < k; ++j) {
92 pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]);
93 }
94 }
95 rowindex++;
96 }
97
98 for (uint64_t j = 1; j < k; ++j) {
99 std::vector<storm::expressions::Expression> pathsubexprs;
100
101 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
102 pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) &&
103 storm::expressions::disjunction(pathsubsubexprs[j - 1][action]));
104 }
105 smtSolver->add(storm::expressions::iff(pathVars[state][j], storm::expressions::disjunction(pathsubexprs)));
106 }
107
108 smtSolver->add(storm::expressions::implies(reachVarExpressions.at(state), pathVars.at(state).back()));
109
110 } else {
111 rowindex += pomdp.getNumberOfChoices(state);
112 }
113 }
114}
115
116template<typename ValueType>
117bool OneShotPolicySearch<ValueType>::analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates, storm::storage::BitVector const& allOfTheseStates) {
118 STORM_LOG_TRACE("Use lookahead of " << k);
119 if (k < maxK) {
120 initialize(k);
121 }
122
123 std::vector<storm::expressions::Expression> atLeastOneOfStates;
124
125 for (uint64_t state : oneOfTheseStates) {
126 atLeastOneOfStates.push_back(reachVarExpressions[state]);
127 }
128 assert(atLeastOneOfStates.size() > 0);
129 smtSolver->add(storm::expressions::disjunction(atLeastOneOfStates));
130
131 for (uint64_t state : allOfTheseStates) {
132 smtSolver->add(reachVarExpressions[state]);
133 }
134
135 STORM_LOG_TRACE(smtSolver->getSmtLibString());
136
137 STORM_LOG_DEBUG("Call to SMT Solver");
138 stats.smtCheckTimer.start();
139 auto result = smtSolver->check();
140 stats.smtCheckTimer.stop();
141
143 STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result");
145 STORM_LOG_DEBUG("Unsatisfiable!");
146 return false;
147 }
148
149 STORM_LOG_DEBUG("Satisfying assignment: ");
150 auto model = smtSolver->getModel();
151 size_t i = 0;
152 storm::storage::BitVector observations(pomdp.getNrObservations());
153 storm::storage::BitVector remainingstates(pomdp.getNumberOfStates());
154 for (auto rv : reachVars) {
155 if (model->getBooleanValue(rv)) {
156 observations.set(pomdp.getObservation(i));
157 } else {
158 remainingstates.set(i);
159 }
160 ++i;
161 }
162 std::vector<std::set<uint64_t>> scheduler;
163 for (auto const& actionSelectionVarsForObs : actionSelectionVars) {
164 uint64_t act = 0;
165 scheduler.push_back(std::set<uint64_t>());
166 for (auto const& asv : actionSelectionVarsForObs) {
167 if (model->getBooleanValue(asv)) {
168 scheduler.back().insert(act);
169 }
170 act++;
171 }
172 }
173
174 return true;
175}
176
177template class OneShotPolicySearch<double>;
178template class OneShotPolicySearch<storm::RationalNumber>;
179} // namespace pomdp
180} // namespace storm
A bit vector that is internally represented as a vector of 64-bit values.
Definition BitVector.h:16
#define STORM_LOG_DEBUG(message)
Definition logging.h:18
#define STORM_LOG_TRACE(message)
Definition logging.h:12
#define STORM_LOG_THROW(cond, exception, message)
Definition macros.h:30
Expression iff(Expression const &first, Expression const &second)
Expression disjunction(std::vector< storm::expressions::Expression > const &expressions)
Expression implies(Expression const &first, Expression const &second)
void initialize(int *argc, char **argv)