11namespace transformer {
13template<
typename ValueType>
15 bool addMemoryLabels,
bool keepStateValuations)
16 : pomdp(pomdp), memory(memory), addMemoryLabels(addMemoryLabels), keepStateValuations(keepStateValuations) {
20template<
typename ValueType>
23 STORM_LOG_THROW(pomdp.isCanonic(), storm::exceptions::InvalidArgumentException,
"POMDP must be canonical to unfold memory into it");
32 if (dropUnreachableStates) {
37 if (keepStateValuations && pomdp.hasStateValuations()) {
38 std::vector<uint64_t> newToOldStates(pomdp.getNumberOfStates() * memory.getNumberOfStates(), 0);
39 for (uint64_t newState = 0; newState < newToOldStates.size(); newState++) {
40 newToOldStates[newState] = getModelState(newState);
42 components.
stateValuations = pomdp.getStateValuations().blowup(newToOldStates).selectStates(reachableStates);
48 for (
auto const& rewModel : pomdp.getRewardModels()) {
52 return std::make_shared<storm::models::sparse::Pomdp<ValueType>>(std::move(components),
true);
55template<
typename ValueType>
59 uint64_t numEntries = 0;
60 for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) {
61 for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
62 numRows += origTransitions.
getRowGroupSize(modelState) * memory.getNumberOfOutgoingTransitions(memState);
67 pomdp.getNumberOfStates() * memory.getNumberOfStates());
70 for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) {
71 for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
72 builder.newRowGroup(row);
75 for (
auto const& memStatePrime : memory.getTransitions(memState)) {
76 for (
auto const& entry : origTransitions.getRow(origRow)) {
77 builder.addNextValue(row, getUnfoldingState(entry.getColumn(), memStatePrime), entry.getValue());
84 return builder.build();
87template<
typename ValueType>
90 for (
auto const& labelName : pomdp.getStateLabeling().getLabels()) {
94 if (labelName ==
"init") {
95 for (
auto const& modelState : pomdp.getStateLabeling().
getStates(labelName)) {
96 newStates.set(getUnfoldingState(modelState, memory.getInitialState()));
99 for (
auto const& modelState : pomdp.getStateLabeling().
getStates(labelName)) {
100 for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
101 newStates.set(getUnfoldingState(modelState, memState));
105 labeling.addLabel(labelName, std::move(newStates));
107 if (addMemoryLabels) {
108 for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
110 for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) {
111 newStates.set(getUnfoldingState(modelState, memState));
113 labeling.addLabel(
"memstate_" + std::to_string(memState), newStates);
119template<
typename ValueType>
120std::vector<uint32_t> PomdpMemoryUnfolder<ValueType>::transformObservabilityClasses(
storm::storage::BitVector const& reachableStates)
const {
121 std::vector<uint32_t> observations;
122 observations.reserve(pomdp.getNumberOfStates() * memory.getNumberOfStates());
123 for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) {
124 for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
125 if (reachableStates.
get(getUnfoldingState(modelState, memState))) {
126 observations.push_back(getUnfoldingObersvation(pomdp.getObservation(modelState), memState));
132 std::set<uint32_t> occuringObservations(observations.begin(), observations.end());
133 uint32_t highestObservation = *occuringObservations.rbegin();
134 std::vector<uint32_t> oldToNewObservationMapping(highestObservation + 1, std::numeric_limits<uint32_t>::max());
136 for (
auto const& oldObs : occuringObservations) {
137 oldToNewObservationMapping[oldObs] = newObs;
140 for (
auto& obs : observations) {
141 obs = oldToNewObservationMapping[obs];
147template<
typename ValueType>
150 std::optional<std::vector<ValueType>> stateRewards, actionRewards;
152 stateRewards = std::vector<ValueType>();
153 stateRewards->reserve(pomdp.getNumberOfStates() * memory.getNumberOfStates());
154 for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) {
155 for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
156 if (reachableStates.
get(getUnfoldingState(modelState, memState))) {
163 actionRewards = std::vector<ValueType>();
164 for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) {
165 for (uint64_t memState = 0; memState < memory.getNumberOfStates(); ++memState) {
166 if (reachableStates.
get(getUnfoldingState(modelState, memState))) {
167 for (uint64_t origRow = pomdp.getTransitionMatrix().getRowGroupIndices()[modelState];
168 origRow < pomdp.getTransitionMatrix().getRowGroupIndices()[modelState + 1]; ++origRow) {
170 actionRewards->insert(actionRewards->end(), memory.getNumberOfOutgoingTransitions(memState), actionReward);
180template<
typename ValueType>
181uint64_t PomdpMemoryUnfolder<ValueType>::getUnfoldingState(uint64_t modelState, uint64_t memoryState)
const {
182 return modelState * memory.getNumberOfStates() + memoryState;
185template<
typename ValueType>
186uint64_t PomdpMemoryUnfolder<ValueType>::getModelState(uint64_t unfoldingState)
const {
187 return unfoldingState / memory.getNumberOfStates();
190template<
typename ValueType>
191uint64_t PomdpMemoryUnfolder<ValueType>::getMemoryState(uint64_t unfoldingState)
const {
192 return unfoldingState % memory.getNumberOfStates();
195template<
typename ValueType>
196uint32_t PomdpMemoryUnfolder<ValueType>::getUnfoldingObersvation(uint32_t modelObservation, uint64_t memoryState)
const {
197 return modelObservation * memory.getNumberOfStates() + memoryState;
200template<
typename ValueType>
201uint32_t PomdpMemoryUnfolder<ValueType>::getModelObersvation(uint32_t unfoldingObservation)
const {
202 return unfoldingObservation / memory.getNumberOfStates();
205template<
typename ValueType>
206uint64_t PomdpMemoryUnfolder<ValueType>::getMemoryStateFromObservation(uint32_t unfoldingObservation)
const {
207 return unfoldingObservation % memory.getNumberOfStates();
210template class PomdpMemoryUnfolder<double>;
211template class PomdpMemoryUnfolder<storm::RationalNumber>;
212template class PomdpMemoryUnfolder<storm::RationalFunction>;
This class represents a partially observable Markov decision process.
bool hasTransitionRewards() const
Retrieves whether the reward model has transition rewards.
ValueType const & getStateReward(uint_fast64_t state) const
ValueType const & getStateActionReward(uint_fast64_t choiceIndex) const
Retrieves the state-action reward for the given choice.
bool hasStateRewards() const
Retrieves whether the reward model has state rewards.
bool hasStateActionRewards() const
Retrieves whether the reward model has state-action rewards.
This class manages the labeling of the state space with a number of (atomic) labels.
storm::storage::BitVector const & getStates(std::string const &label) const
Returns the labeling of states associated with the given label.
StateLabeling getSubLabeling(storm::storage::BitVector const &states) const
Retrieves the sub labeling that represents the same labeling as the current one for all selected stat...
A bit vector that is internally represented as a vector of 64-bit values.
bool get(uint64_t index) const
Retrieves the truth value of the bit at the given index and performs a bound check.
index_type getNumberOfEntries() const
Retrieves the number of entries in the rows.
A class that can be used to build a sparse matrix by adding value by value.
A class that holds a possibly non-square matrix in the compressed row storage format.
const_rows getRowGroup(index_type rowGroup) const
Returns an object representing the given row group.
std::vector< index_type > const & getRowGroupIndices() const
Returns the grouping of rows of this matrix.
index_type getRowGroupSize(index_type group) const
Returns the size of the given row group.
#define STORM_LOG_THROW(cond, exception, message)
SFTBDDChecker::ValueType ValueType
storm::storage::BitVector getStates(storm::logic::Formula const &propositionalFormula, bool formulaInverted, PomdpType const &pomdp)
storm::storage::BitVector getReachableStates(storm::storage::SparseMatrix< T > const &transitionMatrix, storm::storage::BitVector const &initialStates, storm::storage::BitVector const &constraintStates, storm::storage::BitVector const &targetStates, bool useStepBound, uint_fast64_t maximalSteps, boost::optional< storm::storage::BitVector > const &choiceFilter)
Performs a forward depth-first search through the underlying graph structure to identify the states t...
std::optional< storm::storage::sparse::StateValuations > stateValuations
std::unordered_map< std::string, RewardModelType > rewardModels
storm::storage::SparseMatrix< ValueType > transitionMatrix
storm::models::sparse::StateLabeling stateLabeling
std::optional< std::vector< uint32_t > > observabilityClasses