comparison libs/commons-math-2.1/docs/apidocs/src-html/org/apache/commons/math/estimation/GaussNewtonEstimator.html @ 13:cbf34dd4d7e6

commons-math-2.1 added
author dwinter
date Tue, 04 Jan 2011 10:02:07 +0100
parents
children
comparison
equal deleted inserted replaced
12:970d26a94fb7 13:cbf34dd4d7e6
1 <HTML>
2 <BODY BGCOLOR="white">
3 <PRE>
4 <FONT color="green">001</FONT> /*<a name="line.1"></a>
5 <FONT color="green">002</FONT> * Licensed to the Apache Software Foundation (ASF) under one or more<a name="line.2"></a>
6 <FONT color="green">003</FONT> * contributor license agreements. See the NOTICE file distributed with<a name="line.3"></a>
7 <FONT color="green">004</FONT> * this work for additional information regarding copyright ownership.<a name="line.4"></a>
8 <FONT color="green">005</FONT> * The ASF licenses this file to You under the Apache License, Version 2.0<a name="line.5"></a>
9 <FONT color="green">006</FONT> * (the "License"); you may not use this file except in compliance with<a name="line.6"></a>
10 <FONT color="green">007</FONT> * the License. You may obtain a copy of the License at<a name="line.7"></a>
11 <FONT color="green">008</FONT> *<a name="line.8"></a>
12 <FONT color="green">009</FONT> * http://www.apache.org/licenses/LICENSE-2.0<a name="line.9"></a>
13 <FONT color="green">010</FONT> *<a name="line.10"></a>
14 <FONT color="green">011</FONT> * Unless required by applicable law or agreed to in writing, software<a name="line.11"></a>
15 <FONT color="green">012</FONT> * distributed under the License is distributed on an "AS IS" BASIS,<a name="line.12"></a>
16 <FONT color="green">013</FONT> * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.<a name="line.13"></a>
17 <FONT color="green">014</FONT> * See the License for the specific language governing permissions and<a name="line.14"></a>
18 <FONT color="green">015</FONT> * limitations under the License.<a name="line.15"></a>
19 <FONT color="green">016</FONT> */<a name="line.16"></a>
20 <FONT color="green">017</FONT> <a name="line.17"></a>
21 <FONT color="green">018</FONT> package org.apache.commons.math.estimation;<a name="line.18"></a>
22 <FONT color="green">019</FONT> <a name="line.19"></a>
23 <FONT color="green">020</FONT> import java.io.Serializable;<a name="line.20"></a>
24 <FONT color="green">021</FONT> <a name="line.21"></a>
25 <FONT color="green">022</FONT> import org.apache.commons.math.linear.InvalidMatrixException;<a name="line.22"></a>
26 <FONT color="green">023</FONT> import org.apache.commons.math.linear.LUDecompositionImpl;<a name="line.23"></a>
27 <FONT color="green">024</FONT> import org.apache.commons.math.linear.MatrixUtils;<a name="line.24"></a>
28 <FONT color="green">025</FONT> import org.apache.commons.math.linear.RealMatrix;<a name="line.25"></a>
29 <FONT color="green">026</FONT> import org.apache.commons.math.linear.RealVector;<a name="line.26"></a>
30 <FONT color="green">027</FONT> import org.apache.commons.math.linear.ArrayRealVector;<a name="line.27"></a>
31 <FONT color="green">028</FONT> <a name="line.28"></a>
32 <FONT color="green">029</FONT> /**<a name="line.29"></a>
33 <FONT color="green">030</FONT> * This class implements a solver for estimation problems.<a name="line.30"></a>
34 <FONT color="green">031</FONT> *<a name="line.31"></a>
35 <FONT color="green">032</FONT> * &lt;p&gt;This class solves estimation problems using a weighted least<a name="line.32"></a>
36 <FONT color="green">033</FONT> * squares criterion on the measurement residuals. It uses a<a name="line.33"></a>
37 <FONT color="green">034</FONT> * Gauss-Newton algorithm.&lt;/p&gt;<a name="line.34"></a>
38 <FONT color="green">035</FONT> *<a name="line.35"></a>
39 <FONT color="green">036</FONT> * @version $Revision: 811685 $ $Date: 2009-09-05 13:36:48 -0400 (Sat, 05 Sep 2009) $<a name="line.36"></a>
40 <FONT color="green">037</FONT> * @since 1.2<a name="line.37"></a>
41 <FONT color="green">038</FONT> * @deprecated as of 2.0, everything in package org.apache.commons.math.estimation has<a name="line.38"></a>
42 <FONT color="green">039</FONT> * been deprecated and replaced by package org.apache.commons.math.optimization.general<a name="line.39"></a>
43 <FONT color="green">040</FONT> *<a name="line.40"></a>
44 <FONT color="green">041</FONT> */<a name="line.41"></a>
45 <FONT color="green">042</FONT> @Deprecated<a name="line.42"></a>
46 <FONT color="green">043</FONT> public class GaussNewtonEstimator extends AbstractEstimator implements Serializable {<a name="line.43"></a>
47 <FONT color="green">044</FONT> <a name="line.44"></a>
48 <FONT color="green">045</FONT> /** Serializable version identifier */<a name="line.45"></a>
49 <FONT color="green">046</FONT> private static final long serialVersionUID = 5485001826076289109L;<a name="line.46"></a>
50 <FONT color="green">047</FONT> <a name="line.47"></a>
51 <FONT color="green">048</FONT> /** Default threshold for cost steady state detection. */<a name="line.48"></a>
52 <FONT color="green">049</FONT> private static final double DEFAULT_STEADY_STATE_THRESHOLD = 1.0e-6;<a name="line.49"></a>
53 <FONT color="green">050</FONT> <a name="line.50"></a>
54 <FONT color="green">051</FONT> /** Default threshold for cost convergence. */<a name="line.51"></a>
55 <FONT color="green">052</FONT> private static final double DEFAULT_CONVERGENCE = 1.0e-6;<a name="line.52"></a>
56 <FONT color="green">053</FONT> <a name="line.53"></a>
57 <FONT color="green">054</FONT> /** Threshold for cost steady state detection. */<a name="line.54"></a>
58 <FONT color="green">055</FONT> private double steadyStateThreshold;<a name="line.55"></a>
59 <FONT color="green">056</FONT> <a name="line.56"></a>
60 <FONT color="green">057</FONT> /** Threshold for cost convergence. */<a name="line.57"></a>
61 <FONT color="green">058</FONT> private double convergence;<a name="line.58"></a>
62 <FONT color="green">059</FONT> <a name="line.59"></a>
63 <FONT color="green">060</FONT> /** Simple constructor with default settings.<a name="line.60"></a>
64 <FONT color="green">061</FONT> * &lt;p&gt;<a name="line.61"></a>
65 <FONT color="green">062</FONT> * The estimator is built with default values for all settings.<a name="line.62"></a>
66 <FONT color="green">063</FONT> * &lt;/p&gt;<a name="line.63"></a>
67 <FONT color="green">064</FONT> * @see #DEFAULT_STEADY_STATE_THRESHOLD<a name="line.64"></a>
68 <FONT color="green">065</FONT> * @see #DEFAULT_CONVERGENCE<a name="line.65"></a>
69 <FONT color="green">066</FONT> * @see AbstractEstimator#DEFAULT_MAX_COST_EVALUATIONS<a name="line.66"></a>
70 <FONT color="green">067</FONT> */<a name="line.67"></a>
71 <FONT color="green">068</FONT> public GaussNewtonEstimator() {<a name="line.68"></a>
72 <FONT color="green">069</FONT> this.steadyStateThreshold = DEFAULT_STEADY_STATE_THRESHOLD;<a name="line.69"></a>
73 <FONT color="green">070</FONT> this.convergence = DEFAULT_CONVERGENCE;<a name="line.70"></a>
74 <FONT color="green">071</FONT> }<a name="line.71"></a>
75 <FONT color="green">072</FONT> <a name="line.72"></a>
76 <FONT color="green">073</FONT> /**<a name="line.73"></a>
77 <FONT color="green">074</FONT> * Simple constructor.<a name="line.74"></a>
78 <FONT color="green">075</FONT> *<a name="line.75"></a>
79 <FONT color="green">076</FONT> * &lt;p&gt;This constructor builds an estimator and stores its convergence<a name="line.76"></a>
80 <FONT color="green">077</FONT> * characteristics.&lt;/p&gt;<a name="line.77"></a>
81 <FONT color="green">078</FONT> *<a name="line.78"></a>
82 <FONT color="green">079</FONT> * &lt;p&gt;An estimator is considered to have converged whenever either<a name="line.79"></a>
83 <FONT color="green">080</FONT> * the criterion goes below a physical threshold under which<a name="line.80"></a>
84 <FONT color="green">081</FONT> * improvements are considered useless or when the algorithm is<a name="line.81"></a>
85 <FONT color="green">082</FONT> * unable to improve it (even if it is still high). The first<a name="line.82"></a>
86 <FONT color="green">083</FONT> * condition that is met stops the iterations.&lt;/p&gt;<a name="line.83"></a>
87 <FONT color="green">084</FONT> *<a name="line.84"></a>
88 <FONT color="green">085</FONT> * &lt;p&gt;The fact an estimator has converged does not mean that the<a name="line.85"></a>
89 <FONT color="green">086</FONT> * model accurately fits the measurements. It only means no better<a name="line.86"></a>
90 <FONT color="green">087</FONT> * solution can be found, it does not mean this one is good. Such an<a name="line.87"></a>
91 <FONT color="green">088</FONT> * analysis is left to the caller.&lt;/p&gt;<a name="line.88"></a>
92 <FONT color="green">089</FONT> *<a name="line.89"></a>
93 <FONT color="green">090</FONT> * &lt;p&gt;If neither conditions are fulfilled before a given number of<a name="line.90"></a>
94 <FONT color="green">091</FONT> * iterations, the algorithm is considered to have failed and an<a name="line.91"></a>
95 <FONT color="green">092</FONT> * {@link EstimationException} is thrown.&lt;/p&gt;<a name="line.92"></a>
96 <FONT color="green">093</FONT> *<a name="line.93"></a>
97 <FONT color="green">094</FONT> * @param maxCostEval maximal number of cost evaluations allowed<a name="line.94"></a>
98 <FONT color="green">095</FONT> * @param convergence criterion threshold below which we do not need<a name="line.95"></a>
99 <FONT color="green">096</FONT> * to improve the criterion anymore<a name="line.96"></a>
100 <FONT color="green">097</FONT> * @param steadyStateThreshold steady state detection threshold, the<a name="line.97"></a>
101 <FONT color="green">098</FONT> * problem has converged has reached a steady state if<a name="line.98"></a>
102 <FONT color="green">099</FONT> * &lt;code&gt;Math.abs(J&lt;sub&gt;n&lt;/sub&gt; - J&lt;sub&gt;n-1&lt;/sub&gt;) &amp;lt;<a name="line.99"></a>
103 <FONT color="green">100</FONT> * J&lt;sub&gt;n&lt;/sub&gt; &amp;times convergence&lt;/code&gt;, where &lt;code&gt;J&lt;sub&gt;n&lt;/sub&gt;&lt;/code&gt;<a name="line.100"></a>
104 <FONT color="green">101</FONT> * and &lt;code&gt;J&lt;sub&gt;n-1&lt;/sub&gt;&lt;/code&gt; are the current and preceding criterion<a name="line.101"></a>
105 <FONT color="green">102</FONT> * values (square sum of the weighted residuals of considered measurements).<a name="line.102"></a>
106 <FONT color="green">103</FONT> */<a name="line.103"></a>
107 <FONT color="green">104</FONT> public GaussNewtonEstimator(final int maxCostEval, final double convergence,<a name="line.104"></a>
108 <FONT color="green">105</FONT> final double steadyStateThreshold) {<a name="line.105"></a>
109 <FONT color="green">106</FONT> setMaxCostEval(maxCostEval);<a name="line.106"></a>
110 <FONT color="green">107</FONT> this.steadyStateThreshold = steadyStateThreshold;<a name="line.107"></a>
111 <FONT color="green">108</FONT> this.convergence = convergence;<a name="line.108"></a>
112 <FONT color="green">109</FONT> }<a name="line.109"></a>
113 <FONT color="green">110</FONT> <a name="line.110"></a>
114 <FONT color="green">111</FONT> /**<a name="line.111"></a>
115 <FONT color="green">112</FONT> * Set the convergence criterion threshold.<a name="line.112"></a>
116 <FONT color="green">113</FONT> * @param convergence criterion threshold below which we do not need<a name="line.113"></a>
117 <FONT color="green">114</FONT> * to improve the criterion anymore<a name="line.114"></a>
118 <FONT color="green">115</FONT> */<a name="line.115"></a>
119 <FONT color="green">116</FONT> public void setConvergence(final double convergence) {<a name="line.116"></a>
120 <FONT color="green">117</FONT> this.convergence = convergence;<a name="line.117"></a>
121 <FONT color="green">118</FONT> }<a name="line.118"></a>
122 <FONT color="green">119</FONT> <a name="line.119"></a>
123 <FONT color="green">120</FONT> /**<a name="line.120"></a>
124 <FONT color="green">121</FONT> * Set the steady state detection threshold.<a name="line.121"></a>
125 <FONT color="green">122</FONT> * &lt;p&gt;<a name="line.122"></a>
126 <FONT color="green">123</FONT> * The problem has converged has reached a steady state if<a name="line.123"></a>
127 <FONT color="green">124</FONT> * &lt;code&gt;Math.abs(J&lt;sub&gt;n&lt;/sub&gt; - J&lt;sub&gt;n-1&lt;/sub&gt;) &amp;lt;<a name="line.124"></a>
128 <FONT color="green">125</FONT> * J&lt;sub&gt;n&lt;/sub&gt; &amp;times convergence&lt;/code&gt;, where &lt;code&gt;J&lt;sub&gt;n&lt;/sub&gt;&lt;/code&gt;<a name="line.125"></a>
129 <FONT color="green">126</FONT> * and &lt;code&gt;J&lt;sub&gt;n-1&lt;/sub&gt;&lt;/code&gt; are the current and preceding criterion<a name="line.126"></a>
130 <FONT color="green">127</FONT> * values (square sum of the weighted residuals of considered measurements).<a name="line.127"></a>
131 <FONT color="green">128</FONT> * &lt;/p&gt;<a name="line.128"></a>
132 <FONT color="green">129</FONT> * @param steadyStateThreshold steady state detection threshold<a name="line.129"></a>
133 <FONT color="green">130</FONT> */<a name="line.130"></a>
134 <FONT color="green">131</FONT> public void setSteadyStateThreshold(final double steadyStateThreshold) {<a name="line.131"></a>
135 <FONT color="green">132</FONT> this.steadyStateThreshold = steadyStateThreshold;<a name="line.132"></a>
136 <FONT color="green">133</FONT> }<a name="line.133"></a>
137 <FONT color="green">134</FONT> <a name="line.134"></a>
138 <FONT color="green">135</FONT> /**<a name="line.135"></a>
139 <FONT color="green">136</FONT> * Solve an estimation problem using a least squares criterion.<a name="line.136"></a>
140 <FONT color="green">137</FONT> *<a name="line.137"></a>
141 <FONT color="green">138</FONT> * &lt;p&gt;This method set the unbound parameters of the given problem<a name="line.138"></a>
142 <FONT color="green">139</FONT> * starting from their current values through several iterations. At<a name="line.139"></a>
143 <FONT color="green">140</FONT> * each step, the unbound parameters are changed in order to<a name="line.140"></a>
144 <FONT color="green">141</FONT> * minimize a weighted least square criterion based on the<a name="line.141"></a>
145 <FONT color="green">142</FONT> * measurements of the problem.&lt;/p&gt;<a name="line.142"></a>
146 <FONT color="green">143</FONT> *<a name="line.143"></a>
147 <FONT color="green">144</FONT> * &lt;p&gt;The iterations are stopped either when the criterion goes<a name="line.144"></a>
148 <FONT color="green">145</FONT> * below a physical threshold under which improvement are considered<a name="line.145"></a>
149 <FONT color="green">146</FONT> * useless or when the algorithm is unable to improve it (even if it<a name="line.146"></a>
150 <FONT color="green">147</FONT> * is still high). The first condition that is met stops the<a name="line.147"></a>
151 <FONT color="green">148</FONT> * iterations. If the convergence it not reached before the maximum<a name="line.148"></a>
152 <FONT color="green">149</FONT> * number of iterations, an {@link EstimationException} is<a name="line.149"></a>
153 <FONT color="green">150</FONT> * thrown.&lt;/p&gt;<a name="line.150"></a>
154 <FONT color="green">151</FONT> *<a name="line.151"></a>
155 <FONT color="green">152</FONT> * @param problem estimation problem to solve<a name="line.152"></a>
156 <FONT color="green">153</FONT> * @exception EstimationException if the problem cannot be solved<a name="line.153"></a>
157 <FONT color="green">154</FONT> *<a name="line.154"></a>
158 <FONT color="green">155</FONT> * @see EstimationProblem<a name="line.155"></a>
159 <FONT color="green">156</FONT> *<a name="line.156"></a>
160 <FONT color="green">157</FONT> */<a name="line.157"></a>
161 <FONT color="green">158</FONT> @Override<a name="line.158"></a>
162 <FONT color="green">159</FONT> public void estimate(EstimationProblem problem)<a name="line.159"></a>
163 <FONT color="green">160</FONT> throws EstimationException {<a name="line.160"></a>
164 <FONT color="green">161</FONT> <a name="line.161"></a>
165 <FONT color="green">162</FONT> initializeEstimate(problem);<a name="line.162"></a>
166 <FONT color="green">163</FONT> <a name="line.163"></a>
167 <FONT color="green">164</FONT> // work matrices<a name="line.164"></a>
168 <FONT color="green">165</FONT> double[] grad = new double[parameters.length];<a name="line.165"></a>
169 <FONT color="green">166</FONT> ArrayRealVector bDecrement = new ArrayRealVector(parameters.length);<a name="line.166"></a>
170 <FONT color="green">167</FONT> double[] bDecrementData = bDecrement.getDataRef();<a name="line.167"></a>
171 <FONT color="green">168</FONT> RealMatrix wGradGradT = MatrixUtils.createRealMatrix(parameters.length, parameters.length);<a name="line.168"></a>
172 <FONT color="green">169</FONT> <a name="line.169"></a>
173 <FONT color="green">170</FONT> // iterate until convergence is reached<a name="line.170"></a>
174 <FONT color="green">171</FONT> double previous = Double.POSITIVE_INFINITY;<a name="line.171"></a>
175 <FONT color="green">172</FONT> do {<a name="line.172"></a>
176 <FONT color="green">173</FONT> <a name="line.173"></a>
177 <FONT color="green">174</FONT> // build the linear problem<a name="line.174"></a>
178 <FONT color="green">175</FONT> incrementJacobianEvaluationsCounter();<a name="line.175"></a>
179 <FONT color="green">176</FONT> RealVector b = new ArrayRealVector(parameters.length);<a name="line.176"></a>
180 <FONT color="green">177</FONT> RealMatrix a = MatrixUtils.createRealMatrix(parameters.length, parameters.length);<a name="line.177"></a>
181 <FONT color="green">178</FONT> for (int i = 0; i &lt; measurements.length; ++i) {<a name="line.178"></a>
182 <FONT color="green">179</FONT> if (! measurements [i].isIgnored()) {<a name="line.179"></a>
183 <FONT color="green">180</FONT> <a name="line.180"></a>
184 <FONT color="green">181</FONT> double weight = measurements[i].getWeight();<a name="line.181"></a>
185 <FONT color="green">182</FONT> double residual = measurements[i].getResidual();<a name="line.182"></a>
186 <FONT color="green">183</FONT> <a name="line.183"></a>
187 <FONT color="green">184</FONT> // compute the normal equation<a name="line.184"></a>
188 <FONT color="green">185</FONT> for (int j = 0; j &lt; parameters.length; ++j) {<a name="line.185"></a>
189 <FONT color="green">186</FONT> grad[j] = measurements[i].getPartial(parameters[j]);<a name="line.186"></a>
190 <FONT color="green">187</FONT> bDecrementData[j] = weight * residual * grad[j];<a name="line.187"></a>
191 <FONT color="green">188</FONT> }<a name="line.188"></a>
192 <FONT color="green">189</FONT> <a name="line.189"></a>
193 <FONT color="green">190</FONT> // build the contribution matrix for measurement i<a name="line.190"></a>
194 <FONT color="green">191</FONT> for (int k = 0; k &lt; parameters.length; ++k) {<a name="line.191"></a>
195 <FONT color="green">192</FONT> double gk = grad[k];<a name="line.192"></a>
196 <FONT color="green">193</FONT> for (int l = 0; l &lt; parameters.length; ++l) {<a name="line.193"></a>
197 <FONT color="green">194</FONT> wGradGradT.setEntry(k, l, weight * gk * grad[l]);<a name="line.194"></a>
198 <FONT color="green">195</FONT> }<a name="line.195"></a>
199 <FONT color="green">196</FONT> }<a name="line.196"></a>
200 <FONT color="green">197</FONT> <a name="line.197"></a>
201 <FONT color="green">198</FONT> // update the matrices<a name="line.198"></a>
202 <FONT color="green">199</FONT> a = a.add(wGradGradT);<a name="line.199"></a>
203 <FONT color="green">200</FONT> b = b.add(bDecrement);<a name="line.200"></a>
204 <FONT color="green">201</FONT> <a name="line.201"></a>
205 <FONT color="green">202</FONT> }<a name="line.202"></a>
206 <FONT color="green">203</FONT> }<a name="line.203"></a>
207 <FONT color="green">204</FONT> <a name="line.204"></a>
208 <FONT color="green">205</FONT> try {<a name="line.205"></a>
209 <FONT color="green">206</FONT> <a name="line.206"></a>
210 <FONT color="green">207</FONT> // solve the linearized least squares problem<a name="line.207"></a>
211 <FONT color="green">208</FONT> RealVector dX = new LUDecompositionImpl(a).getSolver().solve(b);<a name="line.208"></a>
212 <FONT color="green">209</FONT> <a name="line.209"></a>
213 <FONT color="green">210</FONT> // update the estimated parameters<a name="line.210"></a>
214 <FONT color="green">211</FONT> for (int i = 0; i &lt; parameters.length; ++i) {<a name="line.211"></a>
215 <FONT color="green">212</FONT> parameters[i].setEstimate(parameters[i].getEstimate() + dX.getEntry(i));<a name="line.212"></a>
216 <FONT color="green">213</FONT> }<a name="line.213"></a>
217 <FONT color="green">214</FONT> <a name="line.214"></a>
218 <FONT color="green">215</FONT> } catch(InvalidMatrixException e) {<a name="line.215"></a>
219 <FONT color="green">216</FONT> throw new EstimationException("unable to solve: singular problem");<a name="line.216"></a>
220 <FONT color="green">217</FONT> }<a name="line.217"></a>
221 <FONT color="green">218</FONT> <a name="line.218"></a>
222 <FONT color="green">219</FONT> <a name="line.219"></a>
223 <FONT color="green">220</FONT> previous = cost;<a name="line.220"></a>
224 <FONT color="green">221</FONT> updateResidualsAndCost();<a name="line.221"></a>
225 <FONT color="green">222</FONT> <a name="line.222"></a>
226 <FONT color="green">223</FONT> } while ((getCostEvaluations() &lt; 2) ||<a name="line.223"></a>
227 <FONT color="green">224</FONT> (Math.abs(previous - cost) &gt; (cost * steadyStateThreshold) &amp;&amp;<a name="line.224"></a>
228 <FONT color="green">225</FONT> (Math.abs(cost) &gt; convergence)));<a name="line.225"></a>
229 <FONT color="green">226</FONT> <a name="line.226"></a>
230 <FONT color="green">227</FONT> }<a name="line.227"></a>
231 <FONT color="green">228</FONT> <a name="line.228"></a>
232 <FONT color="green">229</FONT> }<a name="line.229"></a>
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293 </PRE>
294 </BODY>
295 </HTML>