Storm
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
IterativePolicySearch.h
Go to the documentation of this file.
1#include <sstream>
2#include <vector>
9
12
13namespace storm {
14namespace pomdp {
15
18 if (in == "int") {
20 } else if (in == "real") {
22 } else {
23 assert(in == "bool");
25 }
26}
27
29 public:
30 void setExportSATCalls(std::string const& path) {
31 exportSATcalls = path;
32 }
33
34 std::string const& getExportSATCallsPath() const {
35 return exportSATcalls;
36 }
37
38 bool isExportSATSet() const {
39 return exportSATcalls != "";
40 }
41
42 void setDebugLevel(uint64_t level = 1) {
43 debugLevel = level;
44 }
45
46 bool computeInfoOutput() const {
47 return debugLevel > 0;
48 }
49
50 bool computeDebugOutput() const {
51 return debugLevel > 1;
52 }
53
54 bool computeTraceOutput() const {
55 return debugLevel > 2;
56 }
57
59 bool forceLookahead = false;
60 bool validateEveryStep = false;
61 bool validateResult = false;
64 uint64_t extensionCallTimeout = 0u;
65 uint64_t localIterationMaximum = 600;
66
67 private:
68 std::string exportSATcalls = "";
69 uint64_t debugLevel = 0;
70};
71
73 std::vector<storm::storage::BitVector> actions;
74 std::vector<uint64_t> schedulerRef;
76
77 void reset(uint64_t nrObservations, uint64_t nrActions) {
78 actions = std::vector<storm::storage::BitVector>(nrObservations, storm::storage::BitVector(nrActions));
79 schedulerRef = std::vector<uint64_t>(nrObservations, 0);
81 }
82
83 bool empty() const {
84 return actions.empty();
85 }
86
87 void printForObservations(storm::storage::BitVector const& observations, storm::storage::BitVector const& observationsAfterSwitch) const {
88 for (uint64_t obs = 0; obs < observations.size(); ++obs) {
89 if (observations.get(obs) || observationsAfterSwitch.get(obs)) {
90 STORM_LOG_INFO("For observation: " << obs);
91 }
92 if (observations.get(obs)) {
93 std::stringstream ss;
94 ss << "actions:";
95 for (auto act : actions[obs]) {
96 ss << " " << act;
97 }
98 if (switchObservations.get(obs)) {
99 ss << " and switch.";
100 }
101 STORM_LOG_INFO(ss.str());
102 }
103 if (observationsAfterSwitch.get(obs)) {
104 STORM_LOG_INFO("scheduler ref: " << schedulerRef[obs]);
105 }
106 }
107 }
108};
109
110template<typename ValueType>
112 // Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper.
113
114 public:
116 public:
117 Statistics() = default;
118 void print() const;
119
127
129
131 outerIterations++;
132 }
133
135 satCalls++;
136 }
137
138 uint64_t getChecks() {
139 return satCalls;
140 }
141
142 uint64_t getIterations() {
143 return outerIterations;
144 }
145
147 return graphBasedAnalysisWinOb;
148 }
149
151 graphBasedAnalysisWinOb++;
152 }
153
154 private:
155 uint64_t satCalls = 0;
156 uint64_t outerIterations = 0;
157 uint64_t graphBasedAnalysisWinOb = 0;
158 };
159
161 storm::storage::BitVector const& surelyReachSinkStates,
162
163 std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory, MemlessSearchOptions const& options);
164
165 bool analyzeForInitialStates(uint64_t k) {
166 stats.totalTimer.start();
167 STORM_LOG_TRACE("Bad states: " << surelyReachSinkStates);
168 STORM_LOG_TRACE("Target states: " << targetStates);
169 STORM_LOG_TRACE("Questionmark states: " << (~surelyReachSinkStates & ~targetStates));
170 bool result = analyze(k, ~surelyReachSinkStates & ~targetStates, pomdp.getInitialStates());
171 stats.totalTimer.stop();
172 return result;
173 }
174
175 void computeWinningRegion(uint64_t k) {
176 stats.totalTimer.start();
177 analyze(k, ~surelyReachSinkStates & ~targetStates);
178 stats.totalTimer.stop();
179 }
180
182 return winningRegion;
183 }
184
185 uint64_t getOffsetFromObservation(uint64_t state, uint64_t observation) const;
186
187 bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates,
188 storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
189
190 Statistics const& getStatistics() const;
191 void finalizeStatistics();
192
193 private:
194 storm::expressions::Expression const& getDoneActionExpression(uint64_t obs) const;
195
196 void reset() {
197 STORM_LOG_INFO("Reset solver to restart with current winning region");
198 schedulerForObs.clear();
199 finalSchedulers.clear();
200 smtSolver->reset();
201 }
202 void printScheduler(std::vector<InternalObservationScheduler> const&);
203 void coveredStatesToStream(std::ostream& os, storm::storage::BitVector const& remaining) const;
204
205 bool initialize(uint64_t k);
206
207 bool smtCheck(uint64_t iteration, std::set<storm::expressions::Expression> const& assumptions = {});
208
209 std::unique_ptr<storm::solver::SmtSolver> smtSolver;
211 std::shared_ptr<storm::expressions::ExpressionManager> expressionManager;
212 uint64_t maxK = std::numeric_limits<uint64_t>::max();
213
214 storm::storage::BitVector surelyReachSinkStates;
215 storm::storage::BitVector targetStates;
216 std::vector<std::vector<uint64_t>> statesPerObservation;
217
218 std::vector<storm::expressions::Variable> schedulerVariables;
219 std::vector<storm::expressions::Expression> schedulerVariableExpressions;
220 std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a}
221 std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars; // A_{z,a}
222
223 std::vector<storm::expressions::Variable> reachVars;
224 std::vector<storm::expressions::Expression> reachVarExpressions;
225 std::vector<std::vector<storm::expressions::Expression>> reachVarExpressionsPerObservation;
226
227 std::vector<storm::expressions::Variable> observationUpdatedVariables;
228 std::vector<storm::expressions::Expression> observationUpdatedExpressions;
229
230 std::vector<storm::expressions::Variable> switchVars;
231 std::vector<storm::expressions::Expression> switchVarExpressions;
232 std::vector<storm::expressions::Variable> followVars;
233 std::vector<storm::expressions::Expression> followVarExpressions;
234 std::vector<storm::expressions::Variable> continuationVars;
235 std::vector<storm::expressions::Expression> continuationVarExpressions;
236 std::vector<std::vector<storm::expressions::Variable>> pathVars;
237 std::vector<std::vector<storm::expressions::Expression>> pathVarExpressions;
238
239 std::vector<InternalObservationScheduler> finalSchedulers;
240 std::vector<uint64_t> schedulerForObs;
241 WinningRegion winningRegion;
242
243 MemlessSearchOptions options;
244 Statistics stats;
245
246 std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory;
247 std::shared_ptr<WinningRegionQueryInterface<ValueType>> validator;
248
249 mutable bool useFindOffset = false;
250};
251} // namespace pomdp
252} // namespace storm
This class represents a partially observable Markov decision process.
Definition Pomdp.h:15
uint64_t getOffsetFromObservation(uint64_t state, uint64_t observation) const
WinningRegion const & getLastWinningRegion() const
bool analyze(uint64_t k, storm::storage::BitVector const &oneOfTheseStates, storm::storage::BitVector const &allOfTheseStates=storm::storage::BitVector())
std::string const & getExportSATCallsPath() const
void setExportSATCalls(std::string const &path)
MemlessSearchPathVariables pathVariableType
A bit vector that is internally represented as a vector of 64-bit values.
Definition BitVector.h:18
void clear()
Removes all set bits from the bit vector.
size_t size() const
Retrieves the number of bits this bit vector can store.
bool get(uint_fast64_t index) const
Retrieves the truth value of the bit at the given index and performs a bound check.
A class that provides convenience operations to display run times.
Definition Stopwatch.h:14
void start()
Start stopwatch (again) and start measuring time.
Definition Stopwatch.cpp:48
void stop()
Stop stopwatch and add measured time to total time.
Definition Stopwatch.cpp:42
#define STORM_LOG_INFO(message)
Definition logging.h:29
#define STORM_LOG_TRACE(message)
Definition logging.h:17
MemlessSearchPathVariables pathVariableTypeFromString(std::string const &in)
LabParser.cpp.
Definition cli.cpp:18
std::vector< storm::storage::BitVector > actions
void printForObservations(storm::storage::BitVector const &observations, storm::storage::BitVector const &observationsAfterSwitch) const
void reset(uint64_t nrObservations, uint64_t nrActions)