代码写的有点长,
但思路很简单:
把n二进制拆分成数组c
从小到大凑拆分后的数
如果可以从小的nm数组凑,则凑,否则从大的nm拆。
这样贪心一定是对的
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
//typedef __int128 LL;
//typedef unsigned long long ull;
//#define F first
//#define S second
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<ld,ld> pdd;
const ld PI=acos(-1);
const ld eps=1e-9;
//unordered_map<int,int>mp;
#define ls (o<<1)
#define rs (o<<1|1)
#define pb push_back
//#define a(i,j) a[(i)*(m+2)+(j)] //m是矩阵的列数
//pop_back()
const int seed=131;
const int M = 1e5+7;
/*
int head[M],cnt;
void init(){cnt=0,memset(head,0,sizeof(head));}
struct EDGE{int to,nxt,val;}ee[M*2];
void add(int x,int y,int z){ee[++cnt].nxt=head[x],ee[cnt].to=y,ee[cnt].val=z,head[x]=cnt;}
*/
ll cm[100],a[M],c[100],nm[100];
int main()
{
ll tp=1;cm[0]=1;
for(int i=1;i<=100;i++)
{
tp*=2;
cm[i]=tp;
if(tp>1e18)break;
}
int t;scanf("%lld",&t);
while(t--)
{
ll n,m,sm=0,ans=0;
scanf("%lld%lld",&n,&m);ll nn=n;
memset(nm,0,sizeof(nm));memset(c,0,sizeof(c));
for(int i=1;i<=m;i++)
{
cin>>a[i];sm+=a[i];
nm[(int)(log(a[i])/log(2))]++;
}
for(int i=59;i>=0;i--)if(nn>=cm[i])nn-=cm[i],c[i]++;//2的第i次幂有一个需要配
bool f=true;//是否能配成
if(nn!=0||sm<n)
{
cout<<-1<<endl;
continue;
}
// cout<<"ok"<<endl;
for(int i=0;i<=59;i++)
{
// cout<<cm[i]<<" ---"<<nm[i]<<endl;
if(c[i]>0)//当前次幂需要配
{
if(nm[i]>=c[i])nm[i]-=c[i];
else
{
c[i]-=nm[i],nm[i]=0;
ll tp=1;
for(int j=i+1;j<=59;j++)//拆j 去提供i 需要c[i]个i
{
tp*=2;//每拆一个j 提供tp个i
if(nm[j]==0)continue;
if(tp>=c[i]||nm[j]*tp>=c[i])//可以提供需要的i
{
ll z=0;
for(int k=j;k>=i+1;k--)
{
if(k!=j)tp/=2;nm[k]+=z;
ll q=(c[i]-1)/tp+1;ans+=q;//需要拆2的j次方的数量
nm[k]-=q;
z=q*2;
}
c[i]-=z;
break;
}
else
{
ll num=nm[j];nm[j]=0;
for(int k=j;k>=i+1;k--)
{
ans+=num;
num*=2;
}
c[i]-=num;
//这种情况全拆
}
}
//上面的数已经拆完了 如果不满足说明不可能
if(c[i]>0)
{
f=false;
break;
}
else nm[i]-=c[i];
}
}
if(nm[i]>0)nm[i+1]+=nm[i]/2;
}
if(!f)puts("-1");
else
printf("%lld\n",ans);
}
return 0;
}
简洁的写法:
#include <iostream>
#include <cstring>
#include <vector>
#include <cstdio>
#include <algorithm>
using namespace std;
#define ll long long
int vis[64];
int gao(int x) {
if (x == 1)
return 0;
return 1 + gao(x / 2);
}
int main() {
//freopen("in.txt", "r", stdin);
//freopen("out.txt", "w", stdout);
int T;
scanf("%d", &T);
while (T--) {
ll n, sum = 0;
int m, x;
memset(vis, 0, sizeof(vis));
scanf("%lld%d", &n, &m);
for (int i = 1; i <= m; i++) {
scanf("%d", &x);
sum += x;
vis[gao(x)]++;
}
if (sum < n) {
puts("-1");
continue;
}
int cat = 0;
for (int i = 0; (1ll << i) <= n; i++) {
if (n >> i & 1)
{
if (vis[i])
vis[i]--;
else {
++cat;
for (int j = 1; ;j++, cat++)
{
if (vis[j + i]) {
vis[j + i]--;
break;
}
else
vis[j + i]++;
}
}
}
vis[i + 1] += vis[i] / 2;
}
printf("%d\n", cat);
}
}