Skip to content

Commit 1619d02

Browse files
committed
128!:12 finished
1 parent 6000bd4 commit 1619d02

3 files changed

Lines changed: 106 additions & 49 deletions

File tree

jsrc/va2.c

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ DF2(jtsumattymes1){
10371037
// install the shape
10381038
MCISH(zs,longs,af+commonf); MCISH(zs+af+commonf,ws+wr-wcr,wcr-1);
10391039
}
1040-
if(unlikely(fit==2))AS(w)[AR(w)-1]=2; // if +/@:*"1!.1, we store two atoms per sum
1040+
if(unlikely(fit==2))AS(z)[AR(z)-1]=2; // if +/@:*"1!.1, we store two atoms per sum
10411041

10421042
if(likely(fit==0)){RZ(jtsumattymesprods(jt,it,voidAV(a),voidAV(w),dplen,nfro,nfri,ndpo,ndpi,voidAV(z))); // eval standard dot-product, check for error
10431043
}else{
@@ -1046,7 +1046,7 @@ DF2(jtsumattymes1){
10461046
#if (C_AVX2&&SY_64) || EMU_AVX2
10471047
#if 1 // higher precision. Required when a large product is added to a small total. Dependency loop for acc is 4 clocks; for c is 4 clocks. Total 12 insts, so unrolled 2 would do
10481048
#define OGITA(in0,in1,n) TWOPROD(in0,in1,h,y) TWOSUM(acc##n,h,acc##n,q) c##n=_mm256_add_pd(_mm256_add_pd(q,y),c##n);
1049-
#else
1049+
#else // obsolete
10501050
#define OGITA(in0,in1,n) TWOPROD(in0,in1,h,y) c##n=_mm256_add_pd(y,c##n); KAHAN(h,n)
10511051
#endif
10521052
__m256i endmask; /* length mask for the last word */
@@ -1084,13 +1084,22 @@ DF2(jtsumattymes1){
10841084
c0=_mm256_add_pd(c0,c1); c2=_mm256_add_pd(c2,c3); c0=_mm256_add_pd(c0,c2); // add all the low parts together - the low bits of the low will not make it through to the result
10851085
TWOSUM(acc0,acc1,acc0,c1) TWOSUM(acc2,acc3,acc2,c2) c2=_mm256_add_pd(c1,c2); c0=_mm256_add_pd(c0,c2); // add 0+1, 2+3
10861086
TWOSUM(acc0,acc2,acc0,c1) c0=_mm256_add_pd(c0,c1); // 0+2
1087-
// acc0/c0 survive. Combine horizontally
1088-
c0=_mm256_add_pd(c0,_mm256_permute4x64_pd(c0,0b11111110)); acc1=_mm256_permute4x64_pd(acc0,0b11111110); // c0: 01+=23, acc1<-23
1089-
TWOSUM(acc0,acc1,acc0,c1); c0=_mm256_add_pd(c0,c1); // combine p=01+23
1090-
c0=_mm256_add_pd(c0,_mm256_permute_pd(c0,0xf)); acc1=_mm256_permute_pd(acc0,0xf); // combine c0+c1, acc1<-1
1091-
TWOSUM(acc0,acc1,acc0,c1); c0=_mm256_add_pd(c0,c1); // combine 0123, combine all low parts
1092-
acc0=_mm256_add_pd(acc0,c0); // add low parts back into high in case there is overlap
1093-
#else
1087+
// acc0/c0 survive. Combine horizontally. Anything the high part touches must be extended precision; the low in one float. We guarantee extended precision from
1088+
// the largest intermediate total encountered; sometimes we get a little more.
1089+
c0=_mm256_add_pd(c0,_mm256_permute4x64_pd(c0,0b11111110)); acc1=_mm256_permute4x64_pd(acc0,0b11111110); // c0: lo01+=lo23, acc1<-hi23
1090+
TWOSUM(acc0,acc1,acc0,c1); c0=_mm256_add_pd(c0,c1); // combine acc0 = hi0+2/1+3, c0 accumulates lo0+lo2+extension0, lo1+lo3+extension1
1091+
c0=_mm256_add_pd(c0,_mm256_permute_pd(c0,0xf)); acc1=_mm256_permute_pd(acc0,0xf); // c0[0] has total of all loe parts, acc1=hi1+hi3
1092+
TWOSUM(acc0,acc1,acc0,c1); c0=_mm256_add_pd(c0,c1); // acc0 has sum of all hi parts, c1 sum of all low parts+extensions
1093+
if(fit==1){
1094+
// normal result. Just add the extensions into the hi part
1095+
acc0=_mm256_add_pd(acc0,c0); // add low parts back into high in case there is overlap
1096+
}else{
1097+
// extended result. We must preserve the extension bits in the total and write them out
1098+
TWOSUM(acc0,c0,acc0,c1); // extended total
1099+
zv[1]=_mm256_cvtsd_f64(c1); // store it out
1100+
1101+
}
1102+
#else // obsolete
10941103
c0=_mm256_add_pd(c0,c1); c2=_mm256_add_pd(c2,c3); c0=_mm256_add_pd(c0,c2); // add all the low parts together - the low bits of the low will not make it through to the result
10951104
acc0=_mm256_add_pd(acc0,acc1); acc2=_mm256_add_pd(acc2,acc3); acc0=_mm256_add_pd(acc0,acc2); // add all the high parts
10961105
// acc0/c0 survive. Combine horizontally
@@ -1100,7 +1109,7 @@ DF2(jtsumattymes1){
11001109
acc0=_mm256_add_pd(acc0,_mm256_permute_pd(acc0,0xf));
11011110
acc0=_mm256_add_pd(acc0,c0); // add low parts back into high in case there is overlap
11021111
#endif
1103-
*zv=_mm256_cvtsd_f64(acc0); ++zv;
1112+
*zv=_mm256_cvtsd_f64(acc0); zv+=fit; // store out high (perhaps only) part
11041113
if(!--j)break; av=av0; // repeat a if needed
11051114
}
11061115
}

jsrc/vfrom.c

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ struct __attribute__((aligned(CACHELINESIZE))) mvmctx {
687687
I (*axv)[2]; // pointer to ax data
688688
I *amv0; // pointer to am block, and to the selected column's indexes
689689
D *avv0; // pointer to av data, and to the selected column's values
690-
D *mv0; // pointer to M data, and to the selected column (for identity columns) or to the current row of M (for sparse columns)
690+
A qk; // original M block
691691
} ;
692692

693693

@@ -697,7 +697,7 @@ static unsigned char jtmvmsparsex(J jt,void* const ctx,UI4 ti){
697697
// transfer everything out of ctx into local names
698698
#define YC(x) typeof(((struct mvmctx*)ctx)->x) x=((struct mvmctx*)ctx)->x;
699699
YC(ndxa)YC(n)YC(minimp)YC(bv)YC(thresh)YC(bestcol)YC(bestcolrow)YC(zv)YC(Frow)
700-
YC(impfac)YC(prirow)YC(bvgrd0)YC(bvgrde)YC(exlist)YC(nexlist)YC(yk)YC(bkmin)YC(axv)YC(amv0)YC(avv0)YC(mv0)
700+
YC(impfac)YC(prirow)YC(bvgrd0)YC(bvgrde)YC(exlist)YC(nexlist)YC(yk)YC(bkmin)YC(axv)YC(amv0)YC(avv0)YC(qk)
701701
#undef YC
702702
// perform the operation
703703

@@ -746,6 +746,7 @@ static unsigned char jtmvmsparsex(J jt,void* const ctx,UI4 ti){
746746
I taskno=jt->ndxinthreadpool; // default task# to try
747747
if(unlikely(taskno>=AN(ndxa)))taskno=taskno%AN(ndxa); // if more processors than tasks, take a value in range
748748
I trymask=1LL<<taskno; // mask we will try to take
749+
D *mv0=DAV(qk); // pointer to
749750
while(1){
750751
I oldmask=__atomic_fetch_or(&((struct mvmctx*)ctx)->taskmask,trymask,__ATOMIC_ACQ_REL); // try to reserve a task
751752
if(!(oldmask&trymask))break; // if we reserved it, continue
@@ -758,10 +759,8 @@ static unsigned char jtmvmsparsex(J jt,void* const ctx,UI4 ti){
758759
if(bv!=0&&prirow>=0)zv=0; // if we have DIP with a priority row, signal to process ALL rows in bk order. It's not really needed but it ensures that if the prirow is tied for pivot in the first
759760
// column, we will take it
760761

761-
#define COLLPINIT I *bvgrd=bvgrd0; I i=-1; D *mv=mv0-n; D bkold=inf, cold=1.0; I bkle0=1;
762-
#define COLLP do{if(unlikely(zv!=0)){++i; mv+=n;}else{i=*bvgrd; mv=mv0+n*i;} // for each row, i is the row#, mv points to the beginning of the row of M. If we take the whole col, take it in order for cache. Prefetch next row?
763-
#define COLLPE }while(++bvgrd!=bvgrde);
764762
do{
763+
// start of processing one column
765764
I colx=*ndx; // get next column# to work on
766765
I limitrow; // the best row to use as a pivot for this column; or # qualifying Dpiv found.
767766
if(likely(bv!=0)){
@@ -772,6 +771,12 @@ static unsigned char jtmvmsparsex(J jt,void* const ctx,UI4 ti){
772771
}
773772
limitrow=-1; // init no eligible row found
774773
}else limitrow=0; // for Dpiv, init to none found
774+
#if 0
775+
// process the column NPAR values at a time
776+
#else
777+
#define COLLPINIT I *bvgrd=bvgrd0; I i=-1; D *mv=mv0-n; D bkold=inf, cold=1.0; I bkle0=1;
778+
#define COLLP do{if(unlikely(zv!=0)){++i; mv+=n;}else{i=*bvgrd; mv=mv0+n*i;} // for each row, i is the row#, mv points to the beginning of the row of M. If we take the whole col, take it in order for cache. Prefetch next row?
779+
#define COLLPE }while(++bvgrd!=bvgrde);
775780
COLLPINIT
776781
// if the column is just to be fetched from M, do so without dot-product. We can use gather down the column, but there's no gain
777782
__m256d dotprod; // place where product is assembled or read into
@@ -869,6 +874,7 @@ static unsigned char jtmvmsparsex(J jt,void* const ctx,UI4 ti){
869874
COLLPE
870875
} // end 'long product'
871876
} // end 'needed dot-product'
877+
#endif
872878
// done with one column. Collect stats and update parms for the next column
873879
if(bv==0){ // one product, or Dpiv
874880
if(zv)break; // if just one product, skip the setup for next column
@@ -952,10 +958,11 @@ static unsigned char jtmvmsparsex(J jt,void* const ctx,UI4 ti){
952958

953959
// 128!:9 matrix times sparse vector with optional early exit
954960
// product mode:
955-
// y is ndx;Ax;Am;Av;(M, shape m,n) where ndx is an atom
956-
// if ndx<m, the column is ndx {"1 M; otherwise ((((ndx-m){Ax) ];.0 Am) {"1 M) +/@:*"1 ((ndx-m){Ax) ];.0 Av
957-
// Result for product mode (exitvec is scalar) is the product
958-
// DIP mode
961+
// y is ndx;Ax;Am;Av;(M, shape m,n) where ndx is an atom
962+
// if ndx<m, the column is ndx {"1 M; otherwise ((((ndx-m){Ax) ];.0 Am) {"1 M) +/@:*"1 ((ndx-m){Ax) ];.0 Av
963+
// if M has rank 3 (with 2={.$M), do the product in extended precision
964+
// Result for product mode (exitvec is scalar) is the product, one column of M
965+
// DIP/Dpiv mode:
959966
// y is ndx;Ax;Am;Av;(M, shape m,n);bkgrd;(ColThreshold/PivTol,MinPivot,bkmin,NFreeCols,NCols,ImpFac,Virtx/Dpivdir);bk/'';Frow[;exclusion list/Dpiv;Yk]
960967
// Result is rc,best row,best col,#cols scanned,#dot-products evaluated,best gain (if rc e. 0 1 2)
961968
// rc,failing column of NTT, an element of ndx (if rc=4)
@@ -975,7 +982,7 @@ F1(jtmvmsparse){PROLOG(832);
975982
ASSERT(AR(C(AAV(w)[1]))==3&&AS(C(AAV(w)[1]))[1]==2&&AS(C(AAV(w)[1]))[2]==1,EVRANK); // Ax, shape cols,2 1
976983
ASSERT(AR(C(AAV(w)[2]))==1,EVRANK); // Am
977984
ASSERT(AR(C(AAV(w)[3]))==1,EVRANK); // Av
978-
ASSERT(AR(C(AAV(w)[4]))==2,EVRANK); // M
985+
ASSERT(AR(C(AAV(w)[4]))==2||(AR(C(AAV(w)[4]))==3&&AR(C(AAV(w)[0]))==0),EVRANK); // M
979986
// abort if no columns
980987
if(AN(C(AAV(w)[0]))==0)R num(6); // if no cols (which happens at startup, return error indic)
981988
// check types. Don't convert - force the user to get it right
@@ -985,7 +992,7 @@ F1(jtmvmsparse){PROLOG(832);
985992
ASSERT(AT(C(AAV(w)[4]))&FL,EVDOMAIN); // M
986993
// check agreement
987994
ASSERT(AS(C(AAV(w)[2]))[0]==AS(C(AAV(w)[3]))[0],EVLENGTH); // Am and Av
988-
ASSERT(AS(C(AAV(w)[4]))[0]==AS(C(AAV(w)[4]))[1],EVLENGTH); // M is square
995+
ASSERT(AS(C(AAV(w)[4]))[AR(C(AAV(w)[4]))-2]==AS(C(AAV(w)[4]))[AR(C(AAV(w)[4]))-1],EVLENGTH); // M is square
989996

990997
// indexes must be an atom, a single list of integers, or a list of boxes containing integers
991998
// we don't allow conversion so as to force the user to get it right, for speed
@@ -1001,7 +1008,7 @@ F1(jtmvmsparse){PROLOG(832);
10011008
// extract pointers to tables
10021009
D minimp=0.0; // (always neg) min improvement we will accept, best improvement in any column so far. Init to 0 so we take first column with a pivot
10031010

1004-
I n=AS(C(AAV(w)[4]))[0]; // n=#rows/cols in M
1011+
I n=AS(C(AAV(w)[4]))[1]; // n=#rows/cols in M
10051012
// convert types as needed; set ?v=pointer to data area for ?
10061013
D *bv; // pointer to b values if there are any
10071014
__m256d thresh; // ColThr Inf bkmin MinPivot validity thresholds, small positive values
@@ -1017,9 +1024,10 @@ F1(jtmvmsparse){PROLOG(832);
10171024
if(AR(C(AAV(w)[0]))==0){
10181025
// single index value. set bv=0, zv non0 as a flag that we are storing the column
10191026
bv=0; ASSERT(AN(w)==5,EVLENGTH); // if goodvec is an atom, set bv=0 to indicate that bv is not used and verify no more input
1020-
if(unlikely(n==0)){R reshape(sc(n),zeroionei(0));} // empty M, each product is 0
1021-
GATV0(z,FL,n,1); zv=DAV(z); // allocate the result area for column extraction. Set zv nonzero so we use bkgrd of i. #M
1022-
bvgrd0=0; bvgrde=bvgrd0+AS(C(AAV(w)[4]))[0]; // length of column is #M
1027+
if(unlikely(n==0)){R reshape(drop(num(-1),shape(C(AAV(w)[4]))),zeroionei(0));} // empty M, each product is 0
1028+
I epcol=AR(C(AAV(w)[4]))==3; // flag if we are doing an extended-precision column fetch
1029+
GATV(z,FL,n<<epcol,1+epcol,AS(C(AAV(w)[4]))); zv=DAV(z); // allocate the result area for column extraction. Set zv nonzero so we use bkgrd of i. #M
1030+
bvgrd0=0; bvgrde=bvgrd0+n; // length of column is #M
10231031
}else{
10241032
// A list of index values. We are doing the DIP calculation or Dpiv
10251033
ASSERT(AR(C(AAV(w)[5]))==1,EVRANK); ASSERT(AN(C(AAV(w)[5]))==0||AT(C(AAV(w)[5]))&INT,EVDOMAIN); bvgrd0=IAV(C(AAV(w)[5])); bvgrde=bvgrd0+AN(C(AAV(w)[5])); // bkgrd: the order of processing the rows, and end+1 ptr normally /: bk
@@ -1060,7 +1068,7 @@ F1(jtmvmsparse){PROLOG(832);
10601068

10611069
#define YC(n) .n=n,
10621070
struct mvmctx opctx={.ctxlock=0,.abortcolandrow=-1,.bestcolandrow={-1,-1},YC(ndxa)YC(n)YC(minimp)YC(bv)YC(thresh)YC(bestcol)YC(bestcolrow)YC(zv)YC(Frow)YC(nfreecolsd)
1063-
YC(ncolsd)YC(impfac)YC(prirow)YC(bvgrd0)YC(bvgrde)YC(exlist)YC(nexlist)YC(yk)YC(bkmin).axv=((I(*)[2])IAV(C(AAV(w)[1])))-n,.amv0=IAV(C(AAV(w)[2])),.avv0=DAV(C(AAV(w)[3])),.mv0=DAV(C(AAV(w)[4])),
1071+
YC(ncolsd)YC(impfac)YC(prirow)YC(bvgrd0)YC(bvgrde)YC(exlist)YC(nexlist)YC(yk)YC(bkmin).axv=((I(*)[2])IAV(C(AAV(w)[1])))-n,.amv0=IAV(C(AAV(w)[2])),.avv0=DAV(C(AAV(w)[3])),.qk=C(AAV(w)[4]),
10641072
.ndotprods=0,.ncolsproc=0,.taskmask=0};
10651073
#undef YC
10661074

@@ -1109,10 +1117,10 @@ static unsigned char jtekupdatex(J jt,void* const ctx,UI4 ti){
11091117

11101118
__m256d pcoldh, pcoldl=_mm256_setzero_pd(); // value from pivotcolnon0, multiplying one row
11111119
__m256d prowdh, prowdl=_mm256_setzero_pd(); // values from newrownon0
1112-
__m256d relfuzzcct=_mm256_set1_pd(1.0-relfuzz); // comparison tolerance
1120+
__m256d mrelfuzz=_mm256_set1_pd(relfuzz); // comparison tolerance
11131121
__m256d sgnbit=_mm256_broadcast_sd((D*)&Iimin);
11141122
I dpflag=0; // precision flags: 1=Qk 2=pivotcolnon0 4=newrownon0
1115-
D *qkv=DAV(qk); I qksize=AS(qk)[0]; I qksizesq=qksize*qksize; dpflag|=AR(qk)>2; // pointer to qk data, length of a row, offset to low part if present
1123+
D *qkv=DAV(qk); I qksize=AS(qk)[AR(qk)-1]; I qksizesq=qksize*qksize; dpflag|=AR(qk)>2; // pointer to qk data, length of a row, offset to low part if present
11161124
UI rowx=0, rown=AN(prx); I *rowxv=IAV(prx); D *pcn0v=DAV(pivotcolnon0); dpflag|=(AR(pivotcolnon0)>1)<<1; // current row, # rows, address of row indexes, column data
11171125
UI coln=AN(pcx); I *colxv=IAV(pcx); D *prn0v=DAV(newrownon0); dpflag|=(AR(newrownon0)>1)<<2; // # cols, address of col indexes. row data
11181126
// for each row
@@ -1141,42 +1149,52 @@ static unsigned char jtekupdatex(J jt,void* const ctx,UI4 ti){
11411149
}
11421150
// gather the high parts of Qk
11431151
__m256d qkvh=_mm256_setzero_pd(); qkvh=_mm256_mask_i64gather_pd(qkvh,qkvrow,prn0x,endmask,SZI);
1144-
// take product of high parts for fuzz comp
1145-
__m256d prodh=_mm256_mul_pd(pcoldh,prowdh); //
1146-
// calculate fuzzy not-equal. high parts are enough
1147-
// ((((a)>(cct)*(b))?1:0) == (((b)<=(cct)*(a))?1:0)) TNE
1148-
prodh=_mm256_xor_pd(_mm256_fmsub_pd(relfuzzcct,prodh,qkvh),_mm256_fmsub_pd(relfuzzcct,qkvh,prodh)); // sets sign of prodh if fuzzy ne, means keep the result
1152+
// create max(abs(qkvh),abs(pcoldh*prowdh)) which will go into threshold calc
1153+
__m256d maxabs=_mm256_max_pd(_mm256_andnot_pd(sgnbit,qkvh),_mm256_andnot_pd(sgnbit,_mm256_mul_pd(pcoldh,prowdh)));
11491154
if(!(dpflag&1)){
11501155
// single-precision calculation
11511156
// calculate old - pcol*prow
11521157
qkvh=_mm256_fnmadd_pd(prowdh,pcoldh,qkvh);
1158+
// convert maxabs to abs(qkvh) - maxabs*thresh: if < 0, means result should be forced to 0
1159+
maxabs=_mm256_fnmadd_pd(maxabs,mrelfuzz,_mm256_andnot_pd(sgnbit,qkvh));
11531160
}else{
11541161
// extended-precision calculation
1155-
__m256d qkvl; qkvl=_mm256_mask_i64gather_pd(qkvh,qkvrow+qksizesq,prn0x,endmask,SZI); // gather the low parts of Qk
11561162
__m256d iph,ipl,isl; // intermediate products and sums
1163+
__m256d qkvl; // low-order part of result
11571164

1158-
// (qkvh,qkvl) - (prowdh,prowdl) * (pcoldh,pcoldl)
11591165
// (iph,ipl) = - prowdh*pcoldh
1160-
TWOPROD(prowdh,pcoldh,iph,ipl) iph=_mm256_xor_pd(sgnbit,iph);
1161-
// Do high-precision add of qkvh and iph. If this decreases the absvalue of qkvh, we will lose precision because of insufficient
1162-
// bits of qkv. If this increases the absvalue of qkvh, all of qkvl will contribute and the limit of validity will be
1163-
// from the product. In either case it is safe to accumulate all the partial products and ipl into qkvl
1164-
qkvl=_mm256_sub_pd(qkvl,ipl); qkvl=_mm256_fnmadd_pd(prowdh,pcoldl,qkvl); qkvl=_mm256_fnmadd_pd(prowdl,pcoldh,qkvl); // the middle pps. low*low will never contribute unless qkv is exhausted & thus noise
1165-
TWOSUM(qkvh,iph,qkvh,isl) // combine the high parts
1166-
isl=_mm256_add_pd(isl,qkvl); // add the combined low parts
1167-
// Make sure qkvl is much less than qkvh
1168-
TWOSUM(qkvh,isl,qkvh,qkvl) // establish separation
1166+
TWOPROD(prowdh,pcoldh,iph,ipl) // (prowdh,prowdl) to high precision
1167+
iph=_mm256_xor_pd(sgnbit,iph); // change sign for subtract
1168+
if(_mm256_movemask_pd(_mm256_cmp_pd(qkvh,_mm256_setzero_pd(),_CMP_EQ_OQ))==0xf){
1169+
// qkvh is all 0 - the result is simply (-iph,-ipl)
1170+
qkvh=iph; qkvl=_mm256_xor_pd(sgnbit,ipl); // -iph, -ipl
1171+
maxabs=_mm256_setzero_pd(); // do not force any values to true zero
1172+
}else{
1173+
// normal case where qkvh not all 0
1174+
qkvl=_mm256_mask_i64gather_pd(qkvh,qkvrow+qksizesq,prn0x,endmask,SZI); // gather the low parts of Qk
1175+
// (qkvh,qkvl) - (prowdh,prowdl) * (pcoldh,pcoldl)
1176+
// Do high-precision add of qkvh and iph. If this decreases the absvalue of qkvh, we will lose precision because of insufficient
1177+
// bits of qkv. If this increases the absvalue of qkvh, all of qkvl will contribute and the limit of validity will be
1178+
// from the product. In either case it is safe to accumulate all the partial products and ipl into qkvl
1179+
qkvl=_mm256_sub_pd(qkvl,ipl); qkvl=_mm256_fnmadd_pd(prowdh,pcoldl,qkvl); qkvl=_mm256_fnmadd_pd(prowdl,pcoldh,qkvl); // the middle pps. low*low will never contribute unless qkv is exhausted & thus noise
1180+
TWOSUM(qkvh,iph,qkvh,isl) // combine the high parts
1181+
isl=_mm256_add_pd(isl,qkvl); // add the combined low parts
1182+
// Make sure qkvl is much less than qkvh
1183+
TWOSUM(qkvh,isl,qkvh,qkvl) // put pkvh into canonical form
1184+
// convert maxabs to abs(qkvh) - maxabs*thresh: if < 0, means result should be forced to 0
1185+
maxabs=_mm256_fnmadd_pd(maxabs,mrelfuzz,_mm256_andnot_pd(sgnbit,qkvh));
11691186

1170-
// zero if lower than fuzz
1171-
qkvl=_mm256_blendv_pd(_mm256_setzero_pd(),qkvl,prodh);
1187+
// zero if lower than fuzz (low part)
1188+
qkvl=_mm256_blendv_pd(qkvl,_mm256_setzero_pd(),maxabs);
1189+
}
11721190
// scatter the results (low part)
11731191
qkvrow[_mm256_extract_epi64(prn0x,0)+qksizesq]=_mm256_cvtsd_f64(qkvl);
11741192
if(coln-colx>1)qkvrow[_mm256_extract_epi64(prn0x,1)+qksizesq]=_mm256_cvtsd_f64(_mm256_permute_pd(qkvl,0b0001));
11751193
if(coln-colx>2)qkvrow[_mm256_extract_epi64(prn0x,2)+qksizesq]=_mm256_cvtsd_f64(qkvl=_mm256_permute4x64_pd (qkvl,0b01001110));
11761194
if(coln-colx>3)qkvrow[_mm256_extract_epi64(prn0x,3)+qksizesq]=_mm256_cvtsd_f64(_mm256_permute_pd(qkvl,0b0001));
11771195
}
1178-
// zero if lower than fuzz
1179-
qkvh=_mm256_blendv_pd(_mm256_setzero_pd(),qkvh,prodh);
1196+
// zero if lower than fuzz (high part)
1197+
qkvh=_mm256_blendv_pd(qkvh,_mm256_setzero_pd(),maxabs);
11801198
// scatter the results (high part)
11811199
// _mm256_mask_i64scatter_pd(qkvrow,endmask,prn0x,qkvh,SZI);
11821200
qkvrow[_mm256_extract_epi64(prn0x,0)]=_mm256_cvtsd_f64(qkvh);
@@ -1209,7 +1227,7 @@ F2(jtekupdate){F2PREFIP;
12091227
ASSERT(AR(newrownon0)==1||AS(newrownon0)[0]==2, EVLENGTH) // newrownon0 is float or extended list
12101228
A tmp=AAV(a)[4]; if(!(AT(tmp)&FL))RZ(tmp=cvt(FL,tmp)); ASSERT(AR(tmp)==0,EVRANK) D relfuzz=DAV(tmp)[0]; // relfuzz is a float atom
12111229
// agreement
1212-
ASSERT(AN(prx)==AN(pivotcolnon0),EVLENGTH) ASSERT(AN(pcx)==AN(newrownon0),EVLENGTH) // indexes and values must agree
1230+
ASSERT(AN(prx)==AS(pivotcolnon0)[AR(pivotcolnon0)-1],EVLENGTH) ASSERT(AN(pcx)==AS(newrownon0)[AR(newrownon0)-1],EVLENGTH) // indexes and values must agree
12131231
// do the work
12141232
#define YC(n) .n=n,
12151233
struct ekctx opctx={YC(prx)YC(qk)YC(pcx)YC(pivotcolnon0)YC(newrownon0)YC(relfuzz)};

0 commit comments

Comments
 (0)