00001 #include "cs.h"
00002
00003 static int cs_bfs (const cs *A, int n, int *wi, int *wj, int *queue,
00004 const int *imatch, const int *jmatch, int mark)
00005 {
00006 int *Ap, *Ai, head = 0, tail = 0, j, i, p, j2 ;
00007 cs *C ;
00008 for (j = 0 ; j < n ; j++)
00009 {
00010 if (imatch [j] >= 0) continue ;
00011 wj [j] = 0 ;
00012 queue [tail++] = j ;
00013 }
00014 if (tail == 0) return (1) ;
00015 C = (mark == 1) ? ((cs *) A) : cs_transpose (A, 0) ;
00016 if (!C) return (0) ;
00017 Ap = C->p ; Ai = C->i ;
00018 while (head < tail)
00019 {
00020 j = queue [head++] ;
00021 for (p = Ap [j] ; p < Ap [j+1] ; p++)
00022 {
00023 i = Ai [p] ;
00024 if (wi [i] >= 0) continue ;
00025 wi [i] = mark ;
00026 j2 = jmatch [i] ;
00027 if (wj [j2] >= 0) continue ;
00028 wj [j2] = mark ;
00029 queue [tail++] = j2 ;
00030 }
00031 }
00032 if (mark != 1) cs_spfree (C) ;
00033 return (1) ;
00034 }
00035
00036
00037 static void cs_matched (int n, const int *wj, const int *imatch, int *p, int *q,
00038 int *cc, int *rr, int set, int mark)
00039 {
00040 int kc = cc [set], j ;
00041 int kr = rr [set-1] ;
00042 for (j = 0 ; j < n ; j++)
00043 {
00044 if (wj [j] != mark) continue ;
00045 p [kr++] = imatch [j] ;
00046 q [kc++] = j ;
00047 }
00048 cc [set+1] = kc ;
00049 rr [set] = kr ;
00050 }
00051
00052
00053 static void cs_unmatched (int m, const int *wi, int *p, int *rr, int set)
00054 {
00055 int i, kr = rr [set] ;
00056 for (i = 0 ; i < m ; i++) if (wi [i] == 0) p [kr++] = i ;
00057 rr [set+1] = kr ;
00058 }
00059
00060
00061 static int cs_rprune (int i, int j, double aij, void *other)
00062 {
00063 int *rr = (int *) other ;
00064 return (i >= rr [1] && i < rr [2]) ;
00065 }
00066
00067
00068 csd *cs_dmperm (const cs *A, int seed)
00069 {
00070 int m, n, i, j, k, cnz, nc, *jmatch, *imatch, *wi, *wj, *pinv, *Cp, *Ci,
00071 *ps, *rs, nb1, nb2, *p, *q, *cc, *rr, *r, *s, ok ;
00072 cs *C ;
00073 csd *D, *scc ;
00074
00075 if (!CS_CSC (A)) return (NULL) ;
00076 m = A->m ; n = A->n ;
00077 D = cs_dalloc (m, n) ;
00078 if (!D) return (NULL) ;
00079 p = D->p ; q = D->q ; r = D->r ; s = D->s ; cc = D->cc ; rr = D->rr ;
00080 jmatch = cs_maxtrans (A, seed) ;
00081 imatch = jmatch + m ;
00082 if (!jmatch) return (cs_ddone (D, NULL, jmatch, 0)) ;
00083
00084 wi = r ; wj = s ;
00085 for (j = 0 ; j < n ; j++) wj [j] = -1 ;
00086 for (i = 0 ; i < m ; i++) wi [i] = -1 ;
00087 cs_bfs (A, n, wi, wj, q, imatch, jmatch, 1) ;
00088 ok = cs_bfs (A, m, wj, wi, p, jmatch, imatch, 3) ;
00089 if (!ok) return (cs_ddone (D, NULL, jmatch, 0)) ;
00090 cs_unmatched (n, wj, q, cc, 0) ;
00091 cs_matched (n, wj, imatch, p, q, cc, rr, 1, 1) ;
00092 cs_matched (n, wj, imatch, p, q, cc, rr, 2, -1) ;
00093 cs_matched (n, wj, imatch, p, q, cc, rr, 3, 3) ;
00094 cs_unmatched (m, wi, p, rr, 3) ;
00095 cs_free (jmatch) ;
00096
00097 pinv = cs_pinv (p, m) ;
00098 if (!pinv) return (cs_ddone (D, NULL, NULL, 0)) ;
00099 C = cs_permute (A, pinv, q, 0) ;
00100 cs_free (pinv) ;
00101 if (!C) return (cs_ddone (D, NULL, NULL, 0)) ;
00102 Cp = C->p ;
00103 nc = cc [3] - cc [2] ;
00104 if (cc [2] > 0) for (j = cc [2] ; j <= cc [3] ; j++) Cp [j-cc[2]] = Cp [j] ;
00105 C->n = nc ;
00106 if (rr [2] - rr [1] < m)
00107 {
00108 cs_fkeep (C, cs_rprune, rr) ;
00109 cnz = Cp [nc] ;
00110 Ci = C->i ;
00111 if (rr [1] > 0) for (k = 0 ; k < cnz ; k++) Ci [k] -= rr [1] ;
00112 }
00113 C->m = nc ;
00114 scc = cs_scc (C) ;
00115 if (!scc) return (cs_ddone (D, C, NULL, 0)) ;
00116
00117 ps = scc->p ;
00118 rs = scc->r ;
00119 nb1 = scc->nb ;
00120 for (k = 0 ; k < nc ; k++) wj [k] = q [ps [k] + cc [2]] ;
00121 for (k = 0 ; k < nc ; k++) q [k + cc [2]] = wj [k] ;
00122 for (k = 0 ; k < nc ; k++) wi [k] = p [ps [k] + rr [1]] ;
00123 for (k = 0 ; k < nc ; k++) p [k + rr [1]] = wi [k] ;
00124 nb2 = 0 ;
00125 r [0] = s [0] = 0 ;
00126 if (cc [2] > 0) nb2++ ;
00127 for (k = 0 ; k < nb1 ; k++)
00128 {
00129 r [nb2] = rs [k] + rr [1] ;
00130 s [nb2] = rs [k] + cc [2] ;
00131 nb2++ ;
00132 }
00133 if (rr [2] < m)
00134 {
00135 r [nb2] = rr [2] ;
00136 s [nb2] = cc [3] ;
00137 nb2++ ;
00138 }
00139 r [nb2] = m ;
00140 s [nb2] = n ;
00141 D->nb = nb2 ;
00142 cs_dfree (scc) ;
00143 return (cs_ddone (D, C, NULL, 1)) ;
00144 }