00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #include <string.h>
00018 #include <math.h>
00019 #include <stdio.h>
00020 #include <cblas.h>
00021 #include <clapack.h>
00022 #include "amb_kf.h"
00023
00024
00025
00026 #define LOOPMAX 10000
00027
00028 #define SGN(x) ((x)<=0.0?-1.0:1.0)
00029 #define ROUND(x) (floor((x)+0.5))
00030 #define SWAP(x,y) do {double tmp_; tmp_=x; x=y; y=tmp_;} while (0)
00031
00032
00033 static void triu2(u32 n, double *M)
00034 {
00035 for (u32 i=1; i<n; i++) {
00036 for (u32 j=0; j<i; j++) {
00037 M[i*n + j] = 0;
00038 }
00039 }
00040 }
00041
00042 static void eye2(u32 n, double *M)
00043 {
00044 memset(M, 0, n * n * sizeof(double));
00045 for (u32 i=0; i<n; i++) {
00046 M[i*n + i] = 1;
00047 }
00048 }
00049
00050 static s8 udu2(u32 n, double *M, double *U, double *D)
00051 {
00052 double alpha, beta;
00053 triu2(n, M);
00054 eye2(n, U);
00055 memset(D, 0, n * sizeof(double));
00056
00057 for (u32 j=n; j>=2; j--) {
00058 D[j - 1] = M[(j-1)*n + j-1];
00059 if (D[j-1] > 0) {
00060 alpha = 1.0 / D[j-1];
00061 } else {
00062 alpha = 0.0;
00063 }
00064 for (u32 k=1; k<j; k++) {
00065 beta = M[(k-1)*n + j-1];
00066 U[(k-1)*n + j-1] = alpha * beta;
00067 for (u32 kk = 0; kk < k; kk++) {
00068 M[kk*n + k-1] = M[kk*n + k-1] - beta * U[kk*n + j-1];
00069 }
00070 }
00071
00072 }
00073 D[0] = M[0];
00074 return 0;
00075 }
00076
00077
00078
00079 int LD(int n, const double *Q, double *L, double *D)
00080 {
00081 int i,j,k,info=0;
00082 double a;
00083 double A[n*n];
00084 memset(L, 0, sizeof(double)*n*n);
00085 memset(D, 0, sizeof(double)*n);
00086
00087 memcpy(A,Q,sizeof(double)*n*n);
00088 for (i=n-1;i>=0;i--) {
00089 if ((D[i]=A[i+i*n])<=0.0) {info=-1; break;}
00090 a=sqrt(D[i]);
00091 for (j=0;j<=i;j++) L[i+j*n]=A[i+j*n]/a;
00092 for (j=0;j<=i-1;j++) for (k=0;k<=j;k++) A[j+k*n]-=L[i+k*n]*L[i+j*n];
00093 for (j=0;j<=i;j++) L[i+j*n]/=L[i+i*n];
00094 }
00095 if (info) {
00096 printf("%s : LD factorization error, trying UD from Gibbs (col major UD = LD)\n",__FILE__);
00097 double Qcopy[n * n];
00098 memcpy(Qcopy, Q, n * n * sizeof(double));
00099 udu2(n, Qcopy, L, D);
00100 }
00101 return info;
00102 }
00103
00104 void gauss(int n, double *L, double *Z, int i, int j)
00105 {
00106 int k,mu;
00107
00108 if ((mu=(int)ROUND(L[i+j*n]))!=0) {
00109 for (k=i;k<n;k++) L[k+n*j]-=(double)mu*L[k+i*n];
00110 for (k=0;k<n;k++) Z[k+n*j]-=(double)mu*Z[k+i*n];
00111 }
00112 }
00113
00114 void perm(int n, double *L, double *D, int j, double del, double *Z)
00115 {
00116 int k;
00117 double eta,lam,a0,a1;
00118
00119 eta=D[j]/del;
00120 lam=D[j+1]*L[j+1+j*n]/del;
00121 D[j]=eta*D[j+1]; D[j+1]=del;
00122 for (k=0;k<=j-1;k++) {
00123 a0=L[j+k*n]; a1=L[j+1+k*n];
00124 L[j+k*n]=-L[j+1+j*n]*a0+a1;
00125 L[j+1+k*n]=eta*a0+lam*a1;
00126 }
00127 L[j+1+j*n]=lam;
00128 for (k=j+2;k<n;k++) SWAP(L[k+j*n],L[k+(j+1)*n]);
00129 for (k=0;k<n;k++) SWAP(Z[k+j*n],Z[k+(j+1)*n]);
00130 }
00131
00132 void reduction(int n, double *L, double *D, double *Z)
00133 {
00134 int i,j,k;
00135 double del;
00136
00137 j=n-2; k=n-2;
00138 while (j>=0) {
00139 if (j<=k) for (i=j+1;i<n;i++) gauss(n,L,Z,i,j);
00140 del=D[j]+L[j+1+j*n]*L[j+1+j*n]*D[j+1];
00141 if (del+1E-6<D[j+1]) {
00142 perm(n,L,D,j,del,Z);
00143 k=j; j=n-2;
00144 }
00145 else j--;
00146 }
00147 }
00148
00149 static int search(int n, int m, const double *L, const double *D,
00150 const double *zs, double *zn, double *s)
00151 {
00152 int i,j,k,c,nn=0,imax=0;
00153 double newdist,maxdist=1E99,y;
00154 double S[n*n];
00155 double dist[n];
00156 double zb[n];
00157 double z[n];
00158 double step[n];
00159 memset(S, 0, sizeof(double)*n*n);
00160
00161 k=n-1; dist[k]=0.0;
00162 zb[k]=zs[k];
00163 z[k]=ROUND(zb[k]); y=zb[k]-z[k]; step[k]=SGN(y);
00164 for (c=0;c<LOOPMAX;c++) {
00165 newdist=dist[k]+y*y/D[k];
00166 if (newdist<maxdist) {
00167 if (k!=0) {
00168 dist[--k]=newdist;
00169 for (i=0;i<=k;i++)
00170 S[k+i*n]=S[k+1+i*n]+(z[k+1]-zb[k+1])*L[k+1+i*n];
00171 zb[k]=zs[k]+S[k+k*n];
00172 z[k]=ROUND(zb[k]); y=zb[k]-z[k]; step[k]=SGN(y);
00173 }
00174 else {
00175 if (nn<m) {
00176 if (nn==0||newdist>s[imax]) imax=nn;
00177 for (i=0;i<n;i++) zn[i+nn*n]=z[i];
00178 s[nn++]=newdist;
00179 }
00180 else {
00181 if (newdist<s[imax]) {
00182 for (i=0;i<n;i++) zn[i+imax*n]=z[i];
00183 s[imax]=newdist;
00184 for (i=imax=0;i<m;i++) if (s[imax]<s[i]) imax=i;
00185 }
00186 maxdist=s[imax];
00187 }
00188 z[0]+=step[0]; y=zb[0]-z[0]; step[0]=-step[0]-SGN(step[0]);
00189 }
00190 }
00191 else {
00192 if (k==n-1) break;
00193 else {
00194 k++;
00195 z[k]+=step[k]; y=zb[k]-z[k]; step[k]=-step[k]-SGN(step[k]);
00196 }
00197 }
00198 }
00199 for (i=0;i<m-1;i++) {
00200 for (j=i+1;j<m;j++) {
00201 if (s[i]<s[j]) continue;
00202 SWAP(s[i],s[j]);
00203 for (k=0;k<n;k++) SWAP(zn[k+i*n],zn[k+j*n]);
00204 }
00205 }
00206
00207 if (c>=LOOPMAX) {
00208 fprintf(stderr,"%s : search loop count overflow\n",__FILE__);
00209 return -1;
00210 }
00211 return 0;
00212 }
00213
00214
00215
00216
00217
00218
00219
00220
00221
00222
00223 int lambda_reduction(int n, const double *Q, double *Z)
00224 {
00225 int info;
00226
00227 if (n<=0) return -1;
00228
00229 double L[n*n];
00230 double D[n];
00231
00232
00233 memset(L, 0, sizeof(double)*n*n);
00234
00235
00236 memset(Z, 0, sizeof(double)*n*n);
00237 for (int i=0; i<n; i++)
00238 Z[i+n*i] = 1;
00239
00240
00241 if (!(info=LD(n,Q,L,D))) {
00242
00243 reduction(n,L,D,Z);
00244 }
00245
00246 return info;
00247 }
00248
00249
00250
00251
00252
00253
00254
00255
00256
00257
00258
00259 void matmul(const char *tr, integer n, integer k, integer m, double alpha,
00260 const double *A, const double *B, double beta, double *C)
00261 {
00262 integer lda=tr[0]=='T'?m:n,ldb=tr[1]=='T'?k:m;
00263
00264 dgemm_((char *)tr,(char *)tr+1,&n,&k,&m,&alpha,(double *)A,&lda,(double *)B,
00265 &ldb,&beta,C,&n);
00266 }
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277
00278
00279 int solve(const char *tr, const double *A, const double *Y, integer n,
00280 integer m, double *X)
00281 {
00282 double B[n*n];
00283 integer info;
00284 integer ipiv[n];
00285
00286 memcpy(B, A, sizeof(double)*n*n);
00287 memcpy(X, Y, sizeof(double)*n*m);
00288 dgetrf_(&n,&n,B,&n,ipiv,&info);
00289 if (!info) dgetrs_((char *)tr,&n,&m,B,&n,ipiv,X,&n,&info);
00290 return info;
00291 }
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304
00305
00306 int lambda_solution(int n, int m, const double *a, const double *Q, double *F,
00307 double *s)
00308 {
00309 int info;
00310
00311 if (n<=0||m<=0) return -1;
00312 double L[n*n];
00313 double D[n];
00314 double Z[n*n];
00315 double z[n];
00316 double E[n*m];
00317
00318
00319 memset(L, 0, sizeof(double)*n*n);
00320
00321
00322 memset(Z, 0, sizeof(double)*n*n);
00323 for (int i=0; i<n; i++)
00324 Z[i+n*i] = 1;
00325
00326
00327 if (!(info=LD(n,Q,L,D))) {
00328
00329
00330 reduction(n,L,D,Z);
00331 matmul("TN",n,1,n,1.0,Z,a,0.0,z);
00332
00333
00334 if (!(info=search(n,m,L,D,z,E,s))) {
00335
00336 info=solve("T",Z,E,n,m,F);
00337 }
00338 }
00339 return info;
00340 }