Storm 1.11.1.1
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
IterativePolicySearch.h
Go to the documentation of this file.
1#pragma once
2
3#include <sstream>
4#include <vector>
5
13
14namespace storm {
15namespace pomdp {
16
19 if (in == "int") {
21 } else if (in == "real") {
23 } else {
24 assert(in == "bool");
26 }
27}
28
30 public:
31 void setExportSATCalls(std::string const& path) {
32 exportSATcalls = path;
33 }
34
35 std::string const& getExportSATCallsPath() const {
36 return exportSATcalls;
37 }
38
39 bool isExportSATSet() const {
40 return exportSATcalls != "";
41 }
42
43 void setDebugLevel(uint64_t level = 1) {
44 debugLevel = level;
45 }
46
47 bool computeInfoOutput() const {
48 return debugLevel > 0;
49 }
50
51 bool computeDebugOutput() const {
52 return debugLevel > 1;
53 }
54
55 bool computeTraceOutput() const {
56 return debugLevel > 2;
57 }
58
60 bool forceLookahead = false;
61 bool validateEveryStep = false;
62 bool validateResult = false;
65 uint64_t extensionCallTimeout = 0u;
66 uint64_t localIterationMaximum = 600;
67
68 private:
69 std::string exportSATcalls = "";
70 uint64_t debugLevel = 0;
71};
72
74 std::vector<storm::storage::BitVector> actions;
75 std::vector<uint64_t> schedulerRef;
77
78 void reset(uint64_t nrObservations, uint64_t nrActions) {
79 actions = std::vector<storm::storage::BitVector>(nrObservations, storm::storage::BitVector(nrActions));
80 schedulerRef = std::vector<uint64_t>(nrObservations, 0);
82 }
83
84 bool empty() const {
85 return actions.empty();
86 }
87
88 void printForObservations(storm::storage::BitVector const& observations, storm::storage::BitVector const& observationsAfterSwitch) const {
89 for (uint64_t obs = 0; obs < observations.size(); ++obs) {
90 if (observations.get(obs) || observationsAfterSwitch.get(obs)) {
91 STORM_LOG_INFO("For observation: " << obs);
92 }
93 if (observations.get(obs)) {
94 std::stringstream ss;
95 ss << "actions:";
96 for (auto act : actions[obs]) {
97 ss << " " << act;
98 }
99 if (switchObservations.get(obs)) {
100 ss << " and switch.";
101 }
102 STORM_LOG_INFO(ss.str());
103 }
104 if (observationsAfterSwitch.get(obs)) {
105 STORM_LOG_INFO("scheduler ref: " << schedulerRef[obs]);
106 }
107 }
108 }
109};
110
111template<typename ValueType>
113 // Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper.
114
115 public:
117 public:
118 Statistics() = default;
119 void print() const;
120
128
130
132 outerIterations++;
133 }
134
136 satCalls++;
137 }
138
139 uint64_t getChecks() {
140 return satCalls;
141 }
142
143 uint64_t getIterations() {
144 return outerIterations;
145 }
146
148 return graphBasedAnalysisWinOb;
149 }
150
152 graphBasedAnalysisWinOb++;
153 }
154
155 private:
156 uint64_t satCalls = 0;
157 uint64_t outerIterations = 0;
158 uint64_t graphBasedAnalysisWinOb = 0;
159 };
160
162 storm::storage::BitVector const& surelyReachSinkStates,
163
164 std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory, MemlessSearchOptions const& options);
165
166 bool analyzeForInitialStates(uint64_t k) {
167 stats.totalTimer.start();
168 STORM_LOG_TRACE("Bad states: " << surelyReachSinkStates);
169 STORM_LOG_TRACE("Target states: " << targetStates);
170 STORM_LOG_TRACE("Questionmark states: " << (~surelyReachSinkStates & ~targetStates));
171 bool result = analyze(k, ~surelyReachSinkStates & ~targetStates, pomdp.getInitialStates());
172 stats.totalTimer.stop();
173 return result;
174 }
175
176 void computeWinningRegion(uint64_t k) {
177 stats.totalTimer.start();
178 analyze(k, ~surelyReachSinkStates & ~targetStates);
179 stats.totalTimer.stop();
180 }
181
183 return winningRegion;
184 }
185
186 uint64_t getOffsetFromObservation(uint64_t state, uint64_t observation) const;
187
188 bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates,
189 storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
190
191 Statistics const& getStatistics() const;
192 void finalizeStatistics();
193
194 private:
195 void reset() {
196 STORM_LOG_INFO("Reset solver to restart with current winning region");
197 schedulerForObs.clear();
198 finalSchedulers.clear();
199 smtSolver->reset();
200 }
201 void printScheduler(std::vector<InternalObservationScheduler> const&);
202 void coveredStatesToStream(std::ostream& os, storm::storage::BitVector const& remaining) const;
203
204 bool initialize(uint64_t k);
205
206 bool smtCheck(uint64_t iteration, std::set<storm::expressions::Expression> const& assumptions = {});
207
208 std::unique_ptr<storm::solver::SmtSolver> smtSolver;
210 std::shared_ptr<storm::expressions::ExpressionManager> expressionManager;
211 uint64_t maxK = std::numeric_limits<uint64_t>::max();
212
213 storm::storage::BitVector surelyReachSinkStates;
214 storm::storage::BitVector targetStates;
215 std::vector<std::vector<uint64_t>> statesPerObservation;
216
217 std::vector<storm::expressions::Variable> schedulerVariables;
218 std::vector<storm::expressions::Expression> schedulerVariableExpressions;
219 std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a}
220 std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars; // A_{z,a}
221
222 std::vector<storm::expressions::Variable> reachVars;
223 std::vector<storm::expressions::Expression> reachVarExpressions;
224 std::vector<std::vector<storm::expressions::Expression>> reachVarExpressionsPerObservation;
225
226 std::vector<storm::expressions::Variable> observationUpdatedVariables;
227 std::vector<storm::expressions::Expression> observationUpdatedExpressions;
228
229 std::vector<storm::expressions::Variable> switchVars;
230 std::vector<storm::expressions::Expression> switchVarExpressions;
231 std::vector<storm::expressions::Variable> followVars;
232 std::vector<storm::expressions::Expression> followVarExpressions;
233 std::vector<storm::expressions::Variable> continuationVars;
234 std::vector<storm::expressions::Expression> continuationVarExpressions;
235 std::vector<std::vector<storm::expressions::Variable>> pathVars;
236 std::vector<std::vector<storm::expressions::Expression>> pathVarExpressions;
237
238 std::vector<InternalObservationScheduler> finalSchedulers;
239 std::vector<uint64_t> schedulerForObs;
240 WinningRegion winningRegion;
241
242 MemlessSearchOptions options;
243 Statistics stats;
244
245 std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory;
246 std::shared_ptr<WinningRegionQueryInterface<ValueType>> validator;
247
248 mutable bool useFindOffset = false;
249};
250} // namespace pomdp
251} // namespace storm
This class represents a partially observable Markov decision process.
Definition Pomdp.h:13
uint64_t getOffsetFromObservation(uint64_t state, uint64_t observation) const
WinningRegion const & getLastWinningRegion() const
bool analyze(uint64_t k, storm::storage::BitVector const &oneOfTheseStates, storm::storage::BitVector const &allOfTheseStates=storm::storage::BitVector())
std::string const & getExportSATCallsPath() const
void setExportSATCalls(std::string const &path)
MemlessSearchPathVariables pathVariableType
A bit vector that is internally represented as a vector of 64-bit values.
Definition BitVector.h:16
void clear()
Removes all set bits from the bit vector.
size_t size() const
Retrieves the number of bits this bit vector can store.
bool get(uint64_t index) const
Retrieves the truth value of the bit at the given index and performs a bound check.
A class that provides convenience operations to display run times.
Definition Stopwatch.h:14
void start()
Start stopwatch (again) and start measuring time.
Definition Stopwatch.cpp:48
void stop()
Stop stopwatch and add measured time to total time.
Definition Stopwatch.cpp:42
#define STORM_LOG_INFO(message)
Definition logging.h:24
#define STORM_LOG_TRACE(message)
Definition logging.h:12
MemlessSearchPathVariables pathVariableTypeFromString(std::string const &in)
std::vector< storm::storage::BitVector > actions
void printForObservations(storm::storage::BitVector const &observations, storm::storage::BitVector const &observationsAfterSwitch) const
void reset(uint64_t nrObservations, uint64_t nrActions)