Description
题目链接: https://codeforces.com/contest/1107/problem/E
题目给出一个长度为 $n$ 的 01 字符串,消除连续长度为 len 的 0/1 能够获得相应 $a[len]$ 的价值,问将字符串完全消除后能够获得的最大价值?
题目限制:
$1 \le n \le 100$
$a_1,a_2,…,a_n(1 \le a_i \le 10^9)$
Solution
容易想到这是一个区间dp的问题,困难的地方在于如何定义状态
可以定义 $dp[l][r][cnt][dig]$ 为 $[l,r]$ 区间内缩减成 $cnt$ 个 $dig$ (如0/1)数所能产生的最大价值,定义 $ans[l][r]$ 为 $[l,r]$ 区间完全消除后的最大价值,可以知道 $ans[1][n]$ 即为所求
那么状态转移方程可以写作:
$dp[l][r][cnt][dig]=max_{l \le k \le r,s[k]=dig}(ans[l][k-1]+dp[k+1][r][cnt-1][dig])$
$ans[l][r]=max_{1 \le cnt \le r-l+1,0 \le dig \le 1}(dp[l][r][cnt][dig]+a[cnt])$
可以理解成,我如果想求 $[l,r]$ 区间完全消除的最大值,可以先将 $[l,r]$ 区间消除到剩下 $cnt$ 个相同的数,再将这些数一起消除
而如果想知道 $[l,r]$ 区间消除到 $cnt$ 个相同的数( $dig$ )的最大值,我们可以枚举 $[l,r]$ 区间中第一个 $dig$,假定这个位置是 $k$,那么这个总价值显然为先完全消除 $[l,k-1]$ 区间的数,再将 $[k+1,r]$ 区间的数消除到 $cnt-1$ 个相同数( $dig$ )
状态数为 $n^3$,枚举第一个 $dig$ 的位置 $k$ 需要 $O(n)$的时间,总时间复杂度为 $O(n^4)$
AC code
Time: 77ms
Memory: 18500KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <bits/stdc++.h>
using namespace std;
#define mem(a,i) memset(a,i,sizeof(a))
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define per(i,a,b) for(int i=a;i>=b;--i)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef pair<int,int> pii;
const ll mod=1e9+7;
const db exps=1e-8;
const db pi=acos(-1.0);
const int maxn=105;
const ll INF=0x3f3f3f3f3f3f3f3f;
ll ans[maxn][maxn];
ll dp[maxn][maxn][maxn][2];
ll a[maxn];
char s[maxn];
int n;
int main() {
scanf("%d",&n);
scanf("%s",s+1);
rep(i,1,n) scanf("%lld",&a[i]);
rep(i,1,n) {
rep(j,1,n) {
int l=j,r=j+i-1;
if(r>n) continue;
rep(dig,0,1) {
rep(cnt,1,n) {
int sum=0;
rep(k,l,r) if(s[k]==dig+'0') sum++;
if(sum<cnt) {
dp[l][r][cnt][dig]=-INF;
continue;
}
rep(k,l,r) {
if(s[k]==dig+'0') {
if(k+1>r&&cnt-1!=0) continue;
dp[l][r][cnt][dig]=max(ans[l][k-1]+dp[k+1][r][cnt-1][dig],dp[l][r][cnt][dig]);
}
}
ans[l][r]=max(ans[l][r],dp[l][r][cnt][dig]+a[cnt]);
}
}
rep(dig,0,1) {
dp[l][r][0][dig]=ans[l][r];
}
}
}
printf("%lld\n",ans[1][n]);
return 0;
}