【数列求和】牛客练习赛42D

传送门

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll mod = 1e9 + 7;
ll get(ll x, ll y)
{
return (x + y) * (y - x + 1) / 2;
}
ll ksm(ll n, ll k)
{
ll res = 1;
while(k)
{
if(k & 1) res = (res * n) % mod;
n = (n * n) % mod;
k >>= 1;
}
return res % mod;
}
ll sum(ll a1, ll d1, ll b1, ll q, ll n)
{
if(n == 0) return 0;
ll x1 = ((a1 % mod) * ksm(1 - q, mod - 2)) % mod + (((d1 * q) % mod) * ksm(((1 - q) * (1 - q)) % mod, mod - 2)) % mod;
ll dx = (d1 * ksm(1 - q, mod - 2)) % mod;
//cout << x1 << " " << dx << endl;
ll xnp1 = ((x1 % mod) + (dx * n) % mod) % mod;
ll bnp1 = (b1 * ksm(q, n)) % mod;
return (((x1 * b1 - xnp1 * bnp1) % mod) + mod) % mod;
}
int main()
{
ll n, k;
cin >> n >> k;
ll l = 0, r = n;
while(l < r)
{
ll mid = (l + r) >> 1;
if(get(n - mid, n - 1) <= k) l = mid + 1;
else r = mid;
}
if(get(n - l, n - 1) > k) l -= 1;
//cout << l;
k -= get(n - l, n - 1); k++;
//cout << k << endl; return 0;
ll ans1 = sum(n, -1, n + 1, n + 1, l);

ll ans2 = k * ksm(n + 1, l + 1);

ll ans3 = sum(1, 1, ksm(n + 1, l + 2), n + 1, k - 1);

ll ans4 = sum(k + 1, 1, ksm(n + 1, l + k + 1), n + 1, n - l - k);
//cout << ans4 << endl;
cout << (ans1 + ans2 + ans3 + ans4) % mod;
}