Storm 1.11.1.1
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
GradientDescentInstantiationSearcherTest.cpp
Go to the documentation of this file.
1#include "storm-config.h"
2#include "test/storm_gtest.h"
3
14#include "storm/api/builder.h"
15#include "storm/api/storm.h"
27
28using namespace storm::pars;
29
30namespace {
31class RationalGmmxxEnvironment {
32 public:
33 typedef storm::RationalFunction FunctionType;
34 typedef storm::RationalNumber ConstantType;
35 static storm::Environment createEnvironment() {
37 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
38 return env;
39 }
40};
41class DoubleGmmxxEnvironment {
42 public:
43 typedef storm::RationalFunction FunctionType;
44 typedef double ConstantType;
45 static storm::Environment createEnvironment() {
47 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
48 return env;
49 }
50};
51class RationalEigenEnvironment {
52 public:
53 typedef storm::RationalFunction FunctionType;
54 typedef storm::RationalNumber ConstantType;
55 static storm::Environment createEnvironment() {
57 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
58 return env;
59 }
60};
61class DoubleEigenEnvironment {
62 public:
63 typedef storm::RationalFunction FunctionType;
64 typedef double ConstantType;
65 static storm::Environment createEnvironment() {
67 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
68 return env;
69 }
70};
71class DoubleEigenTopologicalEnvironment {
72 public:
73 typedef storm::RationalFunction FunctionType;
74 typedef double ConstantType;
75 static storm::Environment createEnvironment() {
77 env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Topological);
78 env.solver().topological().setUnderlyingEquationSolverType(storm::solver::EquationSolverType::Eigen);
79 return env;
80 }
81};
82template<typename TestType>
83class GradientDescentInstantiationSearcherTest : public ::testing::Test {
84 public:
85 typedef typename TestType::FunctionType FunctionType;
86 typedef typename TestType::ConstantType ConstantType;
87 GradientDescentInstantiationSearcherTest() : _environment(TestType::createEnvironment()) {}
88 storm::Environment const& env() const {
89 return _environment;
90 }
91 virtual void SetUp() {
92#ifndef STORM_HAVE_Z3
93 GTEST_SKIP() << "Z3 not available.";
94#endif
95 carl::VariablePool::getInstance().clear();
96 }
97 virtual void TearDown() {
98 carl::VariablePool::getInstance().clear();
99 }
100
101 private:
102 storm::Environment _environment;
103};
104
105typedef ::testing::Types<
106// The rational environments take ages... GD is just not made for rational arithmetic.
107#ifdef STORM_HAVE_GMM
108 DoubleGmmxxEnvironment,
109#endif
110 DoubleEigenEnvironment, DoubleEigenTopologicalEnvironment>
112} // namespace
113
114TYPED_TEST_SUITE(GradientDescentInstantiationSearcherTest, TestingTypes, );
115
116TYPED_TEST(GradientDescentInstantiationSearcherTest, Simple) {
117 std::string programFile = STORM_TEST_RESOURCES_DIR "/pdtmc/gradient1.pm";
118 std::string formulaAsString = "P>=0.2499 [F s=2]";
119 std::string constantsAsString = ""; // e.g. pL=0.9,TOACK=0.5
120
121 // Program and formula
122 storm::prism::Program program = storm::api::parseProgram(programFile);
123 program = storm::utility::prism::preprocess(program, constantsAsString);
124 std::vector<std::shared_ptr<const storm::logic::Formula>> formulas =
126 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> model =
127 storm::api::buildSparseModel<storm::RationalFunction>(program, formulas)->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
128 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> dtmc = model->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
130 ASSERT_TRUE(simplifier.simplify(*formulas[0]));
131 model = simplifier.getSimplifiedModel();
133
135
137 std::shared_ptr<storm::logic::Formula> formulaWithoutBounds = formulas[0]->clone();
138 std::shared_ptr<storm::logic::Formula const> formulaNoBound = formulaWithoutBounds->asSharedPointer();
139 std::shared_ptr<FeasibilitySynthesisTask> t = std::make_shared<FeasibilitySynthesisTask>(formulaNoBound);
140 t->setBound(formulas[0]->asOperatorFormula().getBound());
141 std::shared_ptr<FeasibilitySynthesisTask const> feasibilityTask = std::make_shared<FeasibilitySynthesisTask const>(std::move(*t));
142
143 checker.setup(this->env(), feasibilityTask);
144 typename TestFixture::ConstantType doubleInstantiation = checker.gradientDescent().second;
145 ASSERT_NEAR(doubleInstantiation * 4, 1, 1e-6);
146}
147
148TYPED_TEST(GradientDescentInstantiationSearcherTest, Crowds) {
149 std::string programFile = STORM_TEST_RESOURCES_DIR "/pdtmc/crowds3_5.pm";
150 std::string formulaAsString = "P<=0.00000001 [F \"observe0Greater1\"]";
151 std::string constantsAsString = ""; // e.g. pL=0.9,TOACK=0.5
152
153 // Program and formula
154 storm::prism::Program program = storm::api::parseProgram(programFile);
155 program = storm::utility::prism::preprocess(program, constantsAsString);
156 std::vector<std::shared_ptr<const storm::logic::Formula>> formulas =
158 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> model =
159 storm::api::buildSparseModel<storm::RationalFunction>(program, formulas)->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
160 std::shared_ptr<storm::models::sparse::Dtmc<storm::RationalFunction>> dtmc = model->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
162 ASSERT_TRUE(simplifier.simplify(*formulas[0]));
163 model = simplifier.getSimplifiedModel();
164
165 dtmc = storm::api::performBisimulationMinimization<storm::RationalFunction>(model, formulas)->as<storm::models::sparse::Dtmc<storm::RationalFunction>>();
166
167 std::shared_ptr<storm::logic::Formula> formulaWithoutBounds = formulas[0]->clone();
168 std::shared_ptr<storm::logic::Formula const> formulaNoBound = formulaWithoutBounds->asSharedPointer();
169 std::shared_ptr<FeasibilitySynthesisTask> t = std::make_shared<FeasibilitySynthesisTask>(formulaNoBound);
170 t->setBound(formulas[0]->asOperatorFormula().getBound());
171 std::shared_ptr<FeasibilitySynthesisTask const> feasibilityTask = std::make_shared<FeasibilitySynthesisTask const>(std::move(*t));
172
174
175 // First, test an ADAM instance. We will check that we have implemented ADAM correctly by comparing our results to results gathered by an ADAM
176 // implementation in tensorflow :)
178 *dtmc, storm::derivative::GradientDescentMethod::ADAM, 0.01, 0.9, 0.999, 2, 1e-6, std::nullopt,
180 adamChecker.setup(this->env(), feasibilityTask);
181 auto doubleInstantiation = adamChecker.gradientDescent();
182 auto walk = adamChecker.getVisualizationWalk();
183
184 carl::Variable badCVar;
185 carl::Variable pfVar;
186 for (auto parameter : storm::models::sparse::getProbabilityParameters(*dtmc)) {
187 if (parameter.name() == "badC") {
188 badCVar = parameter;
189 } else if (parameter.name() == "PF") {
190 pfVar = parameter;
191 }
192 }
193 std::shared_ptr<storm::RawPolynomialCache> cache = std::make_shared<storm::RawPolynomialCache>();
194 auto const badC = storm::RationalFunction(storm::Polynomial(storm::RawPolynomial(badCVar), cache));
195 auto const pf = storm::RationalFunction(storm::Polynomial(storm::RawPolynomial(pfVar), cache));
196
197 const double badCValues[] = {0.5,
198 0.49000033736228943,
199 0.4799894690513611,
200 0.4699603021144867,
201 0.4599062204360962,
202 0.4498208165168762,
203 0.43969807028770447,
204 0.4295324981212616,
205 0.41931894421577454,
206 0.4090527594089508,
207 0.39872977137565613,
208 0.3883463442325592,
209 0.3778994083404541,
210 0.3673863708972931,
211 0.35680532455444336,
212 0.3461548984050751,
213 0.3354344069957733,
214 0.3246437907218933,
215 0.31378373503685,
216 0.30285564064979553,
217 0.2918616533279419,
218 0.28080472350120544,
219 0.2696886956691742,
220 0.258518248796463,
221 0.24729910492897034,
222 0.23603792488574982,
223 0.22474248707294464,
224 0.21342173218727112,
225 0.20208582282066345,
226 0.1907462775707245,
227 0.17941603064537048,
228 0.16810956597328186,
229 0.15684303641319275,
230 0.14563430845737457,
231 0.13450318574905396,
232 0.1234714537858963,
233 0.11256305128335953,
234 0.10180411487817764,
235 0.0912230908870697,
236 0.080850750207901,
237 0.07072017341852188};
238 const double pfValues[] = {0.5,
239 0.49000218510627747,
240 0.47999337315559387,
241 0.46996763348579407,
242 0.45991936326026917,
243 0.4498434364795685,
244 0.4397353231906891,
245 0.4295912981033325,
246 0.4194081425666809,
247 0.4091835021972656,
248 0.3989158570766449,
249 0.3886045515537262,
250 0.3782498240470886,
251 0.3678528368473053,
252 0.3574157953262329,
253 0.3469419777393341,
254 0.33643582463264465,
255 0.32590293884277344,
256 0.3153502345085144,
257 0.30478596687316895,
258 0.2942197024822235,
259 0.28366249799728394,
260 0.2731269598007202,
261 0.26262718439102173,
262 0.25217896699905396,
263 0.24179960787296295,
264 0.2315080612897873,
265 0.2213248461484909,
266 0.2112719565629959,
267 0.20137275755405426,
268 0.1916518211364746,
269 0.18213464319705963,
270 0.17284737527370453,
271 0.1638164073228836,
272 0.15506798028945923,
273 0.14662761986255646,
274 0.13851958513259888,
275 0.13076619803905487,
276 0.12338729947805405,
277 0.11639954894781113,
278 0.10981585085391998};
279
280 for (uint_fast64_t i = 0; i < 41; i++) {
281 ASSERT_NEAR(storm::utility::convertNumber<double>(walk[i].position[badCVar]), badCValues[i], 1e-4);
282 ASSERT_NEAR(storm::utility::convertNumber<double>(walk[i].position[pfVar]), pfValues[i], 1e-4);
283 }
284
285 ASSERT_NEAR(storm::utility::convertNumber<double>(doubleInstantiation.second), 0, 1e-6);
286
287 // Same thing with RAdam
289 *dtmc, storm::derivative::GradientDescentMethod::RADAM, 0.01, 0.9, 0.999, 2, 1e-6, std::nullopt,
291 radamChecker.setup(this->env(), feasibilityTask);
292 auto radamInstantiation = radamChecker.gradientDescent();
293 auto radamWalk = radamChecker.getVisualizationWalk();
294
295 const double badCValuesRadam[] = {0.5,
296 0.49060654640197754,
297 0.48096320033073425,
298 0.47105303406715393,
299 0.4608582556247711,
300 0.46068474650382996,
301 0.4604234993457794,
302 0.460092693567276,
303 0.4597006142139435,
304 0.4592539370059967,
305 0.4587569832801819,
306 0.45821380615234375,
307 0.4576275646686554,
308 0.4570005536079407,
309 0.45633500814437866,
310 0.4556325674057007,
311 0.45489493012428284,
312 0.454123318195343,
313 0.4533190429210663,
314 0.45248323678970337,
315 0.45161670446395874,
316 0.4507202208042145,
317 0.44979462027549744,
318 0.448840469121933,
319 0.44785845279693604,
320 0.4468490481376648,
321 0.4458127021789551,
322 0.4447498917579651,
323 0.44366100430488586,
324 0.4425463378429413,
325 0.4414062798023224,
326 0.4402410686016083,
327 0.4390510618686676,
328 0.43783634901046753,
329 0.43659722805023193,
330 0.43533384799957275,
331 0.4340464472770691,
332 0.4327350854873657,
333 0.43139997124671936,
334 0.43004119396209717,
335 0.4286588430404663};
336
337 const double pfValuesRadam[] = {0.5,
338 0.4985547959804535,
339 0.4970662295818329,
340 0.4955315589904785,
341 0.4939480423927307,
342 0.4937744438648224,
343 0.4935130178928375,
344 0.4931819438934326,
345 0.49278953671455383,
346 0.492342472076416,
347 0.49184510111808777,
348 0.4913014769554138,
349 0.49071478843688965,
350 0.4900873303413391,
351 0.48942139744758606,
352 0.4887186288833618,
353 0.48798075318336487,
354 0.4872090220451355,
355 0.48640477657318115,
356 0.48556917905807495,
357 0.48470309376716614,
358 0.4838073253631592,
359 0.48288270831108093,
360 0.48192986845970154,
361 0.4809495508670807,
362 0.4799422323703766,
363 0.47890838980674744,
364 0.4778485894203186,
365 0.4767632484436035,
366 0.4756527245044708,
367 0.47451743483543396,
368 0.4733576774597168,
369 0.4721738398075104,
370 0.470966100692749,
371 0.4697347581386566,
372 0.46848005056381226,
373 0.4672022759914398,
374 0.46590155363082886,
375 0.46457815170288086,
376 0.46323221921920776,
377 0.4618639349937439};
378
379 for (uint_fast64_t i = 0; i < 41; i++) {
380 ASSERT_NEAR(storm::utility::convertNumber<double>(radamWalk[i].position[badCVar]), badCValuesRadam[i], 1e-5);
381 ASSERT_NEAR(storm::utility::convertNumber<double>(radamWalk[i].position[pfVar]), pfValuesRadam[i], 1e-5);
382 }
383
384 // Same thing with momentum
386 *dtmc, storm::derivative::GradientDescentMethod::MOMENTUM, 0.001, 0.9, 0.999, 2, 1e-6, std::nullopt,
388 momentumChecker.setup(this->env(), feasibilityTask);
389 auto momentumInstantiation = momentumChecker.gradientDescent();
390 auto momentumWalk = momentumChecker.getVisualizationWalk();
391
392 const double badCValuesMomentum[] = {
393 0.5 + 1e-6, 0.4990617036819458, 0.4972723126411438, 0.4947088360786438, 0.4914357662200928, 0.4875074326992035, 0.48296919465065,
394 0.4778585731983185, 0.47220611572265625, 0.4660361409187317, 0.45936742424964905, 0.45221376419067383, 0.44458451867103577, 0.43648505210876465,
395 0.42791709303855896, 0.41887906193733215, 0.4093664884567261, 0.3993722200393677, 0.3888867497444153, 0.37789851427078247, 0.3663942813873291,
396 0.3543594181537628, 0.34177839756011963, 0.32863524556159973, 0.3149142265319824, 0.30060049891471863, 0.28568127751350403, 0.27014678716659546,
397 0.253991961479187, 0.23721818625926971, 0.2198355346918106, 0.201865553855896, 0.18334446847438812, 0.16432702541351318, 0.1448907107114792,
398 0.12514044344425201, 0.10521329939365387, 0.08528289198875427, 0.06556269526481628, 0.04630732536315918, 0.027810536324977875};
399 const double pfValuesMomentum[] = {
400 0.5 + 1e-6, 0.49985647201538086, 0.49958109855651855, 0.4991863965988159, 0.49868202209472656, 0.49807605147361755, 0.49737516045570374,
401 0.49658480286598206, 0.4957093894481659, 0.49475228786468506, 0.4937160611152649, 0.49260255694389343, 0.4914129376411438, 0.49014779925346375,
402 0.4888072907924652, 0.4873911142349243, 0.4858987331390381, 0.4843292832374573, 0.4826817214488983, 0.4809550344944, 0.47914814949035645,
403 0.47726020216941833, 0.47529059648513794, 0.4732392132282257, 0.4711065888404846, 0.4688941240310669, 0.4666043817996979, 0.46424129605293274,
404 0.46181055903434753, 0.45931991934776306, 0.4567795991897583, 0.4542025327682495, 0.4516047239303589, 0.4490053057670593, 0.44642651081085205,
405 0.4438934028148651, 0.44143298268318176, 0.43907299637794495, 0.43684011697769165, 0.43475744128227234, 0.43284183740615845};
406
407 for (uint_fast64_t i = 0; i < 41; i++) {
408 ASSERT_NEAR(storm::utility::convertNumber<double>(momentumWalk[i].position[badCVar]), badCValuesMomentum[i], 1e-5);
409 ASSERT_NEAR(storm::utility::convertNumber<double>(momentumWalk[i].position[pfVar]), pfValuesMomentum[i], 1e-5);
410 }
411
412 // Same thing with nesterov
414 *dtmc, storm::derivative::GradientDescentMethod::NESTEROV, 0.001, 0.9, 0.999, 2, 1e-6, std::nullopt,
416 nesterovChecker.setup(this->env(), feasibilityTask);
417 auto nesterovInstantiation = nesterovChecker.gradientDescent();
418 auto nesterovWalk = nesterovChecker.getVisualizationWalk();
419
420 const double badCValuesNesterov[] = {
421 0.5 + 1e-6, 0.49821633100509644, 0.49565380811691284, 0.4923747181892395, 0.48843076825141907, 0.4838651120662689, 0.4787132740020752,
422 0.473004013299942, 0.4667600393295288, 0.45999863743782043, 0.452732115983963, 0.44496846199035645, 0.43671154975891113, 0.4279615581035614,
423 0.4187155067920685, 0.4089673161506653, 0.3987082839012146, 0.387927383184433, 0.376611590385437, 0.36474621295928955, 0.35231542587280273,
424 0.33930283784866333, 0.3256921172142029, 0.3114677667617798, 0.2966162860393524, 0.2811274528503418, 0.2649959325790405, 0.24822339415550232,
425 0.23082087934017181, 0.21281197667121887, 0.19423632323741913, 0.1751537024974823, 0.15564869344234467, 0.1358354091644287, 0.11586219072341919,
426 0.09591540694236755, 0.07622162997722626, 0.05704677850008011, 0.03869114816188812, 0.021478772163391113, 0.005740571767091751};
427 const double pfValuesNesterov[] = {
428 0.5 + 1e-6, 0.49972638487815857, 0.4993318021297455, 0.4988263249397278, 0.4982176721096039, 0.4975121319293976, 0.4967147409915924,
429 0.4958295524120331, 0.4948597252368927, 0.4938075542449951, 0.49267470836639404, 0.49146217107772827, 0.49017035961151123, 0.48879921436309814,
430 0.48734840750694275, 0.48581716418266296, 0.48420456051826477, 0.4825096130371094, 0.48073118925094604, 0.4788684844970703, 0.47692081332206726,
431 0.4748881161212921, 0.4727708697319031, 0.47057047486305237, 0.4682896137237549, 0.4659323990345001, 0.4635048806667328, 0.4610152840614319,
432 0.4584745764732361, 0.45589667558670044, 0.45329880714416504, 0.45070162415504456, 0.4481291174888611, 0.44560813903808594, 0.44316741824150085,
433 0.440836101770401, 0.43864157795906067, 0.4366067945957184, 0.4347473978996277, 0.433069109916687, 0.43156614899635315};
434
435 for (uint_fast64_t i = 0; i < 41; i++) {
436 ASSERT_NEAR(storm::utility::convertNumber<double>(nesterovWalk[i].position[badCVar]), badCValuesNesterov[i], 1e-5);
437 ASSERT_NEAR(storm::utility::convertNumber<double>(nesterovWalk[i].position[pfVar]), pfValuesNesterov[i], 1e-5);
438 }
439}
TYPED_TEST_SUITE(GradientDescentInstantiationSearcherTest, TestingTypes,)
TYPED_TEST(GradientDescentInstantiationSearcherTest, Simple)
SolverEnvironment & solver()
TopologicalSolverEnvironment & topological()
void setLinearEquationSolverType(storm::solver::EquationSolverType const &value, bool isSetFromDefault=false)
void setUnderlyingEquationSolverType(storm::solver::EquationSolverType value)
void setup(Environment const &env, std::shared_ptr< storm::pars::FeasibilitySynthesisTask const > const &task)
This will setup the matrices used for computing the derivatives by constructing the SparseDerivativeI...
std::vector< VisualizationPoint > getVisualizationWalk()
Get the visualization walk that is recorded if recordRun is set to true in the constructor (false by ...
std::pair< std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type >, ConstantType > gradientDescent()
Perform Gradient Descent.
std::shared_ptr< ModelType > as()
Casts the model into the model type given by the template parameter.
Definition ModelBase.h:37
This class represents a discrete-time Markov chain.
Definition Dtmc.h:13
This class performs different steps to simplify the given (parametric) model.
std::vector< storm::jani::Property > parsePropertiesForPrismProgram(std::string const &inputString, storm::prism::Program const &program, boost::optional< std::set< std::string > > const &propertyFilter)
storm::prism::Program parseProgram(std::string const &filename, bool prismCompatibility, bool simplify)
std::vector< std::shared_ptr< storm::logic::Formula const > > extractFormulasFromProperties(std::vector< storm::jani::Property > const &properties)
std::set< storm::RationalFunctionVariable > getProbabilityParameters(Model< storm::RationalFunction > const &model)
Get all probability parameters occurring on transitions.
Definition Model.cpp:687
storm::prism::Program preprocess(storm::prism::Program const &program, std::map< storm::expressions::Variable, storm::expressions::Expression > const &constantDefinitions)
Definition prism.cpp:19
carl::RationalFunction< Polynomial, true > RationalFunction
carl::MultivariatePolynomial< RationalFunctionCoefficient > RawPolynomial
::testing::Types< Cudd, Sylvan > TestingTypes
Definition GraphTest.cpp:46