Storm 1.11.1.1
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
PartialQuotientExtractor.cpp
Go to the documentation of this file.
2
12
13namespace storm {
14namespace dd {
15namespace bisimulation {
16
17template<storm::dd::DdType DdType, typename ValueType, typename ExportValueType>
19 storm::dd::bisimulation::QuotientFormat const& quotientFormat)
20 : model(model), quotientFormat(quotientFormat) {
21 if (this->quotientFormat != storm::dd::bisimulation::QuotientFormat::Dd) {
22 STORM_LOG_ERROR("Only DD-based partial quotient extraction is currently supported. Switching to DD-based extraction.");
24 }
25}
26
27template<storm::dd::DdType DdType, typename ValueType, typename ExportValueType>
28std::shared_ptr<storm::models::Model<ExportValueType>> PartialQuotientExtractor<DdType, ValueType, ExportValueType>::extract(
29 Partition<DdType, ValueType> const& partition, PreservationInformation<DdType, ValueType> const& preservationInformation) {
30 auto start = std::chrono::high_resolution_clock::now();
31 std::shared_ptr<storm::models::Model<ExportValueType>> result;
32
33 STORM_LOG_THROW(this->quotientFormat == storm::dd::bisimulation::QuotientFormat::Dd, storm::exceptions::NotSupportedException,
34 "Only DD-based partial quotient extraction is currently supported.");
35 result = extractDdQuotient(partition, preservationInformation);
36 auto end = std::chrono::high_resolution_clock::now();
37 STORM_LOG_TRACE("Quotient extraction completed in " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms.");
38
39 STORM_LOG_THROW(result, storm::exceptions::NotSupportedException, "Quotient could not be extracted.");
40
41 return result;
42}
43
44template<storm::dd::DdType DdType, typename ValueType, typename ExportValueType>
45std::shared_ptr<storm::models::symbolic::Model<DdType, ExportValueType>> PartialQuotientExtractor<DdType, ValueType, ExportValueType>::extractDdQuotient(
46 Partition<DdType, ValueType> const& partition, PreservationInformation<DdType, ValueType> const& preservationInformation) {
47 auto modelType = model.getType();
48 if (modelType == storm::models::ModelType::Dtmc || modelType == storm::models::ModelType::Mdp) {
49 // Sanity checks.
50 STORM_LOG_ASSERT(partition.getNumberOfStates() == model.getNumberOfStates(), "Mismatching partition size.");
51 STORM_LOG_ASSERT(partition.getStates().renameVariables(model.getColumnVariables(), model.getRowVariables()) == model.getReachableStates(),
52 "Mismatching partition.");
53
54 std::set<storm::expressions::Variable> blockVariableSet = {partition.getBlockVariable()};
55 std::set<storm::expressions::Variable> blockPrimeVariableSet = {partition.getPrimedBlockVariable()};
56 std::vector<std::pair<storm::expressions::Variable, storm::expressions::Variable>> blockMetaVariablePairs = {
57 std::make_pair(partition.getBlockVariable(), partition.getPrimedBlockVariable())};
58
59 storm::dd::Bdd<DdType> partitionAsBdd = partition.storedAsBdd() ? partition.asBdd() : partition.asAdd().notZero();
60
61 auto start = std::chrono::high_resolution_clock::now();
62 partitionAsBdd = partitionAsBdd.renameVariables(model.getColumnVariables(), model.getRowVariables());
63 storm::dd::Bdd<DdType> reachableStates = partitionAsBdd.existsAbstract(model.getRowVariables());
64 storm::dd::Bdd<DdType> initialStates = (model.getInitialStates() && partitionAsBdd).existsAbstract(model.getRowVariables());
65
66 std::map<std::string, storm::dd::Bdd<DdType>> preservedLabelBdds;
67 for (auto const& label : preservationInformation.getLabels()) {
68 preservedLabelBdds.emplace(label, (model.getStates(label) && partitionAsBdd).existsAbstract(model.getRowVariables()));
69 }
70 for (auto const& expression : preservationInformation.getExpressions()) {
71 std::stringstream stream;
72 stream << expression;
73 std::string expressionAsString = stream.str();
74
75 auto it = preservedLabelBdds.find(expressionAsString);
76 if (it != preservedLabelBdds.end()) {
77 STORM_LOG_WARN("Duplicate label '" << expressionAsString << "', dropping second label definition.");
78 } else {
79 preservedLabelBdds.emplace(stream.str(), (model.getStates(expression) && partitionAsBdd).existsAbstract(model.getRowVariables()));
80 }
81 }
82 auto end = std::chrono::high_resolution_clock::now();
83 STORM_LOG_TRACE("Quotient labels extracted in " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms.");
84
85 start = std::chrono::high_resolution_clock::now();
86 std::set<storm::expressions::Variable> blockAndRowVariables;
87 std::set_union(blockVariableSet.begin(), blockVariableSet.end(), model.getRowVariables().begin(), model.getRowVariables().end(),
88 std::inserter(blockAndRowVariables, blockAndRowVariables.end()));
89 std::set<storm::expressions::Variable> blockPrimeAndColumnVariables;
90 std::set_union(blockPrimeVariableSet.begin(), blockPrimeVariableSet.end(), model.getColumnVariables().begin(), model.getColumnVariables().end(),
91 std::inserter(blockPrimeAndColumnVariables, blockPrimeAndColumnVariables.end()));
92 storm::dd::Add<DdType, ValueType> partitionAsAdd = partitionAsBdd.template toAdd<ValueType>();
93 storm::dd::Add<DdType, ValueType> quotientTransitionMatrix = model.getTransitionMatrix().multiplyMatrix(
94 partitionAsAdd.renameVariables(blockAndRowVariables, blockPrimeAndColumnVariables), model.getColumnVariables());
95
96 quotientTransitionMatrix = quotientTransitionMatrix * partitionAsAdd;
97 end = std::chrono::high_resolution_clock::now();
98
99 // Check quotient matrix for sanity.
100 if (std::is_same<ValueType, storm::RationalNumber>::value) {
101 STORM_LOG_ASSERT(quotientTransitionMatrix.greater(storm::utility::one<ValueType>()).isZero(), "Illegal entries in quotient matrix.");
102 } else {
103 STORM_LOG_ASSERT(quotientTransitionMatrix.greater(storm::utility::one<ValueType>() + storm::utility::convertNumber<ValueType>(1e-6)).isZero(),
104 "Illegal entries in quotient matrix.");
105 }
106
107 STORM_LOG_TRACE("Quotient transition matrix extracted in " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms.");
108
109 storm::dd::Bdd<DdType> quotientTransitionMatrixBdd = quotientTransitionMatrix.notZero();
110 std::set<storm::expressions::Variable> nonSourceVariables;
111 std::set_union(blockPrimeVariableSet.begin(), blockPrimeVariableSet.end(), model.getRowVariables().begin(), model.getRowVariables().end(),
112 std::inserter(nonSourceVariables, nonSourceVariables.begin()));
113 storm::dd::Bdd<DdType> deadlockStates = !quotientTransitionMatrixBdd.existsAbstract(nonSourceVariables) && reachableStates;
114
115 start = std::chrono::high_resolution_clock::now();
116 std::unordered_map<std::string, storm::models::symbolic::StandardRewardModel<DdType, ValueType>> quotientRewardModels;
117 for (auto const& rewardModelName : preservationInformation.getRewardModelNames()) {
118 auto const& rewardModel = model.getRewardModel(rewardModelName);
119
120 boost::optional<storm::dd::Add<DdType, ValueType>> quotientStateRewards;
121 if (rewardModel.hasStateRewards()) {
122 quotientStateRewards = rewardModel.getStateRewardVector() * partitionAsAdd;
123 }
124
125 boost::optional<storm::dd::Add<DdType, ValueType>> quotientStateActionRewards;
126 if (rewardModel.hasStateActionRewards()) {
127 quotientStateActionRewards = rewardModel.getStateActionRewardVector() * partitionAsAdd;
128 }
129
130 quotientRewardModels.emplace(rewardModelName, storm::models::symbolic::StandardRewardModel<DdType, ValueType>(
131 quotientStateRewards, quotientStateActionRewards, boost::none));
132 }
133 end = std::chrono::high_resolution_clock::now();
134 STORM_LOG_TRACE("Reward models extracted in " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms.");
135
136 std::shared_ptr<storm::models::symbolic::Model<DdType, ValueType>> result;
137 if (modelType == storm::models::ModelType::Dtmc) {
138 result = std::make_shared<storm::models::symbolic::Mdp<DdType, ValueType>>(
139 model.getManager().asSharedPointer(), reachableStates, initialStates, deadlockStates, quotientTransitionMatrix, blockVariableSet,
140 blockPrimeVariableSet, blockMetaVariablePairs, model.getRowVariables(), preservedLabelBdds, quotientRewardModels);
141 } else if (modelType == storm::models::ModelType::Mdp) {
142 std::set<storm::expressions::Variable> allNondeterminismVariables;
143 std::set_union(model.getRowVariables().begin(), model.getRowVariables().end(), model.getNondeterminismVariables().begin(),
144 model.getNondeterminismVariables().end(), std::inserter(allNondeterminismVariables, allNondeterminismVariables.begin()));
145
146 result = std::make_shared<storm::models::symbolic::StochasticTwoPlayerGame<DdType, ValueType>>(
147 model.getManager().asSharedPointer(), reachableStates, initialStates, deadlockStates, quotientTransitionMatrix, blockVariableSet,
148 blockPrimeVariableSet, blockMetaVariablePairs, model.getRowVariables(), model.getNondeterminismVariables(), allNondeterminismVariables,
149 preservedLabelBdds, quotientRewardModels);
150 } else {
151 STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Unsupported quotient type.");
152 }
153
154 return result->template toValueType<ExportValueType>();
155 } else {
156 STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Cannot extract partial quotient for this model type.");
157 }
158}
159
160template class PartialQuotientExtractor<storm::dd::DdType::CUDD, double>;
161template class PartialQuotientExtractor<storm::dd::DdType::Sylvan, double>;
162
163template class PartialQuotientExtractor<storm::dd::DdType::Sylvan, storm::RationalNumber>;
164template class PartialQuotientExtractor<storm::dd::DdType::Sylvan, storm::RationalNumber, double>;
165template class PartialQuotientExtractor<storm::dd::DdType::Sylvan, storm::RationalFunction>;
166
167} // namespace bisimulation
168} // namespace dd
169} // namespace storm
Bdd< LibraryType > greater(Add< LibraryType, ValueType > const &other) const
Retrieves the function that maps all evaluations to one whose function value in the first ADD are gre...
Definition Add.cpp:109
Add< LibraryType, ValueType > renameVariables(std::set< storm::expressions::Variable > const &from, std::set< storm::expressions::Variable > const &to) const
Renames the given meta variables in the ADD.
Definition Add.cpp:209
Add< LibraryType, ValueType > multiplyMatrix(Add< LibraryType, ValueType > const &otherMatrix, std::set< storm::expressions::Variable > const &summationMetaVariables) const
Multiplies the current ADD (representing a matrix) with the given matrix by summing over the given me...
Definition Add.cpp:365
Bdd< LibraryType > notZero() const
Computes a BDD that represents the function in which all assignments with a function value unequal to...
Definition Add.cpp:424
Bdd< LibraryType > existsAbstract(std::set< storm::expressions::Variable > const &metaVariables) const
Existentially abstracts from the given meta variables.
Definition Bdd.cpp:172
Bdd< LibraryType > renameVariables(std::set< storm::expressions::Variable > const &from, std::set< storm::expressions::Variable > const &to) const
Renames the given meta variables in the BDD.
Definition Bdd.cpp:341
PartialQuotientExtractor(storm::models::symbolic::Model< DdType, ValueType > const &model, storm::dd::bisimulation::QuotientFormat const &quotientFormat)
std::shared_ptr< storm::models::Model< ExportValueType > > extract(Partition< DdType, ValueType > const &partition, PreservationInformation< DdType, ValueType > const &preservationInformation)
storm::expressions::Variable const & getBlockVariable() const
storm::dd::Bdd< DdType > const & asBdd() const
storm::expressions::Variable const & getPrimedBlockVariable() const
storm::dd::Bdd< DdType > getStates() const
storm::dd::Add< DdType, ValueType > const & asAdd() const
Base class for all symbolic models.
Definition Model.h:42
#define STORM_LOG_WARN(message)
Definition logging.h:25
#define STORM_LOG_TRACE(message)
Definition logging.h:12
#define STORM_LOG_ERROR(message)
Definition logging.h:26
#define STORM_LOG_ASSERT(cond, message)
Definition macros.h:11
#define STORM_LOG_THROW(cond, exception, message)
Definition macros.h:30
std::pair< storm::RationalNumber, storm::RationalNumber > count(std::vector< storm::storage::BitVector > const &origSets, std::vector< storm::storage::BitVector > const &intersects, std::vector< storm::storage::BitVector > const &intersectsInfo, storm::RationalNumber val, bool plus, uint64_t remdepth)