STRASSEN算法 ^^^

by 曾经沧海
626 阅读

一、STRASSEN算法 
#include<iostream.h>
#include<math.h>
#include<memory.h>

//enum   error   {wrong,right,overflow};

void   mutrixMul(int   **a,int   **b,int   **c,int   n);
void   main(void)
{
int   n;

cout<<"please   intput   the   demi   of   matrix:"<<endl;
cin>>n;
int   *a=new   int[n*n];
int   *b=new   int[n*n];
int   *c=new   int[n*n];
//初始化
cout<<"Input   the   elements   of   the   first:"<<endl;
for(int   counter=0;counter<n*n;counter++)
{
cin>>a[counter];
}
//需要清空、!!!
cout<<"your   intput   is:"<<endl;
for(counter=0;counter<n*n;counter++)
{
if(counter%n<n/2)
cout<<endl;
cout<<a[counter]<<"   ";
}

cout<<endl<<"Input   the   elements   of   the   secong:"<<endl;
for(counter=0;counter<n*n;counter++)
{
cin>>b[counter];
}
cout<<"your   intput   is:"<<endl;
for(counter=0;counter<n*n;counter++)
{
if(counter%n<n/2)
cout<<endl;
cout<<b[counter]<<"   ";
}

mutrixMul(&a,&b,&c,n);

cout<<endl<<"the   answer   is:"<<endl;
for(counter=0;counter<n*n;counter++)
{
if(counter%n<n/2)
cout<<endl;
cout<<c[counter]<<"   ";
}
cout<<endl;
}
void   mutrixMul(int   **a,int   **b,int   **c,int   n)
{
if(n==1)
{
// **c=2;
(*c)[0]=(*a)[0]*(*a)[0];
}
else
{
int   *a1=new   int[n*n/4];
int   *a2=new   int[n*n/4];
int   *a3=new   int[n*n/4];
int   *a4=new   int[n*n/4];
//为a申请四部分动态空间。

int   *b1=new   int[n*n/4];
int   *b2=new   int[n*n/4];
int   *b3=new   int[n*n/4];
int   *b4=new   int[n*n/4];
//为b申请。

int   *c1=new   int[n*n/4];
int   *c2=new   int[n*n/4];
int   *c3=new   int[n*n/4];
int   *c4=new   int[n*n/4];
int   *c5=new   int[n*n/4];
int   *c6=new   int[n*n/4];
int   *c7=new   int[n*n/4];
int   *c8=new   int[n*n/4];

int   a1counter,a2counter,a3counter,a4counter;
a1counter=a2counter=a3counter=a4counter=0;
int   b1counter,b2counter,b3counter,b4counter;
b1counter=b2counter=b3counter=b4counter=0;
int   c1counter,c2counter,c3counter,c4counter;
c1counter=c2counter=c3counter=c4counter=0;
int   c5counter,c6counter,c7counter,c8counter;
c5counter=c6counter=c7counter=c8counter=0;

//把a细分到a1,a2,a3,a4四个矩阵;
for(int   i=0;i<n*n/2;i++)
{
if(i%n<n/2)
{
a1[a1counter]=(*a)[i];
a1counter++;
}
else
{
a2[a2counter]=(*a)[i];
a2counter++;
}
}
for(i=n*n/2;i<n*n;i++)
{
if(i%n<n/2)
{
a3[a3counter]=(*a)[i];
a3counter++;
}
else
{
a4[a4counter]=(*a)[i];
a4counter++;
}
}
//把b细分到b1,b2,b3,b4四个矩阵;
for(i=0;i<n*n/2;i++)
{
if(i%n<n/2)
{
b1[b1counter]=(*b)[i];
b1counter++;
}
else
{
b2[b2counter]=(*b)[i];
b2counter++;
}
}
for(i=n*n/2;i<n*n;i++)
{
if(i%n<n/2)
{
b3[b3counter]=(*b)[i];
b3counter++;
}
else
{
b4[b4counter]=(*b)[i];
b4counter++;
}
}
        mutrixMul(&a1,&b1,&c1,n/2);
        mutrixMul(&a2,&b3,&c2,n/2);

        mutrixMul(&a1,&b2,&c3,n/2);
        mutrixMul(&a2,&b4,&c4,n/2);

mutrixMul(&a3,&b1,&c5,n/2);
        mutrixMul(&a4,&b3,&c6,n/2);

        mutrixMul(&a3,&b2,&c7,n/2);
        mutrixMul(&a4,&b4,&c8,n/2);

//
for(i=0;i<n*n/2;i++)
{
if(i%n<n/2)
{
(*c)[i]=c1[c1counter]+c2[c2counter];
c1counter++;
c2counter++;

}
else
{
(*c)[i]=c3[c3counter]+c4[c4counter];
c3counter++;
c4counter++;
}
}
for(i=n*n/2;i<n*n;i++)
{
if(i%n<n/2)
{
(*c)[i]=c5[c5counter]+c6[c6counter];
c5counter++;
c6counter++;
}
else
{
(*c)[i]=c7[c7counter]+c8[c8counter];
c7counter++;
c8counter++;
}
}

delete   a1,a2,a3,a4,b1,b2,b3,b4;
delete   c1,c2,c3,c4,c5,c6,c7,c8;
}

/////////////////////////////////////

*此程序为用改进Strassen分治法来解决矩阵乘法*/
/*新增功能:能禁止对阶数为非2的N次方的矩阵进行运算*/
#include<stdio.h>
#define   M   100
struct   matrix
{
int   m[32][32];
};

int   Judgment(int   n) /*判断是否为2的N次方的函数*/
{
int   flag,temp=n;
while(temp!=1   &&   temp%2==0)
{
if(temp%2==0)   temp/=2;
else   flag=1;
}
if(temp==1)   flag=0;
return   flag;
}

void   Divide(matrix   &d,matrix   &d11,matrix   &d12,matrix   &d21,matrix   &d22,int   n)
/*将一个大矩阵拆分成四个小矩阵的函数*/
{
int   i,j;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
d11.m[i][j]=d.m[i][j];
d12.m[i][j]=d.m[i][j+n];
d21.m[i][j]=d.m[i+n][j];
d22.m[i][j]=d.m[i+n][j+n];
}
}

