组合数学——康拓展开

yuanheci 2023年03月31日 329次浏览

  今天在洛谷刷数位DPDP题目的时候碰到了一道数位很长的题目,直接用数位DPDP的方法难以解决,这里记录一下康拓展开的方式求解。

  康托展开用于求一个排列在所有 11 ~ NN 的排列间的字典序排名。

image-1680243145941

康拓展开基本知识: =====> 传送门

板子题链接: P5367 【模板】康托展开

AC代码:

#include <bits/stdc++.h>
using namespace std;

const int N = 1e6 + 10, MOD = 998244353;
int fact[N], A[N];
int a[N];
int tr[N];
int n;

int lowbit(int x){
	return x & -x;
}

void add(int x, int c){
	for(int i = x; i <= n; i += lowbit(i)){
		tr[i] += c;
	}
}

int query(int x){
	int ans = 0;
	for(int i = x; i; i -= lowbit(i)){
		ans += tr[i];
	}
	return ans;
}

void solve(){
	scanf("%d", &n);
	for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
	fact[0] = 1;
	for(int i = 1; i <= n; i++) fact[i] = 1LL* fact[i - 1] * i % MOD;
	for(int i = n; i >= 1; i--){
		A[i] = query(a[i]);
		add(a[i], 1);
	}
	int ans = 0;
	for(int i = 1; i < n; i++){
		ans = (1LL* ans + 1LL* A[i] * fact[n - i] % MOD) % MOD; 
	}
	printf("%d\n", ans + 1);
}

int main(){
	int T = 1;
//	scanf("%d", &T);
	while(T -- ){
		solve();
	}
	
	return 0;
}

P2518 [HAOI2010]计数

  这题本质上是一个可重集合的康拓展开问题,但是没有给定取模,因此如果直接按照朴素康拓展开的思想来做,需要手写高精度,或者用很多数论知识,不太好写。

  而对于可重集合求全排列的的问题,可以取模的情况下,常规方式是:

image-1680256806951

  但本题m取到 505050!50! 会爆 long longlong \ long 因此不能直接做。

我们考虑另一种想法:
  假如现在有m个位置;我们先把0放法放好 C(m,a[0])C(m,a[0]),之后就只有 ma[0]m-a[0] 个位置;然后在放11C(ma[0],a[1])C(m-a[0],a[1]);以此类推。

  所以答案是 C(m,a[0])C(ma[0],a[1])...C(ma[0]a[1]..a[8],a[9])C(m,a[0]) * C(m-a[0],a[1]) *...* C(m-a[0]-a[1]-..-a[8],a[9])

  这样就可以算出可重集合的全排列(相当于每个重复元素的情况只算一次)。

AC代码:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
const int N = 55;
LL c[N][N], cnt[N];
char s[N]; 
int n;
LL ans;

//预处理组合数
void init(){
	for(int i = 0; i <= n; i++){
		for(int j = 0; j <= i; j++){
			if(!j) c[i][j] = 1;
			else c[i][j] = c[i - 1][j] + c[i - 1][j - 1];
		}
	}	
}

//可重集合的排列数,直接按照康拓展开需要计算阶乘,50!会爆LL, 本题还没有取模,需要写高精度。 

//也可以转换一下求全排列的思路。用组合数的思想来求出全排列。 

LL calc(int m){
	LL res = 1;
	for(int i = 0; i <= 9; i++){
		if(cnt[i]) res *= c[m][cnt[i]];
		m -= cnt[i];
	}
	return res;
}
 
void solve(){
	scanf("%s", s + 1);
	n = strlen(s + 1);
	for(int i = 1; i <= n; i++) cnt[s[i] - '0']++;
	init();
	int nn = n;
	//相当于每种重复数只考虑一次。 
	for(int i = 1; i < nn; i++){
		n--;      //每次统计后面的数的全排列 
		for(int j = 0; j < s[i] - '0'; j++){
			if(cnt[j]){
				cnt[j]--;
				ans += calc(n);
				cnt[j]++;
			}
		}
		cnt[s[i] - '0']--;
	}
	printf("%lld\n", ans); 
}

int main(){
	int T = 1;
//	scanf("%d", &T);
	while(T -- ){
		solve();
	}
	
	return 0;
}