This documentation is automatically generated by online-judge-tools/verification-helper
#include "library/math/floor-sum.hpp"
$f(a, b, c, n) = \sum_{i = 0}^{n - 1} \lfloor \frac{ai + b}{c} \rfloor$
int a, b, c, n;
long long ans = floor_sum(a, b, c, n);
$g(a, b, c, n) = \sum_{i = 0}^{n - 1} i \lfloor \frac{ai + b}{c} \rfloor$
$h(a, b, c, n) = \sum_{i = 0}^{n - 1} \lfloor \frac{ai + b}{c} \rfloor^2$
$g$ 和 $h$ 變形的做法請參考 Reference 裡 OI wiki 的文章。
時間複雜度:$O(\log n)$
#pragma once
#include <algorithm>
#include <cassert>
#include "safe-mod.hpp"
namespace felix {
// sum_{i = 0}^{n - 1} floor((ai + b) / c) in O(a + b + c + n)
long long floor_sum(long long n, long long a, long long b, long long c) {
assert(0 <= n && n < (1LL << 32));
assert(1 <= c && c < (1LL << 32));
unsigned long long ans = 0;
if(a < 0) {
unsigned long long a2 = internal::safe_mod(a, c);
ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / c);
a = a2;
}
if(b < 0) {
unsigned long long b2 = internal::safe_mod(b, c);
ans -= 1ULL * n * ((b2 - b) / c);
b = b2;
}
unsigned long long N = n, C = c, A = a, B = b;
while(true) {
if(A >= C) {
ans += N * (N - 1) / 2 * (A / C);
A %= C;
}
if(B >= C) {
ans += N * (B / C);
B %= C;
}
unsigned long long y_max = A * N + B;
if(y_max < C) {
break;
}
N = y_max / C;
B = y_max % C;
std::swap(C, A);
}
return ans;
}
} // namespace felix
#line 2 "library/math/floor-sum.hpp"
#include <algorithm>
#include <cassert>
#line 2 "library/math/safe-mod.hpp"
namespace felix {
namespace internal {
template<class T>
constexpr T safe_mod(T x, T m) {
x %= m;
if(x < 0) {
x += m;
}
return x;
}
} // namespace internal
} // namespace felix
#line 5 "library/math/floor-sum.hpp"
namespace felix {
// sum_{i = 0}^{n - 1} floor((ai + b) / c) in O(a + b + c + n)
long long floor_sum(long long n, long long a, long long b, long long c) {
assert(0 <= n && n < (1LL << 32));
assert(1 <= c && c < (1LL << 32));
unsigned long long ans = 0;
if(a < 0) {
unsigned long long a2 = internal::safe_mod(a, c);
ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / c);
a = a2;
}
if(b < 0) {
unsigned long long b2 = internal::safe_mod(b, c);
ans -= 1ULL * n * ((b2 - b) / c);
b = b2;
}
unsigned long long N = n, C = c, A = a, B = b;
while(true) {
if(A >= C) {
ans += N * (N - 1) / 2 * (A / C);
A %= C;
}
if(B >= C) {
ans += N * (B / C);
B %= C;
}
unsigned long long y_max = A * N + B;
if(y_max < C) {
break;
}
N = y_max / C;
B = y_max % C;
std::swap(C, A);
}
return ans;
}
} // namespace felix