Storm
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
PartialQuotientExtractor.cpp
Go to the documentation of this file.
2
4
8
11
14
16
17namespace storm {
18namespace dd {
19namespace bisimulation {
20
21template<storm::dd::DdType DdType, typename ValueType, typename ExportValueType>
23 storm::dd::bisimulation::QuotientFormat const& quotientFormat)
24 : model(model), quotientFormat(quotientFormat) {
25 if (this->quotientFormat != storm::dd::bisimulation::QuotientFormat::Dd) {
26 STORM_LOG_ERROR("Only DD-based partial quotient extraction is currently supported. Switching to DD-based extraction.");
28 }
29}
30
31template<storm::dd::DdType DdType, typename ValueType, typename ExportValueType>
32std::shared_ptr<storm::models::Model<ExportValueType>> PartialQuotientExtractor<DdType, ValueType, ExportValueType>::extract(
33 Partition<DdType, ValueType> const& partition, PreservationInformation<DdType, ValueType> const& preservationInformation) {
34 auto start = std::chrono::high_resolution_clock::now();
35 std::shared_ptr<storm::models::Model<ExportValueType>> result;
36
37 STORM_LOG_THROW(this->quotientFormat == storm::dd::bisimulation::QuotientFormat::Dd, storm::exceptions::NotSupportedException,
38 "Only DD-based partial quotient extraction is currently supported.");
39 result = extractDdQuotient(partition, preservationInformation);
40 auto end = std::chrono::high_resolution_clock::now();
41 STORM_LOG_TRACE("Quotient extraction completed in " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms.");
42
43 STORM_LOG_THROW(result, storm::exceptions::NotSupportedException, "Quotient could not be extracted.");
44
45 return result;
46}
47
48template<storm::dd::DdType DdType, typename ValueType, typename ExportValueType>
49std::shared_ptr<storm::models::symbolic::Model<DdType, ExportValueType>> PartialQuotientExtractor<DdType, ValueType, ExportValueType>::extractDdQuotient(
50 Partition<DdType, ValueType> const& partition, PreservationInformation<DdType, ValueType> const& preservationInformation) {
51 auto modelType = model.getType();
52 if (modelType == storm::models::ModelType::Dtmc || modelType == storm::models::ModelType::Mdp) {
53 // Sanity checks.
54 STORM_LOG_ASSERT(partition.getNumberOfStates() == model.getNumberOfStates(), "Mismatching partition size.");
55 STORM_LOG_ASSERT(partition.getStates().renameVariables(model.getColumnVariables(), model.getRowVariables()) == model.getReachableStates(),
56 "Mismatching partition.");
57
58 std::set<storm::expressions::Variable> blockVariableSet = {partition.getBlockVariable()};
59 std::set<storm::expressions::Variable> blockPrimeVariableSet = {partition.getPrimedBlockVariable()};
60 std::vector<std::pair<storm::expressions::Variable, storm::expressions::Variable>> blockMetaVariablePairs = {
61 std::make_pair(partition.getBlockVariable(), partition.getPrimedBlockVariable())};
62
63 storm::dd::Bdd<DdType> partitionAsBdd = partition.storedAsBdd() ? partition.asBdd() : partition.asAdd().notZero();
64
65 auto start = std::chrono::high_resolution_clock::now();
66 partitionAsBdd = partitionAsBdd.renameVariables(model.getColumnVariables(), model.getRowVariables());
67 storm::dd::Bdd<DdType> reachableStates = partitionAsBdd.existsAbstract(model.getRowVariables());
68 storm::dd::Bdd<DdType> initialStates = (model.getInitialStates() && partitionAsBdd).existsAbstract(model.getRowVariables());
69
70 std::map<std::string, storm::dd::Bdd<DdType>> preservedLabelBdds;
71 for (auto const& label : preservationInformation.getLabels()) {
72 preservedLabelBdds.emplace(label, (model.getStates(label) && partitionAsBdd).existsAbstract(model.getRowVariables()));
73 }
74 for (auto const& expression : preservationInformation.getExpressions()) {
75 std::stringstream stream;
76 stream << expression;
77 std::string expressionAsString = stream.str();
78
79 auto it = preservedLabelBdds.find(expressionAsString);
80 if (it != preservedLabelBdds.end()) {
81 STORM_LOG_WARN("Duplicate label '" << expressionAsString << "', dropping second label definition.");
82 } else {
83 preservedLabelBdds.emplace(stream.str(), (model.getStates(expression) && partitionAsBdd).existsAbstract(model.getRowVariables()));
84 }
85 }
86 auto end = std::chrono::high_resolution_clock::now();
87 STORM_LOG_TRACE("Quotient labels extracted in " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms.");
88
89 start = std::chrono::high_resolution_clock::now();
90 std::set<storm::expressions::Variable> blockAndRowVariables;
91 std::set_union(blockVariableSet.begin(), blockVariableSet.end(), model.getRowVariables().begin(), model.getRowVariables().end(),
92 std::inserter(blockAndRowVariables, blockAndRowVariables.end()));
93 std::set<storm::expressions::Variable> blockPrimeAndColumnVariables;
94 std::set_union(blockPrimeVariableSet.begin(), blockPrimeVariableSet.end(), model.getColumnVariables().begin(), model.getColumnVariables().end(),
95 std::inserter(blockPrimeAndColumnVariables, blockPrimeAndColumnVariables.end()));
96 storm::dd::Add<DdType, ValueType> partitionAsAdd = partitionAsBdd.template toAdd<ValueType>();
97 storm::dd::Add<DdType, ValueType> quotientTransitionMatrix = model.getTransitionMatrix().multiplyMatrix(
98 partitionAsAdd.renameVariables(blockAndRowVariables, blockPrimeAndColumnVariables), model.getColumnVariables());
99
100 quotientTransitionMatrix = quotientTransitionMatrix * partitionAsAdd;
101 end = std::chrono::high_resolution_clock::now();
102
103 // Check quotient matrix for sanity.
104 if (std::is_same<ValueType, storm::RationalNumber>::value) {
105 STORM_LOG_ASSERT(quotientTransitionMatrix.greater(storm::utility::one<ValueType>()).isZero(), "Illegal entries in quotient matrix.");
106 } else {
107 STORM_LOG_ASSERT(quotientTransitionMatrix.greater(storm::utility::one<ValueType>() + storm::utility::convertNumber<ValueType>(1e-6)).isZero(),
108 "Illegal entries in quotient matrix.");
109 }
110
111 STORM_LOG_TRACE("Quotient transition matrix extracted in " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms.");
112
113 storm::dd::Bdd<DdType> quotientTransitionMatrixBdd = quotientTransitionMatrix.notZero();
114 std::set<storm::expressions::Variable> nonSourceVariables;
115 std::set_union(blockPrimeVariableSet.begin(), blockPrimeVariableSet.end(), model.getRowVariables().begin(), model.getRowVariables().end(),
116 std::inserter(nonSourceVariables, nonSourceVariables.begin()));
117 storm::dd::Bdd<DdType> deadlockStates = !quotientTransitionMatrixBdd.existsAbstract(nonSourceVariables) && reachableStates;
118
119 start = std::chrono::high_resolution_clock::now();
120 std::unordered_map<std::string, storm::models::symbolic::StandardRewardModel<DdType, ValueType>> quotientRewardModels;
121 for (auto const& rewardModelName : preservationInformation.getRewardModelNames()) {
122 auto const& rewardModel = model.getRewardModel(rewardModelName);
123
124 boost::optional<storm::dd::Add<DdType, ValueType>> quotientStateRewards;
125 if (rewardModel.hasStateRewards()) {
126 quotientStateRewards = rewardModel.getStateRewardVector() * partitionAsAdd;
127 }
128
129 boost::optional<storm::dd::Add<DdType, ValueType>> quotientStateActionRewards;
130 if (rewardModel.hasStateActionRewards()) {
131 quotientStateActionRewards = rewardModel.getStateActionRewardVector() * partitionAsAdd;
132 }
133
134 quotientRewardModels.emplace(rewardModelName, storm::models::symbolic::StandardRewardModel<DdType, ValueType>(
135 quotientStateRewards, quotientStateActionRewards, boost::none));
136 }
137 end = std::chrono::high_resolution_clock::now();
138 STORM_LOG_TRACE("Reward models extracted in " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms.");
139
140 std::shared_ptr<storm::models::symbolic::Model<DdType, ValueType>> result;
141 if (modelType == storm::models::ModelType::Dtmc) {
142 result = std::make_shared<storm::models::symbolic::Mdp<DdType, ValueType>>(
143 model.getManager().asSharedPointer(), reachableStates, initialStates, deadlockStates, quotientTransitionMatrix, blockVariableSet,
144 blockPrimeVariableSet, blockMetaVariablePairs, model.getRowVariables(), preservedLabelBdds, quotientRewardModels);
145 } else if (modelType == storm::models::ModelType::Mdp) {
146 std::set<storm::expressions::Variable> allNondeterminismVariables;
147 std::set_union(model.getRowVariables().begin(), model.getRowVariables().end(), model.getNondeterminismVariables().begin(),
148 model.getNondeterminismVariables().end(), std::inserter(allNondeterminismVariables, allNondeterminismVariables.begin()));
149
150 result = std::make_shared<storm::models::symbolic::StochasticTwoPlayerGame<DdType, ValueType>>(
151 model.getManager().asSharedPointer(), reachableStates, initialStates, deadlockStates, quotientTransitionMatrix, blockVariableSet,
152 blockPrimeVariableSet, blockMetaVariablePairs, model.getRowVariables(), model.getNondeterminismVariables(), allNondeterminismVariables,
153 preservedLabelBdds, quotientRewardModels);
154 } else {
155 STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Unsupported quotient type.");
156 }
157
158 return result->template toValueType<ExportValueType>();
159 } else {
160 STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Cannot extract partial quotient for this model type.");
161 }
162}
163
164template class PartialQuotientExtractor<storm::dd::DdType::CUDD, double>;
165template class PartialQuotientExtractor<storm::dd::DdType::Sylvan, double>;
166
167#ifdef STORM_HAVE_CARL
168template class PartialQuotientExtractor<storm::dd::DdType::Sylvan, storm::RationalNumber>;
169template class PartialQuotientExtractor<storm::dd::DdType::Sylvan, storm::RationalNumber, double>;
170template class PartialQuotientExtractor<storm::dd::DdType::Sylvan, storm::RationalFunction>;
171#endif
172
173} // namespace bisimulation
174} // namespace dd
175} // namespace storm
Bdd< LibraryType > greater(Add< LibraryType, ValueType > const &other) const
Retrieves the function that maps all evaluations to one whose function value in the first ADD are gre...
Definition Add.cpp:116
Add< LibraryType, ValueType > renameVariables(std::set< storm::expressions::Variable > const &from, std::set< storm::expressions::Variable > const &to) const
Renames the given meta variables in the ADD.
Definition Add.cpp:216
Add< LibraryType, ValueType > multiplyMatrix(Add< LibraryType, ValueType > const &otherMatrix, std::set< storm::expressions::Variable > const &summationMetaVariables) const
Multiplies the current ADD (representing a matrix) with the given matrix by summing over the given me...
Definition Add.cpp:372
Bdd< LibraryType > notZero() const
Computes a BDD that represents the function in which all assignments with a function value unequal to...
Definition Add.cpp:431
Bdd< LibraryType > existsAbstract(std::set< storm::expressions::Variable > const &metaVariables) const
Existentially abstracts from the given meta variables.
Definition Bdd.cpp:178
Bdd< LibraryType > renameVariables(std::set< storm::expressions::Variable > const &from, std::set< storm::expressions::Variable > const &to) const
Renames the given meta variables in the BDD.
Definition Bdd.cpp:347
PartialQuotientExtractor(storm::models::symbolic::Model< DdType, ValueType > const &model, storm::dd::bisimulation::QuotientFormat const &quotientFormat)
std::shared_ptr< storm::models::Model< ExportValueType > > extract(Partition< DdType, ValueType > const &partition, PreservationInformation< DdType, ValueType > const &preservationInformation)
storm::expressions::Variable const & getBlockVariable() const
storm::dd::Bdd< DdType > const & asBdd() const
storm::expressions::Variable const & getPrimedBlockVariable() const
storm::dd::Bdd< DdType > getStates() const
storm::dd::Add< DdType, ValueType > const & asAdd() const
Base class for all symbolic models.
Definition Model.h:46
#define STORM_LOG_WARN(message)
Definition logging.h:30
#define STORM_LOG_TRACE(message)
Definition logging.h:17
#define STORM_LOG_ERROR(message)
Definition logging.h:31
#define STORM_LOG_ASSERT(cond, message)
Definition macros.h:11
#define STORM_LOG_THROW(cond, exception, message)
Definition macros.h:30
std::pair< storm::RationalNumber, storm::RationalNumber > count(std::vector< storm::storage::BitVector > const &origSets, std::vector< storm::storage::BitVector > const &intersects, std::vector< storm::storage::BitVector > const &intersectsInfo, storm::RationalNumber val, bool plus, uint64_t remdepth)
LabParser.cpp.
Definition cli.cpp:18