Storm 1.11.1.1
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
OneShotPolicySearch.h
Go to the documentation of this file.
1#pragma once
2
3#include <vector>
4
10
11namespace storm {
12namespace pomdp {
13
14template<typename ValueType>
16 // TODO move.
17 std::set<uint32_t> observations;
18 for (auto state : states) {
19 observations.insert(pomdp.getObservation(state));
20 }
21 return observations;
22}
23
24template<typename ValueType>
30 class Statistics {
31 public:
32 Statistics() = default;
33
35 storm::utility::Stopwatch smtCheckTimer;
36 storm::utility::Stopwatch initializeSolverTimer;
37 };
38
39 public:
41 storm::storage::BitVector const& surelyReachSinkStates, std::shared_ptr<storm::utility::solver::SmtSolverFactory>& smtSolverFactory)
42 : pomdp(pomdp), targetObservations(extractObservations(pomdp, targetStates)), targetStates(targetStates), surelyReachSinkStates(surelyReachSinkStates) {
43 this->expressionManager = std::make_shared<storm::expressions::ExpressionManager>();
44 smtSolver = smtSolverFactory->create(*expressionManager);
45 }
46
48 surelyReachSinkStates = surelyReachSink;
49 }
50
56 bool analyzeForInitialStates(uint64_t k) {
57 STORM_LOG_TRACE("Bad states: " << surelyReachSinkStates);
58 STORM_LOG_TRACE("Target states: " << targetStates);
59 STORM_LOG_TRACE("Questionmark states: " << (~surelyReachSinkStates & ~targetStates));
60 return analyze(k, ~surelyReachSinkStates & ~targetStates, pomdp.getInitialStates());
61 }
62
63 private:
64 bool analyze(uint64_t k, storm::storage::BitVector const& oneOfTheseStates,
65 storm::storage::BitVector const& allOfTheseStates = storm::storage::BitVector());
66
67 void initialize(uint64_t k);
68
69 Statistics stats;
70 std::unique_ptr<storm::solver::SmtSolver> smtSolver;
72 std::shared_ptr<storm::expressions::ExpressionManager> expressionManager;
73 uint64_t maxK = std::numeric_limits<uint64_t>::max();
74
75 std::set<uint32_t> targetObservations;
76 storm::storage::BitVector targetStates;
77 storm::storage::BitVector surelyReachSinkStates;
78
79 std::vector<std::vector<uint64_t>> statesPerObservation;
80 std::vector<std::vector<storm::expressions::Expression>> actionSelectionVarExpressions; // A_{z,a}
81 std::vector<std::vector<storm::expressions::Variable>> actionSelectionVars;
82 std::vector<storm::expressions::Variable> reachVars;
83 std::vector<storm::expressions::Expression> reachVarExpressions;
84 std::vector<std::vector<storm::expressions::Expression>> pathVars;
85};
86} // namespace pomdp
87} // namespace storm
This class represents a partially observable Markov decision process.
Definition Pomdp.h:13
uint32_t getObservation(uint64_t state) const
Definition Pomdp.cpp:65
OneShotPolicySearch(storm::models::sparse::Pomdp< ValueType > const &pomdp, storm::storage::BitVector const &targetStates, storm::storage::BitVector const &surelyReachSinkStates, std::shared_ptr< storm::utility::solver::SmtSolverFactory > &smtSolverFactory)
bool analyzeForInitialStates(uint64_t k)
Check if you can find a memoryless policy from the initial states.
void setSurelyReachSinkStates(storm::storage::BitVector const &surelyReachSink)
A bit vector that is internally represented as a vector of 64-bit values.
Definition BitVector.h:16
A class that provides convenience operations to display run times.
Definition Stopwatch.h:14
#define STORM_LOG_TRACE(message)
Definition logging.h:12
std::set< uint32_t > extractObservations(storm::models::sparse::Pomdp< ValueType > const &pomdp, storm::storage::BitVector const &states)