8template<
typename ValueType>
9void OneShotPolicySearch<ValueType>::initialize(uint64_t k) {
10 if (maxK == std::numeric_limits<uint64_t>::max()) {
13 for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
14 actionSelectionVars.push_back(std::vector<storm::expressions::Variable>());
15 actionSelectionVarExpressions.push_back(std::vector<storm::expressions::Expression>());
16 statesPerObservation.push_back(std::vector<uint64_t>());
23 for (
auto obs : pomdp.getObservations()) {
24 pathVars.push_back(std::vector<storm::expressions::Expression>());
25 for (uint64_t i = 0;
i < k; ++
i) {
26 pathVars.back().push_back(expressionManager->declareBooleanVariable(
"P-" + std::to_string(stateId) +
"-" + std::to_string(i)).getExpression());
28 reachVars.push_back(expressionManager->declareBooleanVariable(
"C-" + std::to_string(stateId)));
29 reachVarExpressions.push_back(reachVars.back().getExpression());
30 statesPerObservation.at(obs).push_back(stateId++);
32 assert(pathVars.size() == pomdp.getNumberOfStates());
36 for (
auto const& statesForObservation : statesPerObservation) {
37 for (uint64_t a = 0; a < pomdp.getNumberOfChoices(statesForObservation.front()); ++a) {
38 std::string varName =
"A-" + std::to_string(obs) +
"-" + std::to_string(a);
39 actionSelectionVars.at(obs).push_back(expressionManager->declareBooleanVariable(varName));
40 actionSelectionVarExpressions.at(obs).push_back(actionSelectionVars.at(obs).back().getExpression());
48 for (
auto const& actionVars : actionSelectionVarExpressions) {
52 uint64_t rowindex = 0;
53 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
54 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
55 std::vector<storm::expressions::Expression> subexprreach;
56 subexprreach.push_back(!reachVarExpressions[state]);
57 subexprreach.push_back(!actionSelectionVarExpressions[pomdp.getObservation(state)][action]);
58 for (
auto const& entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
59 subexprreach.push_back(reachVarExpressions.at(entries.getColumn()));
61 subexprreach.pop_back();
68 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
69 if (targetStates.get(state)) {
70 smtSolver->add(pathVars[state][0]);
72 smtSolver->add(!pathVars[state][0]);
75 if (surelyReachSinkStates.get(state)) {
76 smtSolver->add(!reachVarExpressions[state]);
77 rowindex += pomdp.getNumberOfChoices(state);
78 }
else if (!targetStates.get(state)) {
79 std::vector<std::vector<std::vector<storm::expressions::Expression>>> pathsubsubexprs;
80 for (uint64_t j = 1; j < k; ++j) {
81 pathsubsubexprs.push_back(std::vector<std::vector<storm::expressions::Expression>>());
82 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
83 pathsubsubexprs.back().push_back(std::vector<storm::expressions::Expression>());
87 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
88 std::vector<storm::expressions::Expression> subexprreach;
89 for (
auto const& entries : pomdp.getTransitionMatrix().getRow(rowindex)) {
90 for (uint64_t j = 1; j < k; ++j) {
91 pathsubsubexprs[j - 1][action].push_back(pathVars[entries.getColumn()][j - 1]);
97 for (uint64_t j = 1; j < k; ++j) {
98 std::vector<storm::expressions::Expression> pathsubexprs;
100 for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) {
101 pathsubexprs.push_back(actionSelectionVarExpressions.at(pomdp.getObservation(state)).at(action) &&
110 rowindex += pomdp.getNumberOfChoices(state);
115template<
typename ValueType>
122 std::vector<storm::expressions::Expression> atLeastOneOfStates;
124 for (uint64_t state : oneOfTheseStates) {
125 atLeastOneOfStates.push_back(reachVarExpressions[state]);
127 assert(atLeastOneOfStates.size() > 0);
130 for (uint64_t state : allOfTheseStates) {
131 smtSolver->add(reachVarExpressions[state]);
137 stats.smtCheckTimer.start();
138 auto result = smtSolver->check();
139 stats.smtCheckTimer.stop();
142 STORM_LOG_THROW(
false, storm::exceptions::UnexpectedException,
"SMT solver yielded an unexpected result");
149 auto model = smtSolver->getModel();
153 for (
auto rv : reachVars) {
154 if (model->getBooleanValue(rv)) {
155 observations.set(pomdp.getObservation(i));
157 remainingstates.set(i);
161 std::vector<std::set<uint64_t>> scheduler;
162 for (
auto const& actionSelectionVarsForObs : actionSelectionVars) {
164 scheduler.push_back(std::set<uint64_t>());
165 for (
auto const& asv : actionSelectionVarsForObs) {
166 if (model->getBooleanValue(asv)) {
167 scheduler.back().insert(act);
176template class OneShotPolicySearch<double>;
177template class OneShotPolicySearch<storm::RationalNumber>;
A bit vector that is internally represented as a vector of 64-bit values.
#define STORM_LOG_DEBUG(message)
#define STORM_LOG_TRACE(message)
#define STORM_LOG_THROW(cond, exception, message)
Expression iff(Expression const &first, Expression const &second)
Expression disjunction(std::vector< storm::expressions::Expression > const &expressions)
Expression implies(Expression const &first, Expression const &second)
void initialize(int *argc, char **argv)