matrix   Merge(matrix   a11,matrix   a12,matrix   a21,matrix   a22,int   n)
/*将四个小矩阵合并成一个大矩阵的函数*/
{
int   i,j;
matrix   a;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
{
a.m[i][j]=a11.m[i][j];
a.m[i][j+n]=a12.m[i][j];
a.m[i+n][j]=a21.m[i][j];
a.m[i+n][j+n]=a22.m[i][j];
}
return   a;
}

matrix   AdhocMatrixMultiply(matrix   x,matrix   y) /*阶数为2的矩阵乘法函数*/
{
int   m1,m2,m3,m4,m5,m6,m7;
matrix   z;

m1=(x.m[1][1]+x.m[1][2])*y.m[1][1];
m2=x.m[1][2]*(y.m[2][1]-y.m[1][1]);
m3=(x.m[2][1]+x.m[2][2])*y.m[2][2];
m4=x.m[2][1]*(y.m[1][2]-y.m[2][2]);

m5=(x.m[1][2]+x.m[2][1])*(y.m[1][1]+y.m[2][2]);
m6=(x.m[2][1]-x.m[1][1])*(y.m[1][1]+y.m[1][2]);
m7=(x.m[1][2]-x.m[2][2])*(y.m[2][2]+y.m[2][1]);
z.m[1][1]=m1+m2;
z.m[1][2]=m5-m1+m4-m6;
z.m[2][1]=m5-m3+m2-m7;
z.m[2][2]=m3+m4;

return   z;
}

matrix   MatrixPlus(matrix   f,matrix   g,int   n) /*矩阵加法函数*/
{
int   i,j;
matrix   h;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
h.m[i][j]=f.m[i][j]+g.m[i][j];
return   h;
}

matrix   MatrixMinus(matrix   f,matrix   g,int   n) /*矩阵减法函数*/
{
int   i,j;
matrix   h;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
h.m[i][j]=f.m[i][j]-g.m[i][j];
return   h;
}

matrix   MatrixMultiply(matrix   a,matrix   b,int   n) /*矩阵乘法函数*/
{
int   k;
matrix   a11,a12,a21,a22;
matrix   b11,b12,b21,b22;
matrix   c11,c12,c21,c22,c;
matrix   m1,m2,m3,m4,m5,m6,m7;
k=n;
if(k==2)
{
c=AdhocMatrixMultiply(a,b);
return   c;
}
else

k=n/2;
Divide(a,a11,a12,a21,a22,k); //拆分A、B、C矩阵
Divide(b,b11,b12,b21,b22,k);
Divide(c,c11,c12,c21,c22,k);

m1=MatrixMultiply(MatrixPlus(a11,a12,n/2),b11,k);
m2=MatrixMultiply(a12,MatrixMinus(b21,b11,k),k);
m3=MatrixMultiply(MatrixPlus(a21,a22,k),b22,k);
m4=MatrixMultiply(a21,MatrixMinus(b12,b22,k),k);
m5=MatrixMultiply(MatrixPlus(a12,a21,k),MatrixPlus(b11,b22,k),k);
m6=MatrixMultiply(MatrixMinus(a21,a11,k),MatrixPlus(b11,b12,k),k);
m7=MatrixMultiply(MatrixMinus(a12,a22,k),MatrixPlus(b22,b21,k),k);
c11=MatrixPlus(m1,m2,k);
c12=MatrixPlus(MatrixMinus(m5,m1,k),MatrixMinus(m4,m6,k),k);
c21=MatrixPlus(MatrixMinus(m5,m3,k),MatrixMinus(m2,m7,k),k);
c22=MatrixPlus(m3,m4,k);

c=Merge(c11,c12,c21,c22,k); //合并C矩阵
return   c;

}

void   main()
{
int   i,j,n;
matrix   A,B,C={0};
while(n!=0)
{
printf("请输入矩阵的阶数N:\n");
scanf("%d",&n);
if(n==0)   break;
else
if(Judgment(n)==0) //判断矩阵的阶是否为2的N次方
{
printf("请输入矩阵A:\n");
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
scanf("%d",&A.m[i][j]);
printf("请输入矩阵B:\n");
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
scanf("%d",&B.m[i][j]);
if(n==1)   C.m[1][1]=A.m[1][1]*B.m[1][1]; //矩阵阶数为1时的特殊处理 
else   C=MatrixMultiply(A,B,n);

printf("矩阵C为:\n");
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
printf("%8d%c",C.m[i][j],j==n?'\n':'   ');
}
else   printf("矩阵的阶数不是2的N次方!\n\n\n");
}
}
/*此程序用递归分治法解决矩阵乘法问题。
    当N=2时,矩阵可直接计算出来。
    当N>2时,可以继续将矩阵分块,直到子矩阵的阶降为2。*/

发表评论