Storm 1.11.1.1
A Modern Probabilistic Model Checker
Loading...
Searching...
No Matches
GradientDescentInstantiationSearcher.h
Go to the documentation of this file.
1#pragma once
2
3#include <map>
4#include <memory>
5
13#include "storm/logic/Formula.h"
17
18namespace storm {
19namespace derivative {
20template<typename FunctionType, typename ConstantType>
22 public:
43 storm::models::sparse::Dtmc<FunctionType> const& model, GradientDescentMethod method = GradientDescentMethod::ADAM, ConstantType learningRate = 0.1,
44 ConstantType averageDecay = 0.9, ConstantType squaredAverageDecay = 0.999, uint_fast64_t miniBatchSize = 32, ConstantType terminationEpsilon = 1e-6,
45 std::optional<
47 startPoint = std::nullopt,
49 std::optional<storage::ParameterRegion<FunctionType>> region = std::nullopt, bool recordRun = false)
50 : model(model),
51 derivativeEvaluationHelper(std::make_unique<SparseDerivativeInstantiationModelChecker<FunctionType, ConstantType>>(model)),
52 instantiationModelChecker(
53 std::make_unique<modelchecker::SparseDtmcInstantiationModelChecker<models::sparse::Dtmc<FunctionType>, ConstantType>>(model)),
54 startPoint(startPoint),
55 miniBatchSize(miniBatchSize),
56 terminationEpsilon(terminationEpsilon),
57 constraintMethod(constraintMethod),
58 region(region),
59 recordRun(recordRun) {
60 STORM_LOG_ERROR_COND(region == std::nullopt || (constraintMethod == GradientDescentConstraintMethod::PROJECT ||
62 "Specifying a region is only supported if you are constraining by projection.");
63 // TODO should we put this in subclasses?
64 switch (method) {
66 Adam adam;
67 adam.learningRate = learningRate;
68 adam.averageDecay = averageDecay;
69 adam.averageDecay = averageDecay;
70 adam.squaredAverageDecay = squaredAverageDecay;
71 gradientDescentType = adam;
72 break;
73 }
75 RAdam radam;
76 radam.learningRate = learningRate;
77 radam.averageDecay = averageDecay;
78 radam.squaredAverageDecay = squaredAverageDecay;
79 gradientDescentType = radam;
80 break;
81 }
83 RmsProp rmsProp;
84 rmsProp.learningRate = learningRate;
85 rmsProp.averageDecay = averageDecay;
86 gradientDescentType = rmsProp;
87 break;
88 }
91 Plain plain;
92 plain.learningRate = learningRate;
93 gradientDescentType = plain;
95 useSignsOnly = true;
96 } else {
97 useSignsOnly = false;
98 }
99 break;
100 }
103 Momentum momentum;
104 momentum.learningRate = learningRate;
105 // TODO Document this
106 momentum.momentumTerm = averageDecay;
107 gradientDescentType = momentum;
109 useSignsOnly = true;
110 } else {
111 useSignsOnly = false;
112 }
113 break;
114 }
117 Nesterov nesterov;
118 nesterov.learningRate = learningRate;
119 // TODO Document this
120 nesterov.momentumTerm = averageDecay;
121 gradientDescentType = nesterov;
123 useSignsOnly = true;
124 } else {
125 useSignsOnly = false;
126 }
127 break;
128 }
129 }
130 }
131
140 void setup(Environment const& env, std::shared_ptr<storm::pars::FeasibilitySynthesisTask const> const& task) {
141 this->env = env;
142 this->parameters = storm::models::sparse::getProbabilityParameters(model);
143 this->synthesisTask = task;
144 STORM_LOG_ASSERT(task->getFormula().isProbabilityOperatorFormula() || task->getFormula().isRewardOperatorFormula(),
145 "Formula must be either a reward or a probability operator formula");
146
147 std::shared_ptr<storm::logic::Formula> formulaWithoutBounds = task->getFormula().clone();
148 formulaWithoutBounds->asOperatorFormula().removeBound();
149 this->currentFormulaNoBound = formulaWithoutBounds->asSharedPointer();
150
151 if (task->getFormula().isRewardOperatorFormula()) {
152 auto rewardParameters = storm::models::sparse::getRewardParameters(model);
153 this->parameters.insert(rewardParameters.begin(), rewardParameters.end());
154 }
155
156 this->currentCheckTaskNoBound = std::make_unique<storm::modelchecker::CheckTask<storm::logic::Formula, FunctionType>>(*currentFormulaNoBound);
157 this->currentCheckTaskNoBoundConstantType =
158 std::make_unique<storm::modelchecker::CheckTask<storm::logic::Formula, ConstantType>>(*currentFormulaNoBound);
159
160 instantiationModelChecker->specifyFormula(*this->currentCheckTaskNoBound);
161 derivativeEvaluationHelper->specifyFormula(env, *this->currentCheckTaskNoBound);
162 }
163
167 std::pair<std::map<typename utility::parametric::VariableType<FunctionType>::type, typename utility::parametric::CoefficientType<FunctionType>::type>,
168 ConstantType>
170
174 void printRunAsJson();
175
180 std::map<typename utility::parametric::VariableType<FunctionType>::type, typename utility::parametric::CoefficientType<FunctionType>::type> position;
181 ConstantType value;
182 };
186 std::vector<VisualizationPoint> getVisualizationWalk();
187
188 private:
189 void resetDynamicValues();
190
191 Environment env;
192 std::shared_ptr<storm::pars::FeasibilitySynthesisTask const> synthesisTask;
193 std::unique_ptr<modelchecker::CheckTask<storm::logic::Formula, FunctionType>> currentCheckTaskNoBound;
194 std::unique_ptr<modelchecker::CheckTask<storm::logic::Formula, ConstantType>> currentCheckTaskNoBoundConstantType;
195 std::shared_ptr<storm::logic::Formula const> currentFormulaNoBound;
196
198 std::set<typename utility::parametric::VariableType<FunctionType>::type> parameters;
199 const std::unique_ptr<storm::derivative::SparseDerivativeInstantiationModelChecker<FunctionType, ConstantType>> derivativeEvaluationHelper;
200 std::unique_ptr<storm::analysis::MonotonicityHelper<FunctionType, ConstantType>> monotonicityHelper;
201 const std::unique_ptr<modelchecker::SparseDtmcInstantiationModelChecker<models::sparse::Dtmc<FunctionType>, ConstantType>> instantiationModelChecker;
202 std::optional<std::map<typename utility::parametric::VariableType<FunctionType>::type, typename utility::parametric::CoefficientType<FunctionType>::type>>
203 startPoint;
204 const uint_fast64_t miniBatchSize;
205 const ConstantType terminationEpsilon;
206 const GradientDescentConstraintMethod constraintMethod;
207 const std::optional<storage::ParameterRegion<FunctionType>> region;
208
209 // This is for visualizing data
210 const bool recordRun;
211 std::vector<VisualizationPoint> walk;
212
213 // Gradient Descent types and data that belongs to them, with hyperparameters and running data.
214 struct Adam {
215 ConstantType averageDecay;
216 ConstantType squaredAverageDecay;
217 ConstantType learningRate;
218 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> decayingStepAverageSquared;
219 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> decayingStepAverage;
220 };
221 struct RAdam {
222 ConstantType averageDecay;
223 ConstantType squaredAverageDecay;
224 ConstantType learningRate;
225 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> decayingStepAverageSquared;
226 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> decayingStepAverage;
227 };
228 struct RmsProp {
229 ConstantType averageDecay;
230 ConstantType learningRate;
231 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> rootMeanSquare;
232 };
233 struct Plain {
234 ConstantType learningRate;
235 };
236 struct Momentum {
237 ConstantType learningRate;
238 ConstantType momentumTerm;
239 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> pastStep;
240 };
241 struct Nesterov {
242 ConstantType learningRate;
243 ConstantType momentumTerm;
244 std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType> pastStep;
245 };
246 typedef boost::variant<Adam, RAdam, RmsProp, Plain, Momentum, Nesterov> GradientDescentType;
247 GradientDescentType gradientDescentType;
248 // Only respected by some Gradient Descent methods, the ones that have a "sign" version in the GradientDescentMethod enum
249 bool useSignsOnly;
250
251 ConstantType logarithmicBarrierTerm;
252
253 ConstantType stochasticGradientDescent(
255 ConstantType doStep(
258 const std::map<typename utility::parametric::VariableType<FunctionType>::type, ConstantType>& gradient, uint_fast64_t stepNum);
259 ConstantType constantTypeSqrt(ConstantType input) {
260 if (std::is_same<ConstantType, double>::value) {
261 return utility::sqrt(input);
262 } else {
263 return carl::sqrt(input);
264 }
265 }
266
267 utility::Stopwatch stochasticWatch;
268 utility::Stopwatch batchWatch;
269 utility::Stopwatch startingPointCalculationWatch;
270};
271} // namespace derivative
272} // namespace storm
void setup(Environment const &env, std::shared_ptr< storm::pars::FeasibilitySynthesisTask const > const &task)
This will setup the matrices used for computing the derivatives by constructing the SparseDerivativeI...
GradientDescentInstantiationSearcher(storm::models::sparse::Dtmc< FunctionType > const &model, GradientDescentMethod method=GradientDescentMethod::ADAM, ConstantType learningRate=0.1, ConstantType averageDecay=0.9, ConstantType squaredAverageDecay=0.999, uint_fast64_t miniBatchSize=32, ConstantType terminationEpsilon=1e-6, std::optional< std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type > > startPoint=std::nullopt, GradientDescentConstraintMethod constraintMethod=GradientDescentConstraintMethod::PROJECT_WITH_GRADIENT, std::optional< storage::ParameterRegion< FunctionType > > region=std::nullopt, bool recordRun=false)
The GradientDescentInstantiationSearcher can find extrema and feasible instantiations in pMCs,...
std::vector< VisualizationPoint > getVisualizationWalk()
Get the visualization walk that is recorded if recordRun is set to true in the constructor (false by ...
std::pair< std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type >, ConstantType > gradientDescent()
Perform Gradient Descent.
This class represents a discrete-time Markov chain.
Definition Dtmc.h:13
#define STORM_LOG_ASSERT(cond, message)
Definition macros.h:11
#define STORM_LOG_ERROR_COND(cond, message)
Definition macros.h:52
GradientDescentMethod
GradientDescentMethod is the method of Gradient Descent the GradientDescentInstantiationSearcher shal...
GradientDescentConstraintMethod
GradientDescentConstraintMethod is the method for mitigating constraints that the GradientDescentInst...
std::set< storm::RationalFunctionVariable > getRewardParameters(Model< storm::RationalFunction > const &model)
Get all parameters occurring in rewards.
Definition Model.cpp:699
std::set< storm::RationalFunctionVariable > getProbabilityParameters(Model< storm::RationalFunction > const &model)
Get all probability parameters occurring on transitions.
Definition Model.cpp:695
ValueType sqrt(ValueType const &number)
A point in the Gradient Descent walk, recorded if recordRun is set to true in the constructor (false ...
std::map< typename utility::parametric::VariableType< FunctionType >::type, typename utility::parametric::CoefficientType< FunctionType >::type > position