0%

[SNOI2017] 礼物

这道题是一道十分标准的矩阵快速幂题目。

题目链接

我们来看这道题数据范围 ,我们很容易可以得到这道题算法为矩阵快速幂。

那么我们先来看一下两个矩阵是怎么相乘的。

首先,两个矩阵 相乘的必要条件为 的列等于 的行数,相乘的结果矩阵 的行数等于 的行数, 的列数等于 的列数。

既然这道题需要矩阵,为了方便起见,我们可以把矩阵结构体进行封装,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
struct mat
{
long long m,n;
long long a[30][30];
mat operator * (mat& b) //重构乘法运算符
{
mat c;
memset(c.a,0,sizeof(c.a));
c.m=m;
c.n=b.n;
for(int i=1;i<=m;i++)
for(int j=1;j<=b.n;j++)
for(int p=1;p<=n;p++)
{
c.a[i][j]+=(a[i][p]*b.a[p][j])%1000000007;
c.a[i][j]%=1000000007;
}
return c;
}
};
但是我们会有一个问题:矩阵里面应该放什么变量呢?

首先我们会想到放入第 个人的礼物数量,

其次我们会用到之前所有人的礼物数量之和,因为计算第 个人礼物数量会用到。

我们在计算第 个人的礼物数量需要用到 ,所以在矩阵中还要有

同时需要计算 还需要 ,而 还需要

所以我们找到了矩阵中需要放入的变量:当前朋友的礼物、

我构造的矩阵是这样的:

那么接下来的一步就变成矩阵快速幂了,

那么这种矩阵应该乘一个怎样的矩阵才能变化到下一步呢?

这个矩阵的下一步是这样的:

我们来一步一步地看:

,

那么 等又应该如何构造呢?

由二项式定理得,这个地方应该是一个杨辉三角,我们可以写一个函数来求。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
long long yh[20][20];
long long yhsj(long long x,long long y)
{
if(!yh[x][y])
{
if(y==1||y==x)
yh[x][y]=1;
else
yh[x][y]=yh[x-1][y-1]+yh[x-1][y];
}
return yh[x][y];
}
void nb(mat& a,long long k)
{
a.m=a.n=k+3;
memset(a.a,0,sizeof(a.a));
a.a[k+3][1]=1;
a.a[k+3][k+3]=2;
for(int i=1;i<=k+1;i++)
for(int j=1;j<=i;j++)
a.a[k+3-j][k+3-i]=yhsj(i,j);
for(int i=2;i<=k+2;i++)
a.a[i][1]=a.a[i][k+3]=a.a[i][2];
}

其中 函数用于计算杨辉三角,数组 用于记忆化搜索杨辉三角的某一项。

接下来就是快速幂的部分了。

实际上这个部分很简单,这要把整数的快速幂稍微修修改可以了。

递归形式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
mat mut(mat a,long long t)
{
if(t==1)
return a;
if(t==0)
{
mat b;
b.m=b.n=k+3;
memset(b.a,0,sizeof(b.a));
for(int i=1;i<=k+3;i++)
b.a[i][i]=1;
return b;
}
mat ans;
ans.m=ans.n=3;
memset(ans.a,0,sizeof(ans.a));
ans=mut(a,t/2);
ans=ans*ans;
if(t%2==1)
ans=ans*a;
return ans;
}
循环形式:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
mat mut(mat a,long long t)
{
mat ans;
memset(ans.a,0,sizeof(ans.a));
ans.m=ans.n=k+3;
for(int i=1;i<=k+3;i++)
ans.a[i][i]=1;
while(t)
{
if(t&1)
ans=ans*a;
a=a*a;
t=t/2;
}
return ans;
}
下面给出最终代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include<cstring>
#include<cstdio>
using namespace std;
long long yh[20][20];
long long n,k;
struct mat
{
long long m,n;
long long a[101][101];
mat operator * (mat& b)
{
mat c;
memset(c.a,0,sizeof(c.a));
c.m=m;
c.n=b.n;
for(int i=1;i<=m;i++)
for(int j=1;j<=n;j++)
for(int p=1;p<=n;p++)
{
c.a[i][j]+=(a[i][p]*b.a[p][j])%1000000007;
c.a[i][j]%=1000000007;
}
return c;
}
mat operator + (mat& b)
{
mat c;
c.n=b.n;
c.m=b.m;
for(int i=1;i<=m;i++)
for(int j=1;j<=n;j++)
c.a[i][j]=a[i][j]+b.a[i][j];
return c;
}
};
long long yhsj(long long x,long long y)
{
if(!yh[x][y])
{
if(y==1||y==x)
yh[x][y]=1;
else
yh[x][y]=yh[x-1][y-1]+yh[x-1][y];
}
return yh[x][y];
}
void nb(mat& a,long long k)
{
a.m=a.n=k+3;
memset(a.a,0,sizeof(a.a));
a.a[k+3][1]=1;
a.a[k+3][k+3]=2;
for(int i=1;i<=k+1;i++)
for(int j=1;j<=i;j++)
a.a[k+3-j][k+3-i]=yhsj(i,j);
for(int i=2;i<=k+2;i++)
a.a[i][1]=a.a[i][k+3]=a.a[i][2];
}
void nbnb(mat& a,long long k)
{
memset(a.a,0,sizeof(a.a));
a.m=1;
a.n=k+3;
for(int i=1;i<=k+3;i++)
a.a[1][i]=1;
}
mat mut(mat a,long long t)
{
if(t==1)
return a;
if(t==0)
{
mat b;
b.m=b.n=k+3;
memset(b.a,0,sizeof(b.a));
for(int i=1;i<=k+3;i++)
b.a[i][i]=1;
return b;
}
mat ans;
ans.m=ans.n=3;
memset(ans.a,0,sizeof(ans.a));
ans=mut(a,t/2);
ans=ans*ans;
if(t%2==1)
ans=ans*a;
return ans;
}
int main()
{
scanf("%lld%lld",&n,&k);
mat ans;
nb(ans,k);
mat op;
nbnb(op,k);
mat num=mut(ans,n-1);
num=op*num;
printf("%lld",num.a[1][1]);
return 0;
}
不愧是我,最终代码正好打了100行。

本人第一次写题解,欢迎各位大佬指出我的错误。

这真的是一篇早期题解,我 年寒假写的。