Storm 1.10.0.1
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
GradientDescentInstantiationSearcherTest.cpp
Go to the documentation of this file.
1#include "storm-config.h"
2#include "test/storm_gtest.h"
3
5#include "storm/api/builder.h"
6#include "storm/api/storm.h"
18
23
29
30using namespace storm::pars;
31
32namespace {
33class RationalGmmxxEnvironment {
34 public:
35 typedef storm::RationalFunction FunctionType;
36 typedef storm::RationalNumber ConstantType;
37 static storm::Environment createEnvironment() {
39 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
40 return env;
41 }
42};
43class DoubleGmmxxEnvironment {
44 public:
45 typedef storm::RationalFunction FunctionType;
46 typedef double ConstantType;
47 static storm::Environment createEnvironment() {
49 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
50 return env;
51 }
52};
53class RationalEigenEnvironment {
54 public:
55 typedef storm::RationalFunction FunctionType;
56 typedef storm::RationalNumber ConstantType;
57 static storm::Environment createEnvironment() {
59 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
60 return env;
61 }
62};
63class DoubleEigenEnvironment {
64 public:
65 typedef storm::RationalFunction FunctionType;
66 typedef double ConstantType;
67 static storm::Environment createEnvironment() {
69 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
70 return env;
71 }
72};
73class DoubleEigenTopologicalEnvironment {
74 public:
75 typedef storm::RationalFunction FunctionType;
76 typedef double ConstantType;
77 static storm::Environment createEnvironment() {
79 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Topological);
80 env.solver().topological().setUnderlyingEquationSolverType(storm::solver::EquationSolverType::Eigen);
81 return env;
82 }
83};
84template<typename TestType>
85class GradientDescentInstantiationSearcherTest : public ::testing::Test {
86 public:
87 typedef typename TestType::FunctionType FunctionType;
88 typedef typename TestType::ConstantType ConstantType;
89 GradientDescentInstantiationSearcherTest() : _environment(TestType::createEnvironment()) {}
90 storm::Environment const& env() const {
91 return _environment;
92 }
93 virtual void SetUp() {
94#ifndef STORM_HAVE_Z3
95 GTEST_SKIP() << "Z3 not available.";
96#endif
97 carl::VariablePool::getInstance().clear();
98 }
99 virtual void TearDown() {
100 carl::VariablePool::getInstance().clear();
101 }
102
103 private:
104 storm::Environment _environment;
105};
106
107typedef ::testing::Types<
108// The rational environments take ages... GD is just not made for rational arithmetic.
109#ifdef STORM_HAVE_GMM
110 DoubleGmmxxEnvironment,
111#endif
112 DoubleEigenEnvironment, DoubleEigenTopologicalEnvironment>
114} // namespace
115
116TYPED_TEST_SUITE(GradientDescentInstantiationSearcherTest, TestingTypes, );
117
118TYPED_TEST(GradientDescentInstantiationSearcherTest, Simple) {
119 std::string programFile = STORM_TEST_RESOURCES_DIR "/pdtmc/gradient1.pm";
120 std::string formulaAsString = "P>=0.2499 [F s=2]";
121 std::string constantsAsString = ""; // e.g. pL=0.9,TOACK=0.5
122
123 // Program and formula
124 storm::prism::Program program = storm::api::parseProgram(programFile);
125 program = storm::utility::prism::preprocess(program, constantsAsString);
126 std::vector<std::shared_ptr<const storm::logic::Formula>> formulas =
128 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> model =
129 storm::api::buildSparseModel<storm::RationalFunction>(program, formulas)->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
130 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> dtmc = model->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
132 ASSERT_TRUE(simplifier.simplify(*formulas[0]));
133 model = simplifier.getSimplifiedModel();
135
137
139 std::shared_ptr<storm::logic::Formula> formulaWithoutBounds = formulas[0]->clone();
140 std::shared_ptr<storm::logic::Formula const> formulaNoBound = formulaWithoutBounds->asSharedPointer();
141 std::shared_ptr<FeasibilitySynthesisTask> t = std::make_shared<FeasibilitySynthesisTask>(formulaNoBound);
142 t->setBound(formulas[0]->asOperatorFormula().getBound());
143 std::shared_ptr<FeasibilitySynthesisTask const> feasibilityTask = std::make_shared<FeasibilitySynthesisTask const>(std::move(*t));
144
145 checker.setup(this->env(), feasibilityTask);
146 typename TestFixture::ConstantType doubleInstantiation = checker.gradientDescent().second;
147 ASSERT_NEAR(doubleInstantiation * 4, 1, 1e-6);
148}
149
150TYPED_TEST(GradientDescentInstantiationSearcherTest, Crowds) {
151 std::string programFile = STORM_TEST_RESOURCES_DIR "/pdtmc/crowds3_5.pm";
152 std::string formulaAsString = "P<=0.00000001 [F \"observe0Greater1\"]";
153 std::string constantsAsString = ""; // e.g. pL=0.9,TOACK=0.5
154
155 // Program and formula
156 storm::prism::Program program = storm::api::parseProgram(programFile);
157 program = storm::utility::prism::preprocess(program, constantsAsString);
158 std::vector<std::shared_ptr<const storm::logic::Formula>> formulas =
160 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> model =
161 storm::api::buildSparseModel<storm::RationalFunction>(program, formulas)->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
162 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> dtmc = model->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
164 ASSERT_TRUE(simplifier.simplify(*formulas[0]));
165 model = simplifier.getSimplifiedModel();
166
167 dtmc = storm::api::performBisimulationMinimization<storm::RationalFunction>(model, formulas)->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
168
169 std::shared_ptr<storm::logic::Formula> formulaWithoutBounds = formulas[0]->clone();
170 std::shared_ptr<storm::logic::Formula const> formulaNoBound = formulaWithoutBounds->asSharedPointer();
171 std::shared_ptr<FeasibilitySynthesisTask> t = std::make_shared<FeasibilitySynthesisTask>(formulaNoBound);
172 t->setBound(formulas[0]->asOperatorFormula().getBound());
173 std::shared_ptr<FeasibilitySynthesisTask const> feasibilityTask = std::make_shared<FeasibilitySynthesisTask const>(std::move(*t));
174
176
177 // First, test an ADAM instance. We will check that we have implemented ADAM correctly by comparing our results to results gathered by an ADAM
178 // implementation in tensorflow :)
180 *dtmc, storm::derivative::GradientDescentMethod::ADAM, 0.01, 0.9, 0.999, 2, 1e-6, boost::none,
182 adamChecker.setup(this->env(), feasibilityTask);
183 auto doubleInstantiation = adamChecker.gradientDescent();
184 auto walk = adamChecker.getVisualizationWalk();
185
186 carl::Variable badCVar;
187 carl::Variable pfVar;
188 for (auto parameter : storm::models::sparse::getProbabilityParameters(*dtmc)) {
189 if (parameter.name() == "badC") {
190 badCVar = parameter;
191 } else if (parameter.name() == "PF") {
192 pfVar = parameter;
193 }
194 }
195 std::shared_ptr<storm::RawPolynomialCache> cache = std::make_shared<storm::RawPolynomialCache>();
196 auto const badC = storm::RationalFunction(storm::Polynomial(storm::RawPolynomial(badCVar), cache));
197 auto const pf = storm::RationalFunction(storm::Polynomial(storm::RawPolynomial(pfVar), cache));
198
199 const double badCValues[] = {0.5,
200 0.49000033736228943,
201 0.4799894690513611,
202 0.4699603021144867,
203 0.4599062204360962,
204 0.4498208165168762,
205 0.43969807028770447,
206 0.4295324981212616,
207 0.41931894421577454,
208 0.4090527594089508,
209 0.39872977137565613,
210 0.3883463442325592,
211 0.3778994083404541,
212 0.3673863708972931,
213 0.35680532455444336,
214 0.3461548984050751,
215 0.3354344069957733,
216 0.3246437907218933,
217 0.31378373503685,
218 0.30285564064979553,
219 0.2918616533279419,
220 0.28080472350120544,
221 0.2696886956691742,
222 0.258518248796463,
223 0.24729910492897034,
224 0.23603792488574982,
225 0.22474248707294464,
226 0.21342173218727112,
227 0.20208582282066345,
228 0.1907462775707245,
229 0.17941603064537048,
230 0.16810956597328186,
231 0.15684303641319275,
232 0.14563430845737457,
233 0.13450318574905396,
234 0.1234714537858963,
235 0.11256305128335953,
236 0.10180411487817764,
237 0.0912230908870697,
238 0.080850750207901,
239 0.07072017341852188};
240 const double pfValues[] = {0.5,
241 0.49000218510627747,
242 0.47999337315559387,
243 0.46996763348579407,
244 0.45991936326026917,
245 0.4498434364795685,
246 0.4397353231906891,
247 0.4295912981033325,
248 0.4194081425666809,
249 0.4091835021972656,
250 0.3989158570766449,
251 0.3886045515537262,
252 0.3782498240470886,
253 0.3678528368473053,
254 0.3574157953262329,
255 0.3469419777393341,
256 0.33643582463264465,
257 0.32590293884277344,
258 0.3153502345085144,
259 0.30478596687316895,
260 0.2942197024822235,
261 0.28366249799728394,
262 0.2731269598007202,
263 0.26262718439102173,
264 0.25217896699905396,
265 0.24179960787296295,
266 0.2315080612897873,
267 0.2213248461484909,
268 0.2112719565629959,
269 0.20137275755405426,
270 0.1916518211364746,
271 0.18213464319705963,
272 0.17284737527370453,
273 0.1638164073228836,
274 0.15506798028945923,
275 0.14662761986255646,
276 0.13851958513259888,
277 0.13076619803905487,
278 0.12338729947805405,
279 0.11639954894781113,
280 0.10981585085391998};
281
282 for (uint_fast64_t i = 0; i < 41; i++) {
283 ASSERT_NEAR(storm::utility::convertNumber<double>(walk[i].position[badCVar]), badCValues[i], 1e-4);
284 ASSERT_NEAR(storm::utility::convertNumber<double>(walk[i].position[pfVar]), pfValues[i], 1e-4);
285 }
286
287 ASSERT_NEAR(storm::utility::convertNumber<double>(doubleInstantiation.second), 0, 1e-6);
288
289 // Same thing with RAdam
291 *dtmc, storm::derivative::GradientDescentMethod::RADAM, 0.01, 0.9, 0.999, 2, 1e-6, boost::none,
293 radamChecker.setup(this->env(), feasibilityTask);
294 auto radamInstantiation = radamChecker.gradientDescent();
295 auto radamWalk = radamChecker.getVisualizationWalk();
296
297 const double badCValuesRadam[] = {0.5,
298 0.49060654640197754,
299 0.48096320033073425,
300 0.47105303406715393,
301 0.4608582556247711,
302 0.46068474650382996,
303 0.4604234993457794,
304 0.460092693567276,
305 0.4597006142139435,
306 0.4592539370059967,
307 0.4587569832801819,
308 0.45821380615234375,
309 0.4576275646686554,
310 0.4570005536079407,
311 0.45633500814437866,
312 0.4556325674057007,
313 0.45489493012428284,
314 0.454123318195343,
315 0.4533190429210663,
316 0.45248323678970337,
317 0.45161670446395874,
318 0.4507202208042145,
319 0.44979462027549744,
320 0.448840469121933,
321 0.44785845279693604,
322 0.4468490481376648,
323 0.4458127021789551,
324 0.4447498917579651,
325 0.44366100430488586,
326 0.4425463378429413,
327 0.4414062798023224,
328 0.4402410686016083,
329 0.4390510618686676,
330 0.43783634901046753,
331 0.43659722805023193,
332 0.43533384799957275,
333 0.4340464472770691,
334 0.4327350854873657,
335 0.43139997124671936,
336 0.43004119396209717,
337 0.4286588430404663};
338
339 const double pfValuesRadam[] = {0.5,
340 0.4985547959804535,
341 0.4970662295818329,
342 0.4955315589904785,
343 0.4939480423927307,
344 0.4937744438648224,
345 0.4935130178928375,
346 0.4931819438934326,
347 0.49278953671455383,
348 0.492342472076416,
349 0.49184510111808777,
350 0.4913014769554138,
351 0.49071478843688965,
352 0.4900873303413391,
353 0.48942139744758606,
354 0.4887186288833618,
355 0.48798075318336487,
356 0.4872090220451355,
357 0.48640477657318115,
358 0.48556917905807495,
359 0.48470309376716614,
360 0.4838073253631592,
361 0.48288270831108093,
362 0.48192986845970154,
363 0.4809495508670807,
364 0.4799422323703766,
365 0.47890838980674744,
366 0.4778485894203186,
367 0.4767632484436035,
368 0.4756527245044708,
369 0.47451743483543396,
370 0.4733576774597168,
371 0.4721738398075104,
372 0.470966100692749,
373 0.4697347581386566,
374 0.46848005056381226,
375 0.4672022759914398,
376 0.46590155363082886,
377 0.46457815170288086,
378 0.46323221921920776,
379 0.4618639349937439};
380
381 for (uint_fast64_t i = 0; i < 41; i++) {
382 ASSERT_NEAR(storm::utility::convertNumber<double>(radamWalk[i].position[badCVar]), badCValuesRadam[i], 1e-5);
383 ASSERT_NEAR(storm::utility::convertNumber<double>(radamWalk[i].position[pfVar]), pfValuesRadam[i], 1e-5);
384 }
385
386 // Same thing with momentum
388 *dtmc, storm::derivative::GradientDescentMethod::MOMENTUM, 0.001, 0.9, 0.999, 2, 1e-6, boost::none,
390 momentumChecker.setup(this->env(), feasibilityTask);
391 auto momentumInstantiation = momentumChecker.gradientDescent();
392 auto momentumWalk = momentumChecker.getVisualizationWalk();
393
394 const double badCValuesMomentum[] = {
395 0.5 + 1e-6, 0.4990617036819458, 0.4972723126411438, 0.4947088360786438, 0.4914357662200928, 0.4875074326992035, 0.48296919465065,
396 0.4778585731983185, 0.47220611572265625, 0.4660361409187317, 0.45936742424964905, 0.45221376419067383, 0.44458451867103577, 0.43648505210876465,
397 0.42791709303855896, 0.41887906193733215, 0.4093664884567261, 0.3993722200393677, 0.3888867497444153, 0.37789851427078247, 0.3663942813873291,
398 0.3543594181537628, 0.34177839756011963, 0.32863524556159973, 0.3149142265319824, 0.30060049891471863, 0.28568127751350403, 0.27014678716659546,
399 0.253991961479187, 0.23721818625926971, 0.2198355346918106, 0.201865553855896, 0.18334446847438812, 0.16432702541351318, 0.1448907107114792,
400 0.12514044344425201, 0.10521329939365387, 0.08528289198875427, 0.06556269526481628, 0.04630732536315918, 0.027810536324977875};
401 const double pfValuesMomentum[] = {
402 0.5 + 1e-6, 0.49985647201538086, 0.49958109855651855, 0.4991863965988159, 0.49868202209472656, 0.49807605147361755, 0.49737516045570374,
403 0.49658480286598206, 0.4957093894481659, 0.49475228786468506, 0.4937160611152649, 0.49260255694389343, 0.4914129376411438, 0.49014779925346375,
404 0.4888072907924652, 0.4873911142349243, 0.4858987331390381, 0.4843292832374573, 0.4826817214488983, 0.4809550344944, 0.47914814949035645,
405 0.47726020216941833, 0.47529059648513794, 0.4732392132282257, 0.4711065888404846, 0.4688941240310669, 0.4666043817996979, 0.46424129605293274,
406 0.46181055903434753, 0.45931991934776306, 0.4567795991897583, 0.4542025327682495, 0.4516047239303589, 0.4490053057670593, 0.44642651081085205,
407 0.4438934028148651, 0.44143298268318176, 0.43907299637794495, 0.43684011697769165, 0.43475744128227234, 0.43284183740615845};
408
409 for (uint_fast64_t i = 0; i < 41; i++) {
410 ASSERT_NEAR(storm::utility::convertNumber<double>(momentumWalk[i].position[badCVar]), badCValuesMomentum[i], 1e-5);
411 ASSERT_NEAR(storm::utility::convertNumber<double>(momentumWalk[i].position[pfVar]), pfValuesMomentum[i], 1e-5);
412 }
413
414 // Same thing with nesterov
416 *dtmc, storm::derivative::GradientDescentMethod::NESTEROV, 0.001, 0.9, 0.999, 2, 1e-6, boost::none,
418 nesterovChecker.setup(this->env(), feasibilityTask);
419 auto nesterovInstantiation = nesterovChecker.gradientDescent();
420 auto nesterovWalk = nesterovChecker.getVisualizationWalk();
421
422 const double badCValuesNesterov[] = {
423 0.5 + 1e-6, 0.49821633100509644, 0.49565380811691284, 0.4923747181892395, 0.48843076825141907, 0.4838651120662689, 0.4787132740020752,
424 0.473004013299942, 0.4667600393295288, 0.45999863743782043, 0.452732115983963, 0.44496846199035645, 0.43671154975891113, 0.4279615581035614,
425 0.4187155067920685, 0.4089673161506653, 0.3987082839012146, 0.387927383184433, 0.376611590385437, 0.36474621295928955, 0.35231542587280273,
426 0.33930283784866333, 0.3256921172142029, 0.3114677667617798, 0.2966162860393524, 0.2811274528503418, 0.2649959325790405, 0.24822339415550232,
427 0.23082087934017181, 0.21281197667121887, 0.19423632323741913, 0.1751537024974823, 0.15564869344234467, 0.1358354091644287, 0.11586219072341919,
428 0.09591540694236755, 0.07622162997722626, 0.05704677850008011, 0.03869114816188812, 0.021478772163391113, 0.005740571767091751};
429 const double pfValuesNesterov[] = {
430 0.5 + 1e-6, 0.49972638487815857, 0.4993318021297455, 0.4988263249397278, 0.4982176721096039, 0.4975121319293976, 0.4967147409915924,
431 0.4958295524120331, 0.4948597252368927, 0.4938075542449951, 0.49267470836639404, 0.49146217107772827, 0.49017035961151123, 0.48879921436309814,
432 0.48734840750694275, 0.48581716418266296, 0.48420456051826477, 0.4825096130371094, 0.48073118925094604, 0.4788684844970703, 0.47692081332206726,
433 0.4748881161212921, 0.4727708697319031, 0.47057047486305237, 0.4682896137237549, 0.4659323990345001, 0.4635048806667328, 0.4610152840614319,
434 0.4584745764732361, 0.45589667558670044, 0.45329880714416504, 0.45070162415504456, 0.4481291174888611, 0.44560813903808594, 0.44316741824150085,
435 0.440836101770401, 0.43864157795906067, 0.4366067945957184, 0.4347473978996277, 0.433069109916687, 0.43156614899635315};
436
437 for (uint_fast64_t i = 0; i < 41; i++) {
438 ASSERT_NEAR(storm::utility::convertNumber<double>(nesterovWalk[i].position[badCVar]), badCValuesNesterov[i], 1e-5);
439 ASSERT_NEAR(storm::utility::convertNumber<double>(nesterovWalk[i].position[pfVar]), pfValuesNesterov[i], 1e-5);
440 }
441}
TYPED_TEST_SUITE(GradientDescentInstantiationSearcherTest, TestingTypes,)
TYPED_TEST(GradientDescentInstantiationSearcherTest, Simple)
SolverEnvironment & solver()
TopologicalSolverEnvironment & topological()
void setLinearEquationSolverType(storm::solver::EquationSolverType const &value, bool isSetFromDefault=false)
void setUnderlyingEquationSolverType(storm::solver::EquationSolverType value)
void setup(Environment const &env, std::shared_ptr< storm::pars::FeasibilitySynthesisTask const > const &task)
This will setup the matrices used for computing the derivatives by constructing the SparseDerivativeI...
std::vector< VisualizationPoint > getVisualizationWalk()
Get the visualization walk that is recorded if recordRun is set to true in the constructor (false by ...
std::pair< std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type >, ConstantType > gradientDescent()
Perform Gradient Descent.
std::shared_ptr< ModelType > as()
Casts the model into the model type given by the template parameter.
Definition ModelBase.h:36
This class represents a discrete-time Markov chain.
Definition Dtmc.h:14
This class performs different steps to simplify the given (parametric) model.
std::vector< storm::jani::Property > parsePropertiesForPrismProgram(std::string const &inputString, storm::prism::Program const &program, boost::optional< std::set< std::string > > const &propertyFilter)
storm::prism::Program parseProgram(std::string const &filename, bool prismCompatibility, bool simplify)
std::vector< std::shared_ptr< storm::logic::Formula const > > extractFormulasFromProperties(std::vector< storm::jani::Property > const &properties)
std::set< storm::RationalFunctionVariable > getProbabilityParameters(Model< storm::RationalFunction > const &model)
Get all probability parameters occurring on transitions.
Definition Model.cpp:703
storm::prism::Program preprocess(storm::prism::Program const &program, std::map< storm::expressions::Variable, storm::expressions::Expression > const &constantDefinitions)
Definition prism.cpp:19
carl::RationalFunction< Polynomial, true > RationalFunction
carl::MultivariatePolynomial< RationalFunctionCoefficient > RawPolynomial
::testing::Types< Cudd, Sylvan > TestingTypes
Definition GraphTest.cpp:46