Storm 1.11.1.1
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
GradientDescentInstantiationSearcherTest.cpp
Go to the documentation of this file.
1#include "storm-config.h"
2#include "test/storm_gtest.h"
3
10#include "storm/api/builder.h"
11#include "storm/api/storm.h"
16
17using namespace storm::pars;
18
19namespace {
20class RationalGmmxxEnvironment {
21 public:
22 typedef storm::RationalFunction FunctionType;
23 typedef storm::RationalNumber ConstantType;
24 static storm::Environment createEnvironment() {
26 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
27 return env;
28 }
29};
30class DoubleGmmxxEnvironment {
31 public:
32 typedef storm::RationalFunction FunctionType;
33 typedef double ConstantType;
34 static storm::Environment createEnvironment() {
36 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
37 return env;
38 }
39};
40class RationalEigenEnvironment {
41 public:
42 typedef storm::RationalFunction FunctionType;
43 typedef storm::RationalNumber ConstantType;
44 static storm::Environment createEnvironment() {
46 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
47 return env;
48 }
49};
50class DoubleEigenEnvironment {
51 public:
52 typedef storm::RationalFunction FunctionType;
53 typedef double ConstantType;
54 static storm::Environment createEnvironment() {
56 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
57 return env;
58 }
59};
60class DoubleEigenTopologicalEnvironment {
61 public:
62 typedef storm::RationalFunction FunctionType;
63 typedef double ConstantType;
64 static storm::Environment createEnvironment() {
66 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Topological);
67 env.solver().topological().setUnderlyingEquationSolverType(storm::solver::EquationSolverType::Eigen);
68 return env;
69 }
70};
71template<typename TestType>
72class GradientDescentInstantiationSearcherTest : public ::testing::Test {
73 public:
74 typedef typename TestType::FunctionType FunctionType;
75 typedef typename TestType::ConstantType ConstantType;
76 GradientDescentInstantiationSearcherTest() : _environment(TestType::createEnvironment()) {}
77 storm::Environment const& env() const {
78 return _environment;
79 }
80 virtual void SetUp() {
81#ifndef STORM_HAVE_Z3
82 GTEST_SKIP() << "Z3 not available.";
83#endif
84 carl::VariablePool::getInstance().clear();
85 }
86 virtual void TearDown() {
87 carl::VariablePool::getInstance().clear();
88 }
89
90 private:
91 storm::Environment _environment;
92};
93
94typedef ::testing::Types<
95// The rational environments take ages... GD is just not made for rational arithmetic.
96#ifdef STORM_HAVE_GMM
97 DoubleGmmxxEnvironment, /*RationalGmmxxEnvironment,*/
98#endif
99 DoubleEigenEnvironment, DoubleEigenTopologicalEnvironment /*, RationalEigenEnvironment*/>
101} // namespace
102
103TYPED_TEST_SUITE(GradientDescentInstantiationSearcherTest, TestingTypes, );
104
105TYPED_TEST(GradientDescentInstantiationSearcherTest, Simple) {
106 std::string programFile = STORM_TEST_RESOURCES_DIR "/pdtmc/gradient1.pm";
107 std::string formulaAsString = "P>=0.2499 [F s=2]";
108 std::string constantsAsString = ""; // e.g. pL=0.9,TOACK=0.5
109
110 // Program and formula
111 storm::prism::Program program = storm::api::parseProgram(programFile);
112 program = storm::utility::prism::preprocess(program, constantsAsString);
113 std::vector<std::shared_ptr<const storm::logic::Formula>> formulas =
115 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> model =
116 storm::api::buildSparseModel<storm::RationalFunction>(program, formulas)->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
117 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> dtmc = model->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
119 ASSERT_TRUE(simplifier.simplify(*formulas[0]));
120 model = simplifier.getSimplifiedModel();
122
124
126 std::shared_ptr<storm::logic::Formula> formulaWithoutBounds = formulas[0]->clone();
127 std::shared_ptr<storm::logic::Formula const> formulaNoBound = formulaWithoutBounds->asSharedPointer();
128 std::shared_ptr<FeasibilitySynthesisTask> t = std::make_shared<FeasibilitySynthesisTask>(formulaNoBound);
129 t->setBound(formulas[0]->asOperatorFormula().getBound());
130 std::shared_ptr<FeasibilitySynthesisTask const> feasibilityTask = std::make_shared<FeasibilitySynthesisTask const>(std::move(*t));
131
132 checker.setup(this->env(), feasibilityTask);
133 typename TestFixture::ConstantType doubleInstantiation = checker.gradientDescent().second;
134 ASSERT_NEAR(doubleInstantiation * 4, 1, 1e-6);
135}
136
137TYPED_TEST(GradientDescentInstantiationSearcherTest, Crowds) {
138 std::string programFile = STORM_TEST_RESOURCES_DIR "/pdtmc/crowds3_5.pm";
139 std::string formulaAsString = "P<=0.00000001 [F \"observe0Greater1\"]";
140 std::string constantsAsString = ""; // e.g. pL=0.9,TOACK=0.5
141
142 // Program and formula
143 storm::prism::Program program = storm::api::parseProgram(programFile);
144 program = storm::utility::prism::preprocess(program, constantsAsString);
145 std::vector<std::shared_ptr<const storm::logic::Formula>> formulas =
147 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> model =
148 storm::api::buildSparseModel<storm::RationalFunction>(program, formulas)->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
149 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> dtmc = model->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
151 ASSERT_TRUE(simplifier.simplify(*formulas[0]));
152 model = simplifier.getSimplifiedModel();
153
154 dtmc = storm::api::performBisimulationMinimization<storm::RationalFunction>(model, formulas)->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
155
156 std::shared_ptr<storm::logic::Formula> formulaWithoutBounds = formulas[0]->clone();
157 std::shared_ptr<storm::logic::Formula const> formulaNoBound = formulaWithoutBounds->asSharedPointer();
158 std::shared_ptr<FeasibilitySynthesisTask> t = std::make_shared<FeasibilitySynthesisTask>(formulaNoBound);
159 t->setBound(formulas[0]->asOperatorFormula().getBound());
160 std::shared_ptr<FeasibilitySynthesisTask const> feasibilityTask = std::make_shared<FeasibilitySynthesisTask const>(std::move(*t));
161
163
164 // First, test an ADAM instance. We will check that we have implemented ADAM correctly by comparing our results to results gathered by an ADAM
165 // implementation in tensorflow :)
167 *dtmc, storm::derivative::GradientDescentMethod::ADAM, 0.01, 0.9, 0.999, 2, 1e-6, std::nullopt,
169 adamChecker.setup(this->env(), feasibilityTask);
170 auto doubleInstantiation = adamChecker.gradientDescent();
171 auto walk = adamChecker.getVisualizationWalk();
172
173 carl::Variable badCVar;
174 carl::Variable pfVar;
175 for (auto parameter : storm::models::sparse::getProbabilityParameters(*dtmc)) {
176 if (parameter.name() == "badC") {
177 badCVar = parameter;
178 } else if (parameter.name() == "PF") {
179 pfVar = parameter;
180 }
181 }
182 std::shared_ptr<storm::RawPolynomialCache> cache = std::make_shared<storm::RawPolynomialCache>();
183 auto const badC = storm::RationalFunction(storm::Polynomial(storm::RawPolynomial(badCVar), cache));
184 auto const pf = storm::RationalFunction(storm::Polynomial(storm::RawPolynomial(pfVar), cache));
185
186 const double badCValues[] = {0.5,
187 0.49000033736228943,
188 0.4799894690513611,
189 0.4699603021144867,
190 0.4599062204360962,
191 0.4498208165168762,
192 0.43969807028770447,
193 0.4295324981212616,
194 0.41931894421577454,
195 0.4090527594089508,
196 0.39872977137565613,
197 0.3883463442325592,
198 0.3778994083404541,
199 0.3673863708972931,
200 0.35680532455444336,
201 0.3461548984050751,
202 0.3354344069957733,
203 0.3246437907218933,
204 0.31378373503685,
205 0.30285564064979553,
206 0.2918616533279419,
207 0.28080472350120544,
208 0.2696886956691742,
209 0.258518248796463,
210 0.24729910492897034,
211 0.23603792488574982,
212 0.22474248707294464,
213 0.21342173218727112,
214 0.20208582282066345,
215 0.1907462775707245,
216 0.17941603064537048,
217 0.16810956597328186,
218 0.15684303641319275,
219 0.14563430845737457,
220 0.13450318574905396,
221 0.1234714537858963,
222 0.11256305128335953,
223 0.10180411487817764,
224 0.0912230908870697,
225 0.080850750207901,
226 0.07072017341852188};
227 const double pfValues[] = {0.5,
228 0.49000218510627747,
229 0.47999337315559387,
230 0.46996763348579407,
231 0.45991936326026917,
232 0.4498434364795685,
233 0.4397353231906891,
234 0.4295912981033325,
235 0.4194081425666809,
236 0.4091835021972656,
237 0.3989158570766449,
238 0.3886045515537262,
239 0.3782498240470886,
240 0.3678528368473053,
241 0.3574157953262329,
242 0.3469419777393341,
243 0.33643582463264465,
244 0.32590293884277344,
245 0.3153502345085144,
246 0.30478596687316895,
247 0.2942197024822235,
248 0.28366249799728394,
249 0.2731269598007202,
250 0.26262718439102173,
251 0.25217896699905396,
252 0.24179960787296295,
253 0.2315080612897873,
254 0.2213248461484909,
255 0.2112719565629959,
256 0.20137275755405426,
257 0.1916518211364746,
258 0.18213464319705963,
259 0.17284737527370453,
260 0.1638164073228836,
261 0.15506798028945923,
262 0.14662761986255646,
263 0.13851958513259888,
264 0.13076619803905487,
265 0.12338729947805405,
266 0.11639954894781113,
267 0.10981585085391998};
268
269 for (uint_fast64_t i = 0; i < 41; i++) {
270 ASSERT_NEAR(storm::utility::convertNumber<double>(walk[i].position[badCVar]), badCValues[i], 1e-4);
271 ASSERT_NEAR(storm::utility::convertNumber<double>(walk[i].position[pfVar]), pfValues[i], 1e-4);
272 }
273
274 ASSERT_NEAR(storm::utility::convertNumber<double>(doubleInstantiation.second), 0, 1e-6);
275
276 // Same thing with RAdam
278 *dtmc, storm::derivative::GradientDescentMethod::RADAM, 0.01, 0.9, 0.999, 2, 1e-6, std::nullopt,
280 radamChecker.setup(this->env(), feasibilityTask);
281 auto radamInstantiation = radamChecker.gradientDescent();
282 auto radamWalk = radamChecker.getVisualizationWalk();
283
284 const double badCValuesRadam[] = {0.5,
285 0.49060654640197754,
286 0.48096320033073425,
287 0.47105303406715393,
288 0.4608582556247711,
289 0.46068474650382996,
290 0.4604234993457794,
291 0.460092693567276,
292 0.4597006142139435,
293 0.4592539370059967,
294 0.4587569832801819,
295 0.45821380615234375,
296 0.4576275646686554,
297 0.4570005536079407,
298 0.45633500814437866,
299 0.4556325674057007,
300 0.45489493012428284,
301 0.454123318195343,
302 0.4533190429210663,
303 0.45248323678970337,
304 0.45161670446395874,
305 0.4507202208042145,
306 0.44979462027549744,
307 0.448840469121933,
308 0.44785845279693604,
309 0.4468490481376648,
310 0.4458127021789551,
311 0.4447498917579651,
312 0.44366100430488586,
313 0.4425463378429413,
314 0.4414062798023224,
315 0.4402410686016083,
316 0.4390510618686676,
317 0.43783634901046753,
318 0.43659722805023193,
319 0.43533384799957275,
320 0.4340464472770691,
321 0.4327350854873657,
322 0.43139997124671936,
323 0.43004119396209717,
324 0.4286588430404663};
325
326 const double pfValuesRadam[] = {0.5,
327 0.4985547959804535,
328 0.4970662295818329,
329 0.4955315589904785,
330 0.4939480423927307,
331 0.4937744438648224,
332 0.4935130178928375,
333 0.4931819438934326,
334 0.49278953671455383,
335 0.492342472076416,
336 0.49184510111808777,
337 0.4913014769554138,
338 0.49071478843688965,
339 0.4900873303413391,
340 0.48942139744758606,
341 0.4887186288833618,
342 0.48798075318336487,
343 0.4872090220451355,
344 0.48640477657318115,
345 0.48556917905807495,
346 0.48470309376716614,
347 0.4838073253631592,
348 0.48288270831108093,
349 0.48192986845970154,
350 0.4809495508670807,
351 0.4799422323703766,
352 0.47890838980674744,
353 0.4778485894203186,
354 0.4767632484436035,
355 0.4756527245044708,
356 0.47451743483543396,
357 0.4733576774597168,
358 0.4721738398075104,
359 0.470966100692749,
360 0.4697347581386566,
361 0.46848005056381226,
362 0.4672022759914398,
363 0.46590155363082886,
364 0.46457815170288086,
365 0.46323221921920776,
366 0.4618639349937439};
367
368 for (uint_fast64_t i = 0; i < 41; i++) {
369 ASSERT_NEAR(storm::utility::convertNumber<double>(radamWalk[i].position[badCVar]), badCValuesRadam[i], 1e-5);
370 ASSERT_NEAR(storm::utility::convertNumber<double>(radamWalk[i].position[pfVar]), pfValuesRadam[i], 1e-5);
371 }
372
373 // Same thing with momentum
375 *dtmc, storm::derivative::GradientDescentMethod::MOMENTUM, 0.001, 0.9, 0.999, 2, 1e-6, std::nullopt,
377 momentumChecker.setup(this->env(), feasibilityTask);
378 auto momentumInstantiation = momentumChecker.gradientDescent();
379 auto momentumWalk = momentumChecker.getVisualizationWalk();
380
381 const double badCValuesMomentum[] = {
382 0.5 + 1e-6, 0.4990617036819458, 0.4972723126411438, 0.4947088360786438, 0.4914357662200928, 0.4875074326992035, 0.48296919465065,
383 0.4778585731983185, 0.47220611572265625, 0.4660361409187317, 0.45936742424964905, 0.45221376419067383, 0.44458451867103577, 0.43648505210876465,
384 0.42791709303855896, 0.41887906193733215, 0.4093664884567261, 0.3993722200393677, 0.3888867497444153, 0.37789851427078247, 0.3663942813873291,
385 0.3543594181537628, 0.34177839756011963, 0.32863524556159973, 0.3149142265319824, 0.30060049891471863, 0.28568127751350403, 0.27014678716659546,
386 0.253991961479187, 0.23721818625926971, 0.2198355346918106, 0.201865553855896, 0.18334446847438812, 0.16432702541351318, 0.1448907107114792,
387 0.12514044344425201, 0.10521329939365387, 0.08528289198875427, 0.06556269526481628, 0.04630732536315918, 0.027810536324977875};
388 const double pfValuesMomentum[] = {
389 0.5 + 1e-6, 0.49985647201538086, 0.49958109855651855, 0.4991863965988159, 0.49868202209472656, 0.49807605147361755, 0.49737516045570374,
390 0.49658480286598206, 0.4957093894481659, 0.49475228786468506, 0.4937160611152649, 0.49260255694389343, 0.4914129376411438, 0.49014779925346375,
391 0.4888072907924652, 0.4873911142349243, 0.4858987331390381, 0.4843292832374573, 0.4826817214488983, 0.4809550344944, 0.47914814949035645,
392 0.47726020216941833, 0.47529059648513794, 0.4732392132282257, 0.4711065888404846, 0.4688941240310669, 0.4666043817996979, 0.46424129605293274,
393 0.46181055903434753, 0.45931991934776306, 0.4567795991897583, 0.4542025327682495, 0.4516047239303589, 0.4490053057670593, 0.44642651081085205,
394 0.4438934028148651, 0.44143298268318176, 0.43907299637794495, 0.43684011697769165, 0.43475744128227234, 0.43284183740615845};
395
396 for (uint_fast64_t i = 0; i < 41; i++) {
397 ASSERT_NEAR(storm::utility::convertNumber<double>(momentumWalk[i].position[badCVar]), badCValuesMomentum[i], 1e-5);
398 ASSERT_NEAR(storm::utility::convertNumber<double>(momentumWalk[i].position[pfVar]), pfValuesMomentum[i], 1e-5);
399 }
400
401 // Same thing with nesterov
403 *dtmc, storm::derivative::GradientDescentMethod::NESTEROV, 0.001, 0.9, 0.999, 2, 1e-6, std::nullopt,
405 nesterovChecker.setup(this->env(), feasibilityTask);
406 auto nesterovInstantiation = nesterovChecker.gradientDescent();
407 auto nesterovWalk = nesterovChecker.getVisualizationWalk();
408
409 const double badCValuesNesterov[] = {
410 0.5 + 1e-6, 0.49821633100509644, 0.49565380811691284, 0.4923747181892395, 0.48843076825141907, 0.4838651120662689, 0.4787132740020752,
411 0.473004013299942, 0.4667600393295288, 0.45999863743782043, 0.452732115983963, 0.44496846199035645, 0.43671154975891113, 0.4279615581035614,
412 0.4187155067920685, 0.4089673161506653, 0.3987082839012146, 0.387927383184433, 0.376611590385437, 0.36474621295928955, 0.35231542587280273,
413 0.33930283784866333, 0.3256921172142029, 0.3114677667617798, 0.2966162860393524, 0.2811274528503418, 0.2649959325790405, 0.24822339415550232,
414 0.23082087934017181, 0.21281197667121887, 0.19423632323741913, 0.1751537024974823, 0.15564869344234467, 0.1358354091644287, 0.11586219072341919,
415 0.09591540694236755, 0.07622162997722626, 0.05704677850008011, 0.03869114816188812, 0.021478772163391113, 0.005740571767091751};
416 const double pfValuesNesterov[] = {
417 0.5 + 1e-6, 0.49972638487815857, 0.4993318021297455, 0.4988263249397278, 0.4982176721096039, 0.4975121319293976, 0.4967147409915924,
418 0.4958295524120331, 0.4948597252368927, 0.4938075542449951, 0.49267470836639404, 0.49146217107772827, 0.49017035961151123, 0.48879921436309814,
419 0.48734840750694275, 0.48581716418266296, 0.48420456051826477, 0.4825096130371094, 0.48073118925094604, 0.4788684844970703, 0.47692081332206726,
420 0.4748881161212921, 0.4727708697319031, 0.47057047486305237, 0.4682896137237549, 0.4659323990345001, 0.4635048806667328, 0.4610152840614319,
421 0.4584745764732361, 0.45589667558670044, 0.45329880714416504, 0.45070162415504456, 0.4481291174888611, 0.44560813903808594, 0.44316741824150085,
422 0.440836101770401, 0.43864157795906067, 0.4366067945957184, 0.4347473978996277, 0.433069109916687, 0.43156614899635315};
423
424 for (uint_fast64_t i = 0; i < 41; i++) {
425 ASSERT_NEAR(storm::utility::convertNumber<double>(nesterovWalk[i].position[badCVar]), badCValuesNesterov[i], 1e-5);
426 ASSERT_NEAR(storm::utility::convertNumber<double>(nesterovWalk[i].position[pfVar]), pfValuesNesterov[i], 1e-5);
427 }
428}
TYPED_TEST_SUITE(GradientDescentInstantiationSearcherTest, TestingTypes,)
TYPED_TEST(GradientDescentInstantiationSearcherTest, Simple)
SolverEnvironment & solver()
TopologicalSolverEnvironment & topological()
void setLinearEquationSolverType(storm::solver::EquationSolverType const &value, bool isSetFromDefault=false)
void setUnderlyingEquationSolverType(storm::solver::EquationSolverType value)
void setup(Environment const &env, std::shared_ptr< storm::pars::FeasibilitySynthesisTask const > const &task)
This will setup the matrices used for computing the derivatives by constructing the SparseDerivativeI...
std::vector< VisualizationPoint > getVisualizationWalk()
Get the visualization walk that is recorded if recordRun is set to true in the constructor (false by ...
std::pair< std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type >, ConstantType > gradientDescent()
Perform Gradient Descent.
std::shared_ptr< ModelType > as()
Casts the model into the model type given by the template parameter.
Definition ModelBase.h:37
This class represents a discrete-time Markov chain.
Definition Dtmc.h:13
This class performs different steps to simplify the given (parametric) model.
std::vector< storm::jani::Property > parsePropertiesForPrismProgram(std::string const &inputString, storm::prism::Program const &program, boost::optional< std::set< std::string > > const &propertyFilter)
storm::prism::Program parseProgram(std::string const &filename, bool prismCompatibility, bool simplify)
std::vector< std::shared_ptr< storm::logic::Formula const > > extractFormulasFromProperties(std::vector< storm::jani::Property > const &properties)
std::set< storm::RationalFunctionVariable > getProbabilityParameters(Model< storm::RationalFunction > const &model)
Get all probability parameters occurring on transitions.
Definition Model.cpp:695
storm::prism::Program preprocess(storm::prism::Program const &program, std::map< storm::expressions::Variable, storm::expressions::Expression > const &constantDefinitions)
Definition prism.cpp:13
carl::RationalFunction< Polynomial, true > RationalFunction
carl::MultivariatePolynomial< RationalFunctionCoefficient > RawPolynomial
::testing::Types< Cudd, Sylvan > TestingTypes
Definition GraphTest.cpp:59