Storm
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
OneShotPolicySearch.h
Go to the documentation of this file.
1#include <vector>
8
9namespace storm {
10namespace pomdp {
11
12template<typename ValueType>
14 // TODO move.
15 std::set<uint32_t> observations;
16 for (auto state : states) {
17 observations.insert(pomdp.getObservation(state));
18 }
19 return observations;
20}
21
22template<typename ValueType>
28 class Statistics {
29 public:
30 Statistics() = default;
31
33 storm::utility::Stopwatch smtCheckTimer;
34 storm::utility::Stopwatch initializeSolverTimer;
35 };
36
37 public:
39 storm::storage::BitVector const& surelyReachSinkStates, std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory)
40 : pomdp(pomdp), targetObservations(extractObservations(pomdp, targetStates)), targetStates(targetStates), surelyReachSinkStates(surelyReachSinkStates) {
41 this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
42 smtSolver = smtSolverFactory->create(*expressionManager);
43 }
44
46 surelyReachSinkStates = surelyReachSink;
47 }
48
54 bool analyzeForInitialStates(uint64_t k) {
55 STORM_LOG_TRACE("Bad states: " << surelyReachSinkStates);
56 STORM_LOG_TRACE("Target states: " << targetStates);
57 STORM_LOG_TRACE("Questionmark states: " << (~surelyReachSinkStates & ~targetStates));
58 return analyze(k, ~surelyReachSinkStates & ~targetStates, pomdp.getInitialStates());
59 }
60
61 private:
62 bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates,
63 storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
64
65 void initialize(uint64_t k);
66
67 Statistics stats;
68 std::unique_ptr<storm::solver::SmtSolver> smtSolver;
70 std::shared_ptr<storm::expressions::ExpressionManager> expressionManager;
71 uint64_t maxK = std::numeric_limits<uint64_t>::max();
72
73 std::set<uint32_t> targetObservations;
74 storm::storage::BitVector targetStates;
75 storm::storage::BitVector surelyReachSinkStates;
76
77 std::vector<std::vector<uint64_t>> statesPerObservation;
78 std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a}
79 std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars;
80 std::vector<storm::expressions::Variable> reachVars;
81 std::vector<storm::expressions::Expression> reachVarExpressions;
82 std::vector<std::vector<storm::expressions::Expression>> pathVars;
83};
84} // namespace pomdp
85} // namespace storm
This class represents a partially observable Markov decision process.
Definition Pomdp.h:15
uint32_t getObservation(uint64_t state) const
Definition Pomdp.cpp:63
OneShotPolicySearch(storm::models::sparse::Pomdp< ValueType > const &pomdp, storm::storage::BitVector const &targetStates, storm::storage::BitVector const &surelyReachSinkStates, std::shared_ptr< storm::utility::solver::SmtSolverFactory > &smtSolverFactory)
bool analyzeForInitialStates(uint64_t k)
Check if you can find a memoryless policy from the initial states.
void setSurelyReachSinkStates(storm::storage::BitVector const &surelyReachSink)
A bit vector that is internally represented as a vector of 64-bit values.
Definition BitVector.h:18
A class that provides convenience operations to display run times.
Definition Stopwatch.h:14
#define STORM_LOG_TRACE(message)
Definition logging.h:17
std::set< uint32_t > extractObservations(storm::models::sparse::Pomdp< ValueType > const &pomdp, storm::storage::BitVector const &states)
LabParser.cpp.
Definition cli.cpp:18