368 STORM_LOG_DEBUG(
"Surely reach sink states: " << surelyReachSinkStates);
370 STORM_LOG_DEBUG(
"Questionmark states " << (~surelyReachSinkStates & ~targetStates));
371 stats.initializeSolverTimer.start();
373 bool lookaheadConstraintsRequired = initialize(k);
374 if (lookaheadConstraintsRequired) {
378 stats.winningRegionUpdatesTimer.start();
382 for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
383 if (winningRegion.observationIsWinning(observation)) {
386 bool observationIsWinning =
true;
387 for (uint64_t state : statesPerObservation[observation]) {
388 if (!targetStates.get(state)) {
389 observationIsWinning =
false;
390 observationsWithPartialWinners.
set(observation);
392 potentialWinner.
set(observation);
395 if (observationIsWinning) {
397 stats.incrementGraphBasedWinningObservations();
398 winningRegion.setObservationIsWinning(observation);
399 updated.
set(observation);
402 STORM_LOG_DEBUG(
"Graph based winning obs: " << stats.getGraphBasedwinningObservations());
403 observationsWithPartialWinners &= potentialWinner;
404 for (
auto const observation : observationsWithPartialWinners) {
405 uint64_t nrStatesForObs = statesPerObservation[observation].
size();
407 for (uint64_t i = 0; i < nrStatesForObs; ++i) {
408 uint64_t state = statesPerObservation[observation][i];
409 if (targetStates.get(state)) {
413 assert(!update.
empty());
414 STORM_LOG_TRACE(
"Extend winning region for observation " << observation <<
" with target states/offsets" << update);
415 winningRegion.addTargetStates(observation, update);
416 assert(winningRegion.query(observation, update));
418 updated.
set(observation);
422 for (
auto const& state : targetStates) {
423 STORM_LOG_ASSERT(winningRegion.isWinning(pomdp.getObservation(state), getOffsetFromObservation(state, pomdp.getObservation(state))),
424 "Target state " << state <<
" , observation " << pomdp.getObservation(state) <<
" is not reflected as winning.");
428 stats.winningRegionUpdatesTimer.stop();
430 uint64_t maximalNrActions = 0;
431 for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
432 maximalNrActions = std::max(pomdp.getTransitionMatrix().getRowGroupSize(state), maximalNrActions);
434 std::vector<storm::expressions::Expression> atLeastOneOfStates;
435 for (uint64_t state : oneOfTheseStates) {
436 STORM_LOG_ASSERT(reachVarExpressions.size() > state,
"state id " << state <<
" exceeds number of states (" << reachVarExpressions.size() <<
")");
437 atLeastOneOfStates.push_back(reachVarExpressions[state]);
439 if (!atLeastOneOfStates.empty()) {
443 std::set<storm::expressions::Expression> allOfTheseAssumption;
444 std::vector<storm::expressions::Expression> updateForObservationExpressions;
446 for (uint64_t state : allOfTheseStates) {
447 assert(reachVarExpressions.size() > state);
448 allOfTheseAssumption.insert(reachVarExpressions[state]);
451 if (winningRegion.empty()) {
453 for (uint64_t ob = 0; ob < pomdp.getNrObservations(); ++ob) {
455 schedulerForObs.push_back(0);
459 for (
auto const& statesForObservation : statesPerObservation) {
460 schedulerForObs.push_back(0);
461 for (
auto const& winningSet : winningRegion.getWinningSetsPerObservation(obs)) {
462 assert(!winningSet.empty());
463 assert(obs < schedulerForObs.size());
464 ++(schedulerForObs[obs]);
465 auto constant = expressionManager->integer(schedulerForObs[obs]);
466 for (
auto const& stateOffset : ~winningSet) {
467 uint64_t state = statesForObservation[stateOffset];
468 STORM_LOG_TRACE(
"State " << state <<
" with observation " << obs <<
" does not allow scheduler " << constant);
469 smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant)));
471 !(reachVarExpressions[state] && followVarExpressions[pomdp.getObservation(state)] && (schedulerVariableExpressions[obs] == constant)));
474 if (winningRegion.getWinningSetsPerObservation(obs).empty()) {
479 updateForObservationExpressions.push_back(winningRegion.extensionExpression(obs, reachVarExpressionsPerObservation[obs]));
486 for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
487 auto constant = expressionManager->integer(schedulerForObs[obs]);
488 smtSolver->add(schedulerVariableExpressions[obs] <= constant);
489 smtSolver->add(
storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs]));
492 assert(pomdp.getNrObservations() == schedulerForObs.size());
505 stats.initializeSolverTimer.stop();
508 uint64_t iterations = 0;
510 bool foundWhatWeLookFor =
false;
512 stats.incrementOuterIterations();
514 scheduler.
reset(pomdp.getNrObservations(), maximalNrActions);
515 observations.
clear();
516 observationsAfterSwitch.
clear();
517 coveredStates = targetStates;
518 coveredStatesAfterSwitch.
clear();
519 observationUpdated.
clear();
520 bool newSchedulerDiscovered =
false;
521 if (!allOfTheseAssumption.empty()) {
523 bool foundResult = this->smtCheck(iterations, allOfTheseAssumption);
526 foundWhatWeLookFor =
true;
529 uint64_t localIterations = 0;
534 bool foundScheduler = foundWhatWeLookFor;
535 if (!foundScheduler) {
536 foundScheduler = this->smtCheck(iterations);
538 if (!foundScheduler) {
541 newSchedulerDiscovered =
true;
542 stats.evaluateExtensionSolverTime.start();
543 auto const& model = smtSolver->getModel();
545 newObservationsAfterSwitch.
clear();
546 newObservations.
clear();
549 for (
auto const& ov : observationUpdatedVariables) {
550 if (!observationUpdated.
get(obs) && model->getBooleanValue(ov)) {
552 observationUpdated.
set(obs);
557 uncoveredStates = ~coveredStates;
558 for (uint64_t i : uncoveredStates) {
559 auto const& rv = reachVars[i];
560 auto const& rvExpr = reachVarExpressions[i];
561 if (observationUpdated.
get(pomdp.getObservation(i)) && model->getBooleanValue(rv)) {
563 smtSolver->add(rvExpr);
564 assert(!surelyReachSinkStates.get(i));
565 newObservations.
set(pomdp.getObservation(i));
566 coveredStates.
set(i);
567 if (lookaheadConstraintsRequired) {
569 smtSolver->add(pathVarExpressions[i][0] == expressionManager->integer(model->getIntegerValue(pathVars[i][0])));
571 smtSolver->add(pathVarExpressions[i][0] == expressionManager->rational(model->getRationalValue(pathVars[i][0])));
578 for (uint64_t i : uncoveredStatesAfterSwitch) {
579 auto const& cv = continuationVars[i];
580 if (model->getBooleanValue(cv)) {
581 uint64_t obs = pomdp.getObservation(i);
582 STORM_LOG_ASSERT(winningRegion.isWinning(obs, getOffsetFromObservation(i, obs)),
583 "Cannot continue: No scheduler known for state " << i <<
" (observation " << obs <<
").");
584 auto const& cvExpr = continuationVarExpressions[i];
585 smtSolver->add(cvExpr);
586 if (!observationsAfterSwitch.
get(obs)) {
587 newObservationsAfterSwitch.
set(obs);
591 stats.evaluateExtensionSolverTime.stop();
593 if (options.computeTraceOutput()) {
596 stats.encodeExtensionSolverTime.start();
597 for (
auto obs : newObservations) {
598 auto const& actionSelectionVarsForObs = actionSelectionVars[obs];
599 observations.
set(obs);
600 for (uint64_t act = 0; act < actionSelectionVarsForObs.size(); ++act) {
601 if (model->getBooleanValue(actionSelectionVarsForObs[act])) {
602 scheduler.
actions[obs].set(act);
603 smtSolver->add(actionSelectionVarExpressions[obs][act]);
605 smtSolver->add(!actionSelectionVarExpressions[obs][act]);
608 if (model->getBooleanValue(switchVars[obs])) {
610 smtSolver->add(switchVarExpressions[obs]);
612 smtSolver->add(!switchVarExpressions[obs]);
614 if (model->getBooleanValue(followVars[obs])) {
615 smtSolver->add(followVarExpressions[obs]);
617 smtSolver->add(!followVarExpressions[obs]);
620 for (
auto obs : newObservationsAfterSwitch) {
621 observationsAfterSwitch.
set(obs);
622 scheduler.
schedulerRef[obs] = model->getIntegerValue(schedulerVariables[obs]);
623 smtSolver->add(schedulerVariableExpressions[obs] == expressionManager->integer(scheduler.
schedulerRef[obs]));
626 if (options.computeTraceOutput()) {
633 if (foundWhatWeLookFor ||
634 (options.localIterationMaximum > 0 && (localIterations % (options.localIterationMaximum + 1) == options.localIterationMaximum))) {
635 stats.encodeExtensionSolverTime.stop();
639 std::vector<storm::expressions::Expression> remainingExpressions;
640 for (
auto index : ~coveredStates) {
641 if (observationUpdated.
get(pomdp.getObservation(index))) {
642 remainingExpressions.push_back(reachVarExpressions[index]);
645 for (
auto index : ~observationUpdated) {
646 remainingExpressions.push_back(observationUpdatedExpressions[index]);
649 if (remainingExpressions.empty()) {
650 stats.encodeExtensionSolverTime.stop();
654 stats.encodeExtensionSolverTime.stop();
657 if (!newSchedulerDiscovered) {
663 if (options.computeDebugOutput()) {
664 std::stringstream strstr;
665 coveredStatesToStream(strstr, ~coveredStates);
673 stats.winningRegionUpdatesTimer.start();
675 uint64_t newTargetObservations = 0;
676 for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
680 for (uint64_t state : statesPerObservation[observation]) {
681 if (coveredStates.
get(state)) {
682 assert(!surelyReachSinkStates.get(state));
687 if (!update.
empty()) {
688 STORM_LOG_TRACE(
"Update Winning Region: Observation " << observation <<
" with update " << update);
689 bool updateResult = winningRegion.update(observation, update);
692 if (winningRegion.observationIsWinning(observation)) {
693 ++newTargetObservations;
694 for (uint64_t state : statesPerObservation[observation]) {
695 targetStates.set(state);
696 assert(!surelyReachSinkStates.get(state));
699 updated.
set(observation);
700 updateForObservationExpressions[observation] =
701 winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]);
705 stats.winningRegionUpdatesTimer.stop();
706 if (foundWhatWeLookFor) {
709 if (newTargetObservations > 0) {
710 stats.graphSearchTime.start();
712 uint64_t targetStatesBefore = targetStates.getNumberOfSetBits();
713 STORM_LOG_DEBUG(
"Target states before graph based analysis " << targetStates.getNumberOfSetBits());
714 targetStates = graphanalysis.
analyseProb1Max(~surelyReachSinkStates, targetStates);
716 STORM_LOG_DEBUG(
"Target states after graph based analysis " << targetStates.getNumberOfSetBits());
717 stats.graphSearchTime.stop();
718 if (targetStatesAfter - targetStatesBefore > 0) {
719 stats.winningRegionUpdatesTimer.start();
722 for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
723 if (winningRegion.observationIsWinning(observation)) {
726 bool observationIsWinning =
true;
727 for (uint64_t state : statesPerObservation[observation]) {
728 if (!targetStates.get(state)) {
729 observationIsWinning =
false;
730 observationsWithPartialWinners.
set(observation);
732 potentialWinner.
set(observation);
735 if (observationIsWinning) {
736 stats.incrementGraphBasedWinningObservations();
737 winningRegion.setObservationIsWinning(observation);
738 updated.
set(observation);
741 STORM_LOG_DEBUG(
"Graph-based winning obs: " << stats.getGraphBasedwinningObservations());
742 observationsWithPartialWinners &= potentialWinner;
743 for (
auto const observation : observationsWithPartialWinners) {
744 uint64_t nrStatesForObs = statesPerObservation[observation].
size();
746 for (uint64_t i = 0; i < nrStatesForObs; ++i) {
747 uint64_t state = statesPerObservation[observation][i];
748 if (targetStates.get(state)) {
752 assert(!update.
empty());
753 STORM_LOG_TRACE(
"Extend winning region for observation " << observation <<
" with target states/offsets" << update);
754 winningRegion.addTargetStates(observation, update);
755 assert(winningRegion.query(observation, update));
756 updated.
set(observation);
758 stats.winningRegionUpdatesTimer.stop();
762 return analyze(k, ~targetStates & ~surelyReachSinkStates, allOfTheseStates);
767 if (options.computeDebugOutput()) {
768 winningRegion.print();
770 if (options.validateEveryStep) {
771 STORM_LOG_WARN(
"Validating every step, for debug purposes only!");
772 validator->validate();
774 if (stats.getIterations() % options.restartAfterNIterations == options.restartAfterNIterations - 1) {
776 return analyze(k, ~targetStates & ~surelyReachSinkStates, allOfTheseStates);
778 stats.updateNewStrategySolverTime.start();
779 for (uint64_t observation : updated) {
780 updateForObservationExpressions[observation] = winningRegion.extensionExpression(observation, reachVarExpressionsPerObservation[observation]);
784 for (
auto const& statesForObservation : statesPerObservation) {
785 if (observations.
get(obs) && updated.
get(obs)) {
786 STORM_LOG_DEBUG(
"We have a new policy ( " << finalSchedulers.size() <<
" ) for states with observation " << obs <<
".");
787 assert(schedulerForObs.size() > obs);
788 (schedulerForObs[obs])++;
789 STORM_LOG_DEBUG(
"We now have " << schedulerForObs[obs] <<
" policies for states with observation " << obs);
790 if (winningRegion.observationIsWinning(obs)) {
791 for (
auto const& state : statesForObservation) {
792 smtSolver->add(reachVarExpressions[state]);
794 auto constant = expressionManager->integer(schedulerForObs[obs]);
795 smtSolver->add(schedulerVariableExpressions[obs] == constant);
797 auto constant = expressionManager->integer(schedulerForObs[obs]);
798 for (
auto const& state : statesForObservation) {
799 if (!coveredStates.
get(state)) {
800 smtSolver->add(!(continuationVarExpressions[state] && (schedulerVariableExpressions[obs] == constant)));
801 smtSolver->add(!(reachVarExpressions[state] && followVarExpressions[pomdp.getObservation(state)] &&
802 (schedulerVariableExpressions[obs] == constant)));
809 finalSchedulers.push_back(scheduler);
813 for (uint64_t obs = 0; obs < pomdp.getNrObservations(); ++obs) {
814 if (winningRegion.observationIsWinning(obs)) {
815 auto constant = expressionManager->integer(schedulerForObs[obs]);
818 smtSolver->add(!observationUpdatedExpressions[obs]);
820 auto constant = expressionManager->integer(schedulerForObs[obs]);
821 smtSolver->add(schedulerVariableExpressions[obs] <= constant);
822 smtSolver->add(
storm::expressions::iff(observationUpdatedExpressions[obs], updateForObservationExpressions[obs]));
825 stats.updateNewStrategySolverTime.stop();
827 STORM_LOG_INFO(
"... after iteration " << stats.getIterations() <<
" so far " << stats.getChecks() <<
" checks.");
829 if (options.validateResult) {
830 STORM_LOG_WARN(
"Validating result is a winning region, only for debugging purposes.");
831 validator->validate();
832 STORM_LOG_WARN(
"Validating result is a fixed point, only for debugging purposes.");
833 validator->validateIsMaximal(surelyReachSinkStates);
836 if (!allOfTheseStates.
empty()) {
837 for (uint64_t observation = 0; observation < pomdp.getNrObservations(); ++observation) {
840 for (uint64_t state : statesPerObservation[observation]) {
841 if (allOfTheseStates.
get(state)) {
846 if (!winningRegion.query(observation, check)) {