Storm
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
GradientDescentInstantiationSearcher.h
Go to the documentation of this file.
1#pragma once
2
3#include <map>
4#include <memory>
13#include "storm/logic/Formula.h"
15
19
20namespace storm {
21namespace derivative {
22template<typename FunctionType, typename ConstantType>
24 public:
43 storm::models::sparse::Dtmc<FunctionType> const& model, GradientDescentMethod method = GradientDescentMethod::ADAM, ConstantType learningRate = 0.1,
44 ConstantType averageDecay = 0.9, ConstantType squaredAverageDecay = 0.999, uint_fast64_t miniBatchSize = 32, ConstantType terminationEpsilon = 1e-6,
45 boost::optional<
47 startPoint = boost::none,
49 : model(model),
50 derivativeEvaluationHelper(std::make_unique<SparseDerivativeInstantiationModelChecker<FunctionType, ConstantType>>(model)),
51 instantiationModelChecker(
52 std::make_unique<modelchecker::SparseDtmcInstantiationModelChecker<models::sparse::Dtmc<FunctionType>, ConstantType>>(model)),
53 startPoint(startPoint),
54 miniBatchSize(miniBatchSize),
55 terminationEpsilon(terminationEpsilon),
56 constraintMethod(constraintMethod),
57 recordRun(recordRun) {
58 // TODO should we put this in subclasses?
59 switch (method) {
61 Adam adam;
62 adam.learningRate = learningRate;
63 adam.averageDecay = averageDecay;
64 adam.averageDecay = averageDecay;
65 adam.squaredAverageDecay = squaredAverageDecay;
66 gradientDescentType = adam;
67 break;
68 }
70 RAdam radam;
71 radam.learningRate = learningRate;
72 radam.averageDecay = averageDecay;
73 radam.squaredAverageDecay = squaredAverageDecay;
74 gradientDescentType = radam;
75 break;
76 }
78 RmsProp rmsProp;
79 rmsProp.learningRate = learningRate;
80 rmsProp.averageDecay = averageDecay;
81 gradientDescentType = rmsProp;
82 break;
83 }
86 Plain plain;
87 plain.learningRate = learningRate;
88 gradientDescentType = plain;
90 useSignsOnly = true;
91 } else {
92 useSignsOnly = false;
93 }
94 break;
95 }
98 Momentum momentum;
99 momentum.learningRate = learningRate;
100 // TODO Document this
101 momentum.momentumTerm = averageDecay;
102 gradientDescentType = momentum;
104 useSignsOnly = true;
105 } else {
106 useSignsOnly = false;
107 }
108 break;
109 }
112 Nesterov nesterov;
113 nesterov.learningRate = learningRate;
114 // TODO Document this
115 nesterov.momentumTerm = averageDecay;
116 gradientDescentType = nesterov;
118 useSignsOnly = true;
119 } else {
120 useSignsOnly = false;
121 }
122 break;
123 }
124 }
125 }
126
135 void setup(Environment const& env, std::shared_ptr<storm::pars::FeasibilitySynthesisTask const> const& task) {
136 this->env = env;
137 this->parameters = storm::models::sparse::getProbabilityParameters(model);
138 this->synthesisTask = task;
139 STORM_LOG_ASSERT(task->getFormula().isProbabilityOperatorFormula() || task->getFormula().isRewardOperatorFormula(),
140 "Formula must be either a reward or a probability operator formula");
141
142 std::shared_ptr<storm::logic::Formula> formulaWithoutBounds = task->getFormula().clone();
143 formulaWithoutBounds->asOperatorFormula().removeBound();
144 this->currentFormulaNoBound = formulaWithoutBounds->asSharedPointer();
145
146 if (task->getFormula().isRewardOperatorFormula()) {
147 auto rewardParameters = storm::models::sparse::getRewardParameters(model);
148 this->parameters.insert(rewardParameters.begin(), rewardParameters.end());
149 }
150
151 this->currentCheckTaskNoBound = std::make_unique<storm::modelchecker::CheckTask<storm::logic::Formula, FunctionType>>(*currentFormulaNoBound);
152 this->currentCheckTaskNoBoundConstantType =
153 std::make_unique<storm::modelchecker::CheckTask<storm::logic::Formula, ConstantType>>(*currentFormulaNoBound);
154
155 instantiationModelChecker->specifyFormula(*this->currentCheckTaskNoBound);
156 derivativeEvaluationHelper->specifyFormula(env, *this->currentCheckTaskNoBound);
157 }
158
162 std::pair<std::map<typename utility::parametric::VariableType<FunctionType>::type, typename utility::parametric::CoefficientType<FunctionType>::type>,
163 ConstantType>
165
169 void printRunAsJson();
170
175 std::map<typename utility::parametric::VariableType<FunctionType>::type, typename utility::parametric::CoefficientType<FunctionType>::type> position;
176 ConstantType value;
177 };
181 std::vector<VisualizationPoint> getVisualizationWalk();
182
183 private:
184 void resetDynamicValues();
185
186 Environment env;
187 std::shared_ptr<storm::pars::FeasibilitySynthesisTask const> synthesisTask;
188 std::unique_ptr<modelchecker::CheckTask<storm::logic::Formula, FunctionType>> currentCheckTaskNoBound;
189 std::unique_ptr<modelchecker::CheckTask<storm::logic::Formula, ConstantType>> currentCheckTaskNoBoundConstantType;
190 std::shared_ptr<storm::logic::Formula const> currentFormulaNoBound;
191
193 std::set<typename utility::parametric::VariableType<FunctionType>::type> parameters;
194 const std::unique_ptr<storm::derivative::SparseDerivativeInstantiationModelChecker<FunctionType, ConstantType>> derivativeEvaluationHelper;
195 std::unique_ptr<storm::analysis::MonotonicityHelper<FunctionType, ConstantType>> monotonicityHelper;
196 const std::unique_ptr<modelchecker::SparseDtmcInstantiationModelChecker<models::sparse::Dtmc<FunctionType>, ConstantType>> instantiationModelChecker;
197 boost::optional<std::map<typename utility::parametric::VariableType<FunctionType>::type, typename utility::parametric::CoefficientType<FunctionType>::type>>
198 startPoint;
199 const uint_fast64_t miniBatchSize;
200 const ConstantType terminationEpsilon;
201 const GradientDescentConstraintMethod constraintMethod;
202
203 // This is for visualizing data
204 const bool recordRun;
205 std::vector<VisualizationPoint> walk;
206
207 // Gradient Descent types and data that belongs to them, with hyperparameters and running data.
208 struct Adam {
209 ConstantType averageDecay;
210 ConstantType squaredAverageDecay;
211 ConstantType learningRate;
212 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> decayingStepAverageSquared;
213 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> decayingStepAverage;
214 };
215 struct RAdam {
216 ConstantType averageDecay;
217 ConstantType squaredAverageDecay;
218 ConstantType learningRate;
219 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> decayingStepAverageSquared;
220 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> decayingStepAverage;
221 };
222 struct RmsProp {
223 ConstantType averageDecay;
224 ConstantType learningRate;
225 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> rootMeanSquare;
226 };
227 struct Plain {
228 ConstantType learningRate;
229 };
230 struct Momentum {
231 ConstantType learningRate;
232 ConstantType momentumTerm;
233 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> pastStep;
234 };
235 struct Nesterov {
236 ConstantType learningRate;
237 ConstantType momentumTerm;
238 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> pastStep;
239 };
240 typedef boost::variant<Adam, RAdam, RmsProp, Plain, Momentum, Nesterov> GradientDescentType;
241 GradientDescentType gradientDescentType;
242 // Only respected by some Gradient Descent methods, the ones that have a "sign" version in the GradientDescentMethod enum
243 bool useSignsOnly;
244
245 ConstantType logarithmicBarrierTerm;
246
247 ConstantType stochasticGradientDescent(
249 ConstantType doStep(
252 const std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType>& gradient, uint_fast64_t stepNum);
253 ConstantType constantTypeSqrt(ConstantType input) {
254 if (std::is_same<ConstantType, double>::value) {
255 return utility::sqrt(input);
256 } else {
257 return carl::sqrt(input);
258 }
259 }
260
261 utility::Stopwatch stochasticWatch;
262 utility::Stopwatch batchWatch;
263 utility::Stopwatch startingPointCalculationWatch;
264};
265} // namespace derivative
266} // namespace storm
GradientDescentInstantiationSearcher(storm::models::sparse::Dtmc< FunctionType > const &model, GradientDescentMethod method=GradientDescentMethod::ADAM, ConstantType learningRate=0.1, ConstantType averageDecay=0.9, ConstantType squaredAverageDecay=0.999, uint_fast64_t miniBatchSize=32, ConstantType terminationEpsilon=1e-6, boost::optional< std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type > > startPoint=boost::none, GradientDescentConstraintMethod constraintMethod=GradientDescentConstraintMethod::PROJECT_WITH_GRADIENT, bool recordRun=false)
The GradientDescentInstantiationSearcher can find extrema and feasible instantiations in pMCs,...
void setup(Environment const &env, std::shared_ptr< storm::pars::FeasibilitySynthesisTask const > const &task)
This will setup the matrices used for computing the derivatives by constructing the SparseDerivativeI...
std::vector< VisualizationPoint > getVisualizationWalk()
Get the visualization walk that is recorded if recordRun is set to true in the constructor (false by ...
std::pair< std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type >, ConstantType > gradientDescent()
Perform Gradient Descent.
This class represents a discrete-time Markov chain.
Definition Dtmc.h:14
#define STORM_LOG_ASSERT(cond, message)
Definition macros.h:11
GradientDescentMethod
GradientDescentMethod is the method of Gradient Descent the GradientDescentInstantiationSearcher shal...
GradientDescentConstraintMethod
GradientDescentConstraintMethod is the method for mitigating constraints that the GradientDescentInst...
std::set< storm::RationalFunctionVariable > getRewardParameters(Model< storm::RationalFunction > const &model)
Get all parameters occurring in rewards.
Definition Model.cpp:707
std::set< storm::RationalFunctionVariable > getProbabilityParameters(Model< storm::RationalFunction > const &model)
Get all probability parameters occurring on transitions.
Definition Model.cpp:703
ValueType sqrt(ValueType const &number)
LabParser.cpp.
Definition cli.cpp:18
A point in the Gradient Descent walk, recorded if recordRun is set to true in the constructor (false ...
std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type > position