D. Count the Arrays

Link: D. Count the Arrays

Description

Your task is to calculate the number of arrays such that:

  • each array contains n elements;
  • each element is an integer from 1 to m;
  • for each array, there is exactly one pair of equal elements;
  • for each array a, there exists an index i such that the array is strictly ascending before the i-th element and strictly descending after it (formally, it means that , if , and , if ).

Input
The first line contains two integers n and m ().

Output
Print one integer — the number of arrays that meet all of the aforementioned conditions, taken modulo 998244353.

Examples
input
3 4
output
6
input
3 5
output
10
input
42 1337
output
806066790
input
100000 200000
output
707899035
Note
The arrays in the first example are:

  • ;
  • ;
  • ;
  • ;
  • ;
  • .

Problem solving

这道题的意思是给你两个数,让你求有多少个长度为的序列满足,序列中最大值的左边严格单调递增,最大值的右边严格单调递减。并且序列中有且仅有一对相等的数,序列中的数的范围是

首先你要知道,有一对相等的数。那么这两个数一定是一个在最大值的左边一个在最大值的右边,因为题目要求中的单调性都是严格单调。
我们考虑,将最大值右边的数全部移动到最大值左边,所以此时除了最大值,我们还需要放个数(因为总共n个数,有一个数是出现两次的)。然后我们考虑排列组合,假设最大值为,那么这个数的取值就有个可能,因为每个数都不相同,所以此时的总方案数为
但是我们这是移动过之后放置的方案数,每个数的原来的位置是可以在最大值左边也可以在最大值右边的。所以总方案数应该再乘上,为什么是次方呢,因为这个数中有一个数是重复出现的,因此它的位置是固定的,即最大值左边一个最大值右边一个,所以只有个数是即可以放在左边也可以放在右边的,并且因为这个数中每个数都有可能是那个重复的数,所以总方案数应该再乘上,这时的方案数即为最大值为时的最终答案。
然后我们枚举i可能的取值,累加计算答案即可。因为前面至少要放个数,所以i的枚举范围为
注意,的时候是无解的。
即:

答案我们列出来了,但是还需要一些细节的处理。因为需要计算组合数,如果你只是暴力计算组合数,那你肯定会T死。不难发现,我们需要的组合数都是。并且求和的时候第一项如果有一定是,然后依次是 , 。不难发现,后面的每一项都可以由前面的那一项递推得来,以达到优化时间复杂度的效果。

我们观察可得

并且每相邻两项都满足这样的规律。
因此我们如果知道最大值为得时候的组合数,就可以最大值为时的组合数为
因为数据范围很大,所以这里的除法需要用到逆元。

具体操作请看代码注释

Code

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll maxn = 2e5+10;
const ll mod = 998244353;
ll a[maxn];
ll poww(ll x,ll y){
    ll ans=1;
    while(y){
        if(y&1)    ans=(ans*x)%mod;
        x=(x*x)%mod;
        y>>=1;
    }
    return ans;
}//快速幂
ll inv(ll x){
    return poww(x,mod-2);
}//逆元
int main()
{
    ll n,m;
    cin>>n>>m;
    if(n==2){
        cout<<"0"<<endl;
        return 0;
    }//如果n为2,此时无解
    ll ans=0,mid,er=poww(2,n-3);//ans为答案,mid记录每一次组合数的值
    mid=1;ans=er*(n-2)%mod;//我们以i=n-1的值为初始情况,将mid和ans初始化
    for(ll i=n;i<=m;i++)
        mid=(i-1)*mid%mod*inv(i-n+1)%mod,ans=(ans+mid*er%mod*(n-2))%mod;//每次更新mid(即组合数)和ans,注意这里三个数相乘可能会爆long long,总之能取模就多去几次就行了。这里对应的就是上面的分析
    cout<<ans<<endl;
    return 0;
}