Storm
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
GradientDescentInstantiationSearcherTest.cpp
Go to the documentation of this file.
1#include "storm-config.h"
2#include "test/storm_gtest.h"
3
5#include "storm/api/builder.h"
6#include "storm/api/storm.h"
18
23
29
30using namespace storm::pars;
31
32namespace {
33class RationalGmmxxEnvironment {
34 public:
35 typedef storm::RationalFunction FunctionType;
36 typedef storm::RationalNumber ConstantType;
37 static storm::Environment createEnvironment() {
39 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
40 return env;
41 }
42};
43class DoubleGmmxxEnvironment {
44 public:
45 typedef storm::RationalFunction FunctionType;
46 typedef double ConstantType;
47 static storm::Environment createEnvironment() {
49 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
50 return env;
51 }
52};
53class RationalEigenEnvironment {
54 public:
55 typedef storm::RationalFunction FunctionType;
56 typedef storm::RationalNumber ConstantType;
57 static storm::Environment createEnvironment() {
59 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
60 return env;
61 }
62};
63class DoubleEigenEnvironment {
64 public:
65 typedef storm::RationalFunction FunctionType;
66 typedef double ConstantType;
67 static storm::Environment createEnvironment() {
69 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
70 return env;
71 }
72};
73class DoubleEigenTopologicalEnvironment {
74 public:
75 typedef storm::RationalFunction FunctionType;
76 typedef double ConstantType;
77 static storm::Environment createEnvironment() {
79 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Topological);
80 env.solver().topological().setUnderlyingEquationSolverType(storm::solver::EquationSolverType::Eigen);
81 return env;
82 }
83};
84template<typename TestType>
85class GradientDescentInstantiationSearcherTest : public ::testing::Test {
86 public:
87 typedef typename TestType::FunctionType FunctionType;
88 typedef typename TestType::ConstantType ConstantType;
89 GradientDescentInstantiationSearcherTest() : _environment(TestType::createEnvironment()) {}
90 storm::Environment const& env() const {
91 return _environment;
92 }
93 virtual void SetUp() {
94#ifndef STORM_HAVE_Z3
95 GTEST_SKIP() << "Z3 not available.";
96#endif
97 carl::VariablePool::getInstance().clear();
98 }
99 virtual void TearDown() {
100 carl::VariablePool::getInstance().clear();
101 }
102
103 private:
104 storm::Environment _environment;
105};
106
107typedef ::testing::Types<
108 // The rational environments take ages... sorry, but GD is just not made for rational arithmetic.
109 DoubleGmmxxEnvironment, DoubleEigenEnvironment, DoubleEigenTopologicalEnvironment>
111} // namespace
112
113TYPED_TEST_SUITE(GradientDescentInstantiationSearcherTest, TestingTypes, );
114
115TYPED_TEST(GradientDescentInstantiationSearcherTest, Simple) {
116 std::string programFile = STORM_TEST_RESOURCES_DIR "/pdtmc/gradient1.pm";
117 std::string formulaAsString = "P>=0.2499 [F s=2]";
118 std::string constantsAsString = ""; // e.g. pL=0.9,TOACK=0.5
119
120 // Program and formula
121 storm::prism::Program program = storm::api::parseProgram(programFile);
122 program = storm::utility::prism::preprocess(program, constantsAsString);
123 std::vector<std::shared_ptr<const storm::logic::Formula>> formulas =
125 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> model =
126 storm::api::buildSparseModel<storm::RationalFunction>(program, formulas)->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
127 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> dtmc = model->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
129 ASSERT_TRUE(simplifier.simplify(*formulas[0]));
130 model = simplifier.getSimplifiedModel();
132
134
136 std::shared_ptr<storm::logic::Formula> formulaWithoutBounds = formulas[0]->clone();
137 std::shared_ptr<storm::logic::Formula const> formulaNoBound = formulaWithoutBounds->asSharedPointer();
138 std::shared_ptr<FeasibilitySynthesisTask> t = std::make_shared<FeasibilitySynthesisTask>(formulaNoBound);
139 t->setBound(formulas[0]->asOperatorFormula().getBound());
140 std::shared_ptr<FeasibilitySynthesisTask const> feasibilityTask = std::make_shared<FeasibilitySynthesisTask const>(std::move(*t));
141
142 checker.setup(this->env(), feasibilityTask);
143 typename TestFixture::ConstantType doubleInstantiation = checker.gradientDescent().second;
144 ASSERT_NEAR(doubleInstantiation * 4, 1, 1e-6);
145}
146
147TYPED_TEST(GradientDescentInstantiationSearcherTest, Crowds) {
148 std::string programFile = STORM_TEST_RESOURCES_DIR "/pdtmc/crowds3_5.pm";
149 std::string formulaAsString = "P<=0.00000001 [F \"observe0Greater1\"]";
150 std::string constantsAsString = ""; // e.g. pL=0.9,TOACK=0.5
151
152 // Program and formula
153 storm::prism::Program program = storm::api::parseProgram(programFile);
154 program = storm::utility::prism::preprocess(program, constantsAsString);
155 std::vector<std::shared_ptr<const storm::logic::Formula>> formulas =
157 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> model =
158 storm::api::buildSparseModel<storm::RationalFunction>(program, formulas)->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
159 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> dtmc = model->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
161 ASSERT_TRUE(simplifier.simplify(*formulas[0]));
162 model = simplifier.getSimplifiedModel();
163
164 dtmc = storm::api::performBisimulationMinimization<storm::RationalFunction>(model, formulas)->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
165
166 std::shared_ptr<storm::logic::Formula> formulaWithoutBounds = formulas[0]->clone();
167 std::shared_ptr<storm::logic::Formula const> formulaNoBound = formulaWithoutBounds->asSharedPointer();
168 std::shared_ptr<FeasibilitySynthesisTask> t = std::make_shared<FeasibilitySynthesisTask>(formulaNoBound);
169 t->setBound(formulas[0]->asOperatorFormula().getBound());
170 std::shared_ptr<FeasibilitySynthesisTask const> feasibilityTask = std::make_shared<FeasibilitySynthesisTask const>(std::move(*t));
171
173
174 // First, test an ADAM instance. We will check that we have implemented ADAM correctly by comparing our results to results gathered by an ADAM
175 // implementation in tensorflow :)
177 *dtmc, storm::derivative::GradientDescentMethod::ADAM, 0.01, 0.9, 0.999, 2, 1e-6, boost::none,
179 adamChecker.setup(this->env(), feasibilityTask);
180 auto doubleInstantiation = adamChecker.gradientDescent();
181 auto walk = adamChecker.getVisualizationWalk();
182
183 carl::Variable badCVar;
184 carl::Variable pfVar;
185 for (auto parameter : storm::models::sparse::getProbabilityParameters(*dtmc)) {
186 if (parameter.name() == "badC") {
187 badCVar = parameter;
188 } else if (parameter.name() == "PF") {
189 pfVar = parameter;
190 }
191 }
192 std::shared_ptr<storm::RawPolynomialCache> cache = std::make_shared<storm::RawPolynomialCache>();
193 auto const badC = storm::RationalFunction(storm::Polynomial(storm::RawPolynomial(badCVar), cache));
194 auto const pf = storm::RationalFunction(storm::Polynomial(storm::RawPolynomial(pfVar), cache));
195
196 const double badCValues[] = {0.5,
197 0.49000033736228943,
198 0.4799894690513611,
199 0.4699603021144867,
200 0.4599062204360962,
201 0.4498208165168762,
202 0.43969807028770447,
203 0.4295324981212616,
204 0.41931894421577454,
205 0.4090527594089508,
206 0.39872977137565613,
207 0.3883463442325592,
208 0.3778994083404541,
209 0.3673863708972931,
210 0.35680532455444336,
211 0.3461548984050751,
212 0.3354344069957733,
213 0.3246437907218933,
214 0.31378373503685,
215 0.30285564064979553,
216 0.2918616533279419,
217 0.28080472350120544,
218 0.2696886956691742,
219 0.258518248796463,
220 0.24729910492897034,
221 0.23603792488574982,
222 0.22474248707294464,
223 0.21342173218727112,
224 0.20208582282066345,
225 0.1907462775707245,
226 0.17941603064537048,
227 0.16810956597328186,
228 0.15684303641319275,
229 0.14563430845737457,
230 0.13450318574905396,
231 0.1234714537858963,
232 0.11256305128335953,
233 0.10180411487817764,
234 0.0912230908870697,
235 0.080850750207901,
236 0.07072017341852188};
237 const double pfValues[] = {0.5,
238 0.49000218510627747,
239 0.47999337315559387,
240 0.46996763348579407,
241 0.45991936326026917,
242 0.4498434364795685,
243 0.4397353231906891,
244 0.4295912981033325,
245 0.4194081425666809,
246 0.4091835021972656,
247 0.3989158570766449,
248 0.3886045515537262,
249 0.3782498240470886,
250 0.3678528368473053,
251 0.3574157953262329,
252 0.3469419777393341,
253 0.33643582463264465,
254 0.32590293884277344,
255 0.3153502345085144,
256 0.30478596687316895,
257 0.2942197024822235,
258 0.28366249799728394,
259 0.2731269598007202,
260 0.26262718439102173,
261 0.25217896699905396,
262 0.24179960787296295,
263 0.2315080612897873,
264 0.2213248461484909,
265 0.2112719565629959,
266 0.20137275755405426,
267 0.1916518211364746,
268 0.18213464319705963,
269 0.17284737527370453,
270 0.1638164073228836,
271 0.15506798028945923,
272 0.14662761986255646,
273 0.13851958513259888,
274 0.13076619803905487,
275 0.12338729947805405,
276 0.11639954894781113,
277 0.10981585085391998};
278
279 for (uint_fast64_t i = 0; i < 41; i++) {
280 ASSERT_NEAR(storm::utility::convertNumber<double>(walk[i].position[badCVar]), badCValues[i], 1e-4);
281 ASSERT_NEAR(storm::utility::convertNumber<double>(walk[i].position[pfVar]), pfValues[i], 1e-4);
282 }
283
284 ASSERT_NEAR(storm::utility::convertNumber<double>(doubleInstantiation.second), 0, 1e-6);
285
286 // Same thing with RAdam
288 *dtmc, storm::derivative::GradientDescentMethod::RADAM, 0.01, 0.9, 0.999, 2, 1e-6, boost::none,
290 radamChecker.setup(this->env(), feasibilityTask);
291 auto radamInstantiation = radamChecker.gradientDescent();
292 auto radamWalk = radamChecker.getVisualizationWalk();
293
294 const double badCValuesRadam[] = {0.5,
295 0.49060654640197754,
296 0.48096320033073425,
297 0.47105303406715393,
298 0.4608582556247711,
299 0.46068474650382996,
300 0.4604234993457794,
301 0.460092693567276,
302 0.4597006142139435,
303 0.4592539370059967,
304 0.4587569832801819,
305 0.45821380615234375,
306 0.4576275646686554,
307 0.4570005536079407,
308 0.45633500814437866,
309 0.4556325674057007,
310 0.45489493012428284,
311 0.454123318195343,
312 0.4533190429210663,
313 0.45248323678970337,
314 0.45161670446395874,
315 0.4507202208042145,
316 0.44979462027549744,
317 0.448840469121933,
318 0.44785845279693604,
319 0.4468490481376648,
320 0.4458127021789551,
321 0.4447498917579651,
322 0.44366100430488586,
323 0.4425463378429413,
324 0.4414062798023224,
325 0.4402410686016083,
326 0.4390510618686676,
327 0.43783634901046753,
328 0.43659722805023193,
329 0.43533384799957275,
330 0.4340464472770691,
331 0.4327350854873657,
332 0.43139997124671936,
333 0.43004119396209717,
334 0.4286588430404663};
335
336 const double pfValuesRadam[] = {0.5,
337 0.4985547959804535,
338 0.4970662295818329,
339 0.4955315589904785,
340 0.4939480423927307,
341 0.4937744438648224,
342 0.4935130178928375,
343 0.4931819438934326,
344 0.49278953671455383,
345 0.492342472076416,
346 0.49184510111808777,
347 0.4913014769554138,
348 0.49071478843688965,
349 0.4900873303413391,
350 0.48942139744758606,
351 0.4887186288833618,
352 0.48798075318336487,
353 0.4872090220451355,
354 0.48640477657318115,
355 0.48556917905807495,
356 0.48470309376716614,
357 0.4838073253631592,
358 0.48288270831108093,
359 0.48192986845970154,
360 0.4809495508670807,
361 0.4799422323703766,
362 0.47890838980674744,
363 0.4778485894203186,
364 0.4767632484436035,
365 0.4756527245044708,
366 0.47451743483543396,
367 0.4733576774597168,
368 0.4721738398075104,
369 0.470966100692749,
370 0.4697347581386566,
371 0.46848005056381226,
372 0.4672022759914398,
373 0.46590155363082886,
374 0.46457815170288086,
375 0.46323221921920776,
376 0.4618639349937439};
377
378 for (uint_fast64_t i = 0; i < 41; i++) {
379 ASSERT_NEAR(storm::utility::convertNumber<double>(radamWalk[i].position[badCVar]), badCValuesRadam[i], 1e-5);
380 ASSERT_NEAR(storm::utility::convertNumber<double>(radamWalk[i].position[pfVar]), pfValuesRadam[i], 1e-5);
381 }
382
383 // Same thing with momentum
385 *dtmc, storm::derivative::GradientDescentMethod::MOMENTUM, 0.001, 0.9, 0.999, 2, 1e-6, boost::none,
387 momentumChecker.setup(this->env(), feasibilityTask);
388 auto momentumInstantiation = momentumChecker.gradientDescent();
389 auto momentumWalk = momentumChecker.getVisualizationWalk();
390
391 const double badCValuesMomentum[] = {
392 0.5 + 1e-6, 0.4990617036819458, 0.4972723126411438, 0.4947088360786438, 0.4914357662200928, 0.4875074326992035, 0.48296919465065,
393 0.4778585731983185, 0.47220611572265625, 0.4660361409187317, 0.45936742424964905, 0.45221376419067383, 0.44458451867103577, 0.43648505210876465,
394 0.42791709303855896, 0.41887906193733215, 0.4093664884567261, 0.3993722200393677, 0.3888867497444153, 0.37789851427078247, 0.3663942813873291,
395 0.3543594181537628, 0.34177839756011963, 0.32863524556159973, 0.3149142265319824, 0.30060049891471863, 0.28568127751350403, 0.27014678716659546,
396 0.253991961479187, 0.23721818625926971, 0.2198355346918106, 0.201865553855896, 0.18334446847438812, 0.16432702541351318, 0.1448907107114792,
397 0.12514044344425201, 0.10521329939365387, 0.08528289198875427, 0.06556269526481628, 0.04630732536315918, 0.027810536324977875};
398 const double pfValuesMomentum[] = {
399 0.5 + 1e-6, 0.49985647201538086, 0.49958109855651855, 0.4991863965988159, 0.49868202209472656, 0.49807605147361755, 0.49737516045570374,
400 0.49658480286598206, 0.4957093894481659, 0.49475228786468506, 0.4937160611152649, 0.49260255694389343, 0.4914129376411438, 0.49014779925346375,
401 0.4888072907924652, 0.4873911142349243, 0.4858987331390381, 0.4843292832374573, 0.4826817214488983, 0.4809550344944, 0.47914814949035645,
402 0.47726020216941833, 0.47529059648513794, 0.4732392132282257, 0.4711065888404846, 0.4688941240310669, 0.4666043817996979, 0.46424129605293274,
403 0.46181055903434753, 0.45931991934776306, 0.4567795991897583, 0.4542025327682495, 0.4516047239303589, 0.4490053057670593, 0.44642651081085205,
404 0.4438934028148651, 0.44143298268318176, 0.43907299637794495, 0.43684011697769165, 0.43475744128227234, 0.43284183740615845};
405
406 for (uint_fast64_t i = 0; i < 41; i++) {
407 ASSERT_NEAR(storm::utility::convertNumber<double>(momentumWalk[i].position[badCVar]), badCValuesMomentum[i], 1e-5);
408 ASSERT_NEAR(storm::utility::convertNumber<double>(momentumWalk[i].position[pfVar]), pfValuesMomentum[i], 1e-5);
409 }
410
411 // Same thing with nesterov
413 *dtmc, storm::derivative::GradientDescentMethod::NESTEROV, 0.001, 0.9, 0.999, 2, 1e-6, boost::none,
415 nesterovChecker.setup(this->env(), feasibilityTask);
416 auto nesterovInstantiation = nesterovChecker.gradientDescent();
417 auto nesterovWalk = nesterovChecker.getVisualizationWalk();
418
419 const double badCValuesNesterov[] = {
420 0.5 + 1e-6, 0.49821633100509644, 0.49565380811691284, 0.4923747181892395, 0.48843076825141907, 0.4838651120662689, 0.4787132740020752,
421 0.473004013299942, 0.4667600393295288, 0.45999863743782043, 0.452732115983963, 0.44496846199035645, 0.43671154975891113, 0.4279615581035614,
422 0.4187155067920685, 0.4089673161506653, 0.3987082839012146, 0.387927383184433, 0.376611590385437, 0.36474621295928955, 0.35231542587280273,
423 0.33930283784866333, 0.3256921172142029, 0.3114677667617798, 0.2966162860393524, 0.2811274528503418, 0.2649959325790405, 0.24822339415550232,
424 0.23082087934017181, 0.21281197667121887, 0.19423632323741913, 0.1751537024974823, 0.15564869344234467, 0.1358354091644287, 0.11586219072341919,
425 0.09591540694236755, 0.07622162997722626, 0.05704677850008011, 0.03869114816188812, 0.021478772163391113, 0.005740571767091751};
426 const double pfValuesNesterov[] = {
427 0.5 + 1e-6, 0.49972638487815857, 0.4993318021297455, 0.4988263249397278, 0.4982176721096039, 0.4975121319293976, 0.4967147409915924,
428 0.4958295524120331, 0.4948597252368927, 0.4938075542449951, 0.49267470836639404, 0.49146217107772827, 0.49017035961151123, 0.48879921436309814,
429 0.48734840750694275, 0.48581716418266296, 0.48420456051826477, 0.4825096130371094, 0.48073118925094604, 0.4788684844970703, 0.47692081332206726,
430 0.4748881161212921, 0.4727708697319031, 0.47057047486305237, 0.4682896137237549, 0.4659323990345001, 0.4635048806667328, 0.4610152840614319,
431 0.4584745764732361, 0.45589667558670044, 0.45329880714416504, 0.45070162415504456, 0.4481291174888611, 0.44560813903808594, 0.44316741824150085,
432 0.440836101770401, 0.43864157795906067, 0.4366067945957184, 0.4347473978996277, 0.433069109916687, 0.43156614899635315};
433
434 for (uint_fast64_t i = 0; i < 41; i++) {
435 ASSERT_NEAR(storm::utility::convertNumber<double>(nesterovWalk[i].position[badCVar]), badCValuesNesterov[i], 1e-5);
436 ASSERT_NEAR(storm::utility::convertNumber<double>(nesterovWalk[i].position[pfVar]), pfValuesNesterov[i], 1e-5);
437 }
438}
TYPED_TEST_SUITE(GradientDescentInstantiationSearcherTest, TestingTypes,)
TYPED_TEST(GradientDescentInstantiationSearcherTest, Simple)
SolverEnvironment & solver()
TopologicalSolverEnvironment & topological()
void setLinearEquationSolverType(storm::solver::EquationSolverType const &value, bool isSetFromDefault=false)
void setUnderlyingEquationSolverType(storm::solver::EquationSolverType value)
void setup(Environment const &env, std::shared_ptr< storm::pars::FeasibilitySynthesisTask const > const &task)
This will setup the matrices used for computing the derivatives by constructing the SparseDerivativeI...
std::vector< VisualizationPoint > getVisualizationWalk()
Get the visualization walk that is recorded if recordRun is set to true in the constructor (false by ...
std::pair< std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type >, ConstantType > gradientDescent()
Perform Gradient Descent.
std::shared_ptr< ModelType > as()
Casts the model into the model type given by the template parameter.
Definition ModelBase.h:36
This class represents a discrete-time Markov chain.
Definition Dtmc.h:14
This class performs different steps to simplify the given (parametric) model.
std::vector< storm::jani::Property > parsePropertiesForPrismProgram(std::string const &inputString, storm::prism::Program const &program, boost::optional< std::set< std::string > > const &propertyFilter)
storm::prism::Program parseProgram(std::string const &filename, bool prismCompatibility, bool simplify)
std::vector< std::shared_ptr< storm::logic::Formula const > > extractFormulasFromProperties(std::vector< storm::jani::Property > const &properties)
std::set< storm::RationalFunctionVariable > getProbabilityParameters(Model< storm::RationalFunction > const &model)
Get all probability parameters occurring on transitions.
Definition Model.cpp:703
storm::prism::Program preprocess(storm::prism::Program const &program, std::map< storm::expressions::Variable, storm::expressions::Expression > const &constantDefinitions)
Definition prism.cpp:19
carl::RationalFunction< Polynomial, true > RationalFunction
carl::MultivariatePolynomial< RationalFunctionCoefficient > RawPolynomial
::testing::Types< Cudd, Sylvan > TestingTypes
Definition GraphTest.cpp:46