Storm 1.11.1.1
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
GradientDescentInstantiationSearcher.h
Go to the documentation of this file.
1#pragma once
2
3#include <map>
4#include <memory>
12#include "storm/logic/Formula.h"
14
18
19namespace storm {
20namespace derivative {
21template<typename FunctionType, typename ConstantType>
23 public:
44 storm::models::sparse::Dtmc<FunctionType> const& model, GradientDescentMethod method = GradientDescentMethod::ADAM, ConstantType learningRate = 0.1,
45 ConstantType averageDecay = 0.9, ConstantType squaredAverageDecay = 0.999, uint_fast64_t miniBatchSize = 32, ConstantType terminationEpsilon = 1e-6,
46 std::optional<
48 startPoint = std::nullopt,
50 std::optional<storage::ParameterRegion<FunctionType>> region = std::nullopt, bool recordRun = false)
51 : model(model),
52 derivativeEvaluationHelper(std::make_unique<SparseDerivativeInstantiationModelChecker<FunctionType, ConstantType>>(model)),
53 instantiationModelChecker(
54 std::make_unique<modelchecker::SparseDtmcInstantiationModelChecker<models::sparse::Dtmc<FunctionType>, ConstantType>>(model)),
55 startPoint(startPoint),
56 miniBatchSize(miniBatchSize),
57 terminationEpsilon(terminationEpsilon),
58 constraintMethod(constraintMethod),
59 region(region),
60 recordRun(recordRun) {
61 STORM_LOG_ERROR_COND(region == std::nullopt || (constraintMethod == GradientDescentConstraintMethod::PROJECT ||
63 "Specifying a region is only supported if you are constraining by projection.");
64 // TODO should we put this in subclasses?
65 switch (method) {
67 Adam adam;
68 adam.learningRate = learningRate;
69 adam.averageDecay = averageDecay;
70 adam.averageDecay = averageDecay;
71 adam.squaredAverageDecay = squaredAverageDecay;
72 gradientDescentType = adam;
73 break;
74 }
76 RAdam radam;
77 radam.learningRate = learningRate;
78 radam.averageDecay = averageDecay;
79 radam.squaredAverageDecay = squaredAverageDecay;
80 gradientDescentType = radam;
81 break;
82 }
84 RmsProp rmsProp;
85 rmsProp.learningRate = learningRate;
86 rmsProp.averageDecay = averageDecay;
87 gradientDescentType = rmsProp;
88 break;
89 }
92 Plain plain;
93 plain.learningRate = learningRate;
94 gradientDescentType = plain;
96 useSignsOnly = true;
97 } else {
98 useSignsOnly = false;
99 }
100 break;
101 }
104 Momentum momentum;
105 momentum.learningRate = learningRate;
106 // TODO Document this
107 momentum.momentumTerm = averageDecay;
108 gradientDescentType = momentum;
110 useSignsOnly = true;
111 } else {
112 useSignsOnly = false;
113 }
114 break;
115 }
118 Nesterov nesterov;
119 nesterov.learningRate = learningRate;
120 // TODO Document this
121 nesterov.momentumTerm = averageDecay;
122 gradientDescentType = nesterov;
124 useSignsOnly = true;
125 } else {
126 useSignsOnly = false;
127 }
128 break;
129 }
130 }
131 }
132
141 void setup(Environment const& env, std::shared_ptr<storm::pars::FeasibilitySynthesisTask const> const& task) {
142 this->env = env;
143 this->parameters = storm::models::sparse::getProbabilityParameters(model);
144 this->synthesisTask = task;
145 STORM_LOG_ASSERT(task->getFormula().isProbabilityOperatorFormula() || task->getFormula().isRewardOperatorFormula(),
146 "Formula must be either a reward or a probability operator formula");
147
148 std::shared_ptr<storm::logic::Formula> formulaWithoutBounds = task->getFormula().clone();
149 formulaWithoutBounds->asOperatorFormula().removeBound();
150 this->currentFormulaNoBound = formulaWithoutBounds->asSharedPointer();
151
152 if (task->getFormula().isRewardOperatorFormula()) {
153 auto rewardParameters = storm::models::sparse::getRewardParameters(model);
154 this->parameters.insert(rewardParameters.begin(), rewardParameters.end());
155 }
156
157 this->currentCheckTaskNoBound = std::make_unique<storm::modelchecker::CheckTask<storm::logic::Formula, FunctionType>>(*currentFormulaNoBound);
158 this->currentCheckTaskNoBoundConstantType =
159 std::make_unique<storm::modelchecker::CheckTask<storm::logic::Formula, ConstantType>>(*currentFormulaNoBound);
160
161 instantiationModelChecker->specifyFormula(*this->currentCheckTaskNoBound);
162 derivativeEvaluationHelper->specifyFormula(env, *this->currentCheckTaskNoBound);
163 }
164
168 std::pair<std::map<typename utility::parametric::VariableType<FunctionType>::type, typename utility::parametric::CoefficientType<FunctionType>::type>,
169 ConstantType>
171
175 void printRunAsJson();
176
181 std::map<typename utility::parametric::VariableType<FunctionType>::type, typename utility::parametric::CoefficientType<FunctionType>::type> position;
182 ConstantType value;
183 };
187 std::vector<VisualizationPoint> getVisualizationWalk();
188
189 private:
190 void resetDynamicValues();
191
192 Environment env;
193 std::shared_ptr<storm::pars::FeasibilitySynthesisTask const> synthesisTask;
194 std::unique_ptr<modelchecker::CheckTask<storm::logic::Formula, FunctionType>> currentCheckTaskNoBound;
195 std::unique_ptr<modelchecker::CheckTask<storm::logic::Formula, ConstantType>> currentCheckTaskNoBoundConstantType;
196 std::shared_ptr<storm::logic::Formula const> currentFormulaNoBound;
197
199 std::set<typename utility::parametric::VariableType<FunctionType>::type> parameters;
200 const std::unique_ptr<storm::derivative::SparseDerivativeInstantiationModelChecker<FunctionType, ConstantType>> derivativeEvaluationHelper;
201 std::unique_ptr<storm::analysis::MonotonicityHelper<FunctionType, ConstantType>> monotonicityHelper;
202 const std::unique_ptr<modelchecker::SparseDtmcInstantiationModelChecker<models::sparse::Dtmc<FunctionType>, ConstantType>> instantiationModelChecker;
203 std::optional<std::map<typename utility::parametric::VariableType<FunctionType>::type, typename utility::parametric::CoefficientType<FunctionType>::type>>
204 startPoint;
205 const uint_fast64_t miniBatchSize;
206 const ConstantType terminationEpsilon;
207 const GradientDescentConstraintMethod constraintMethod;
208 const std::optional<storage::ParameterRegion<FunctionType>> region;
209
210 // This is for visualizing data
211 const bool recordRun;
212 std::vector<VisualizationPoint> walk;
213
214 // Gradient Descent types and data that belongs to them, with hyperparameters and running data.
215 struct Adam {
216 ConstantType averageDecay;
217 ConstantType squaredAverageDecay;
218 ConstantType learningRate;
219 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> decayingStepAverageSquared;
220 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> decayingStepAverage;
221 };
222 struct RAdam {
223 ConstantType averageDecay;
224 ConstantType squaredAverageDecay;
225 ConstantType learningRate;
226 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> decayingStepAverageSquared;
227 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> decayingStepAverage;
228 };
229 struct RmsProp {
230 ConstantType averageDecay;
231 ConstantType learningRate;
232 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> rootMeanSquare;
233 };
234 struct Plain {
235 ConstantType learningRate;
236 };
237 struct Momentum {
238 ConstantType learningRate;
239 ConstantType momentumTerm;
240 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> pastStep;
241 };
242 struct Nesterov {
243 ConstantType learningRate;
244 ConstantType momentumTerm;
245 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> pastStep;
246 };
247 typedef boost::variant<Adam, RAdam, RmsProp, Plain, Momentum, Nesterov> GradientDescentType;
248 GradientDescentType gradientDescentType;
249 // Only respected by some Gradient Descent methods, the ones that have a "sign" version in the GradientDescentMethod enum
250 bool useSignsOnly;
251
252 ConstantType logarithmicBarrierTerm;
253
254 ConstantType stochasticGradientDescent(
256 ConstantType doStep(
259 const std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType>& gradient, uint_fast64_t stepNum);
260 ConstantType constantTypeSqrt(ConstantType input) {
261 if (std::is_same<ConstantType, double>::value) {
262 return utility::sqrt(input);
263 } else {
264 return carl::sqrt(input);
265 }
266 }
267
268 utility::Stopwatch stochasticWatch;
269 utility::Stopwatch batchWatch;
270 utility::Stopwatch startingPointCalculationWatch;
271};
272} // namespace derivative
273} // namespace storm
void setup(Environment const &env, std::shared_ptr< storm::pars::FeasibilitySynthesisTask const > const &task)
This will setup the matrices used for computing the derivatives by constructing the SparseDerivativeI...
GradientDescentInstantiationSearcher(storm::models::sparse::Dtmc< FunctionType > const &model, GradientDescentMethod method=GradientDescentMethod::ADAM, ConstantType learningRate=0.1, ConstantType averageDecay=0.9, ConstantType squaredAverageDecay=0.999, uint_fast64_t miniBatchSize=32, ConstantType terminationEpsilon=1e-6, std::optional< std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type > > startPoint=std::nullopt, GradientDescentConstraintMethod constraintMethod=GradientDescentConstraintMethod::PROJECT_WITH_GRADIENT, std::optional< storage::ParameterRegion< FunctionType > > region=std::nullopt, bool recordRun=false)
The GradientDescentInstantiationSearcher can find extrema and feasible instantiations in pMCs,...
std::vector< VisualizationPoint > getVisualizationWalk()
Get the visualization walk that is recorded if recordRun is set to true in the constructor (false by ...
std::pair< std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type >, ConstantType > gradientDescent()
Perform Gradient Descent.
This class represents a discrete-time Markov chain.
Definition Dtmc.h:14
#define STORM_LOG_ASSERT(cond, message)
Definition macros.h:11
#define STORM_LOG_ERROR_COND(cond, message)
Definition macros.h:52
GradientDescentMethod
GradientDescentMethod is the method of Gradient Descent the GradientDescentInstantiationSearcher shal...
GradientDescentConstraintMethod
GradientDescentConstraintMethod is the method for mitigating constraints that the GradientDescentInst...
std::set< storm::RationalFunctionVariable > getRewardParameters(Model< storm::RationalFunction > const &model)
Get all parameters occurring in rewards.
Definition Model.cpp:707
std::set< storm::RationalFunctionVariable > getProbabilityParameters(Model< storm::RationalFunction > const &model)
Get all probability parameters occurring on transitions.
Definition Model.cpp:703
ValueType sqrt(ValueType const &number)
LabParser.cpp.
A point in the Gradient Descent walk, recorded if recordRun is set to true in the constructor (false ...
std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type > position