Felix's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub fffelix-huang/CP-stuff

:heavy_check_mark: test/tree/hld/yosupo-Vertex-Set-Path-Composite.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/vertex_set_path_composite"

#include <iostream>

#include <vector>

#include "../../../library/data-structure/segtree.hpp"

#include "../../../library/tree/hld.hpp"

#include "../../../library/modint/modint.hpp"

using namespace std;
using namespace felix;

using mint = modint998244353;

struct S {
	pair<mint, mint> f, g;

	S() : S(1, 0) {}
	S(mint a, mint b) : f(a, b), g(a, b) {}
	S(pair<mint, mint> a, pair<mint, mint> b) : f(a), g(b) {}
};

pair<mint, mint> composition(pair<mint, mint> f, pair<mint, mint> g) { return {f.first * g.first, f.first * g.second + f.second}; }

S e() { return S(); }
S op(S a, S b) { return S(composition(a.f, b.f), composition(b.g, a.g)); }

int main() {
	ios::sync_with_stdio(false);
	cin.tie(0);
	int n, q;
	cin >> n >> q;
	vector<S> a(n);
	for(int i = 0; i < n; i++) {
		mint c, d;
		cin >> c >> d;
		a[i] = S(c, d);
	}
	HLD hld(n);
	for(int i = 0; i < n - 1; i++) {
		int u, v;
		cin >> u >> v;
		hld.add_edge(u, v);
	}
	hld.build();
	segtree<S, e, op> seg(n);
	for(int i = 0; i < n; i++) {
		seg.set(hld.id[i], a[i]);
	}
	while(q--) {
		int type, x, y, z;
		cin >> type >> x >> y >> z;
		if(type == 0) {
			seg.set(hld.id[x], S(y, z));
		} else {
			pair<mint, mint> res = {1, 0};
			for(auto [u, v] : hld.get_path(x, y, true)) {
				if(hld.id[u] <= hld.id[v]) {
					res = composition(seg.prod(hld.id[u], hld.id[v] + 1).g, res);
				} else {
					res = composition(seg.prod(hld.id[v], hld.id[u] + 1).f, res);
				}
			}
			cout << res.first * z + res.second << "\n";
		}
	}
	return 0;
}
#line 1 "test/tree/hld/yosupo-Vertex-Set-Path-Composite.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/vertex_set_path_composite"

#include <iostream>

#include <vector>

#line 3 "library/data-structure/segtree.hpp"
#include <algorithm>
#include <functional>
#include <cassert>

namespace felix {

template<class S, S (*e)(), S (*op)(S, S)>
struct segtree {
public:
    segtree() {}
    explicit segtree(int _n) : segtree(std::vector<S>(_n, e())) {}
    explicit segtree(const std::vector<S>& a): n(a.size()) {
        log = std::__lg(2 * n - 1);
        size = 1 << log;
        d.resize(size * 2, e());
        for(int i = 0; i < n; ++i) {
            d[size + i] = a[i];
        }
        for(int i = size - 1; i >= 1; i--) {
            update(i);
        }
    }
    
    void set(int p, S val) {
        assert(0 <= p && p < n);
        p += size;
        d[p] = val;
        for(int i = 1; i <= log; ++i) {
            update(p >> i);
        }
    }

    S get(int p) const {
        assert(0 <= p && p < n);
        return d[p + size];
    }

    S operator[](int p) const { return get(p); }
    
    S prod(int l, int r) const {
        assert(0 <= l && l <= r && r <= n);
        S sml = e(), smr = e();
        for(l += size, r += size; l < r; l >>= 1, r >>= 1) {
            if(l & 1) {
                sml = op(sml, d[l++]);
            }
            if(r & 1) {
                smr = op(d[--r], smr);
            }
        }
        return op(sml, smr);
    }

    S all_prod() const { return d[1]; }

    template<bool (*f)(S)> int max_right(int l) {
        return max_right(l, [](S x) { return f(x); });
    }

    template<class F> int max_right(int l, F f) {
        assert(0 <= l && l <= n);
        assert(f(e()));
        if(l == n) {
            return n;
        }
        l += size;
        S sm = e();
        do {
            while(~l & 1) {
                l >>= 1;
            }
            if(!f(op(sm, d[l]))) {
                while(l < size) {
                    push(l);
                    l <<= 1;
                    if(f(op(sm, d[l]))) {
                        sm = op(sm, d[l++]);
                    }
                }
                return l - size;
            }
            sm = op(sm, d[l++]);
        } while((l & -l) != l);
        return n;
    }

    template<bool (*f)(S)> int min_left(int r) {
        return min_left(r, [](S x) { return f(x); });
    }

    template<class F> int min_left(int r, F f) {
        assert(0 <= r && r <= n);
        assert(f(e()));
        if(r == 0) {
            return 0;
        }
        r += size;
        S sm = e();
        do {
            r--;
            while(r > 1 && (r & 1)) {
                r >>= 1;
            }
            if(!f(op(d[r], sm))) {
                while(r < size) {
                    push(r);
                    r = 2 * r + 1;
                    if(f(op(d[r], sm))) {
                        sm = op(d[r--], sm);
                    }
                }
                return r + 1 - size;
            }
            sm = op(d[r], sm);
        } while((r & -r) != r);
        return 0;
    }
    
protected:
    int n, size, log;
    std::vector<S> d;

    void update(int v) {
        d[v] = op(d[2 * v], d[2 * v + 1]);
    }

    virtual void push(int p) {}
};

} // namespace felix
#line 3 "library/tree/hld.hpp"
#include <array>

#line 6 "library/tree/hld.hpp"
#include <cmath>

#line 4 "library/data-structure/sparse-table.hpp"

namespace felix {

template<class S, S (*op)(S, S)>
struct sparse_table {
public:
	sparse_table() {}
	explicit sparse_table(const std::vector<S>& a) {
		n = (int) a.size();
		int max_log = std::__lg(n) + 1;
		mat.resize(max_log);
		mat[0] = a;
		for(int j = 1; j < max_log; ++j) {
			mat[j].resize(n - (1 << j) + 1);
			for(int i = 0; i <= n - (1 << j); ++i) {
				mat[j][i] = op(mat[j - 1][i], mat[j - 1][i + (1 << (j - 1))]);
			}
		}
	}

	S prod(int from, int to) const {
		assert(0 <= from && from <= to && to <= n - 1);
		int lg = std::__lg(to - from + 1);
		return op(mat[lg][from], mat[lg][to - (1 << lg) + 1]);
	}

private:
	int n;
	std::vector<std::vector<S>> mat;
};

} // namespace felix
#line 8 "library/tree/hld.hpp"

namespace felix {

struct HLD {
private:
	static constexpr std::pair<int, int> __lca_op(std::pair<int, int> a, std::pair<int, int> b) {
		return std::min(a, b);
	}

public:
	int n;
	std::vector<std::vector<int>> g;
	std::vector<int> subtree_size;
	std::vector<int> parent;
	std::vector<int> depth;
	std::vector<int> top;
	std::vector<int> tour;
	std::vector<int> first_occurrence;
	std::vector<int> id;
	std::vector<std::pair<int, int>> euler_tour;
	sparse_table<std::pair<int, int>, __lca_op> st;

	HLD() : n(0) {}
	explicit HLD(int _n) : n(_n), g(_n), subtree_size(_n), parent(_n), depth(_n), top(_n), first_occurrence(_n), id(_n) {
		tour.reserve(n);
		euler_tour.reserve(2 * n - 1);
	}

	void add_edge(int u, int v) {
		assert(0 <= u && u < n);
		assert(0 <= v && v < n);
		g[u].push_back(v);
		g[v].push_back(u);
	}

	void build(int root = 0) {
		assert(0 <= root && root < n);
		parent[root] = -1;
		top[root] = root;
		dfs_sz(root);
		dfs_link(root);
		st = std::move(sparse_table<std::pair<int, int>, __lca_op>(euler_tour));
	}

	int get_lca(int u, int v) {
		assert(0 <= u && u < n);
		assert(0 <= v && v < n);
		int L = first_occurrence[u];
		int R = first_occurrence[v];
		if(L > R) {
			std::swap(L, R);
		}
		return st.prod(L, R).second;
	}

	bool is_ancestor(int u, int v) {
		assert(0 <= u && u < n);
		assert(0 <= v && v < n);
		return id[u] <= id[v] && id[v] < id[u] + subtree_size[u];
	}

	bool on_path(int a, int x, int b) {
		return (is_ancestor(x, a) || is_ancestor(x, b)) && is_ancestor(get_lca(a, b), x);
	}

	int get_distance(int u, int v) {
		return depth[u] + depth[v] - 2 * depth[(get_lca(u, v))];
	}

	std::pair<int, std::array<int, 2>> get_diameter() const {
		std::pair<int, int> u_max = {-1, -1};
		std::pair<int, int> ux_max = {-1, -1};
		std::pair<int, std::array<int, 2>> uxv_max = {-1, std::array<int, 2>{-1, -1}};
		for(auto [d, u] : euler_tour) {
			u_max = std::max(u_max, std::make_pair(d, u));
			ux_max = std::max(ux_max, std::make_pair(u_max.first - 2 * d, u_max.second));
			uxv_max = std::max(uxv_max, std::make_pair(ux_max.first + d, std::array<int, 2>{ux_max.second, u}));
		}
		return uxv_max;
	}

	int get_kth_ancestor(int u, int k) {
		assert(0 <= u && u < n);
		if(depth[u] < k) {
			return -1;
		}
		int d = depth[u] - k;
		while(depth[top[u]] > d) {
			u = parent[top[u]];
		}
		return tour[id[u] + d - depth[u]];
	}

	int get_kth_node_on_path(int a, int b, int k) {
		int z = get_lca(a, b);
		int fi = depth[a] - depth[z];
		int se = depth[b] - depth[z];
		if(k < 0 || k > fi + se) {
			return -1;
		}
		if(k < fi) {
			return get_kth_ancestor(a, k);
		} else {
			return get_kth_ancestor(b, fi + se - k);
		}
	}

	std::vector<std::pair<int, int>> get_path(int u, int v, bool include_lca) {
		if(u == v && !include_lca) {
			return {};
		}
		std::vector<std::pair<int, int>> lhs, rhs;
		while(top[u] != top[v]) {
			if(depth[top[u]] > depth[top[v]]) {
				lhs.emplace_back(u, top[u]);
				u = parent[top[u]];
			} else {
				rhs.emplace_back(top[v], v);
				v = parent[top[v]];
			}
		}
		if(u != v || include_lca) {
			if(include_lca) {
				lhs.emplace_back(u, v);
			} else {
				int d = std::abs(depth[u] - depth[v]);
				if(depth[u] < depth[v]) {
					rhs.emplace_back(tour[id[v] - d + 1], v);
				} else {
					lhs.emplace_back(u, tour[id[u] - d + 1]);
				}
			}
		}
		std::reverse(rhs.begin(), rhs.end());
		lhs.insert(lhs.end(), rhs.begin(), rhs.end());
		return lhs;
	}

private:
	void dfs_sz(int u) {
		if(parent[u] != -1) {
			g[u].erase(std::find(g[u].begin(), g[u].end(), parent[u]));
		}
		subtree_size[u] = 1;
		for(auto& v : g[u]) {
			parent[v] = u;
			depth[v] = depth[u] + 1;
			dfs_sz(v);
			subtree_size[u] += subtree_size[v];
			if(subtree_size[v] > subtree_size[g[u][0]]) {
				std::swap(v, g[u][0]);
			}
		}
	}

	void dfs_link(int u) {
		first_occurrence[u] = (int) euler_tour.size();
		id[u] = (int) tour.size();
		euler_tour.emplace_back(depth[u], u);
		tour.push_back(u);
		for(auto v : g[u]) {
			top[v] = (v == g[u][0] ? top[u] : v);
			dfs_link(v);
			euler_tour.emplace_back(depth[u], u);
		}
	}
};

} // namespace felix

#line 6 "library/modint/modint.hpp"
#include <type_traits>

#line 3 "library/misc/type-traits.hpp"
#include <numeric>

#line 5 "library/misc/type-traits.hpp"

namespace felix {

namespace internal {

#ifndef _MSC_VER
template<class T> using is_signed_int128 = typename std::conditional<std::is_same<T, __int128_t>::value || std::is_same<T, __int128>::value, std::true_type, std::false_type>::type;
template<class T> using is_unsigned_int128 = typename std::conditional<std::is_same<T, __uint128_t>::value || std::is_same<T, unsigned __int128>::value, std::true_type, std::false_type>::type;
template<class T> using make_unsigned_int128 = typename std::conditional<std::is_same<T, __int128_t>::value, __uint128_t, unsigned __int128>;
template<class T> using is_integral = typename std::conditional<std::is_integral<T>::value || is_signed_int128<T>::value || is_unsigned_int128<T>::value, std::true_type, std::false_type>::type;
template<class T> using is_signed_int = typename std::conditional<(is_integral<T>::value && std::is_signed<T>::value) || is_signed_int128<T>::value, std::true_type, std::false_type>::type;
template<class T> using is_unsigned_int = typename std::conditional<(is_integral<T>::value && std::is_unsigned<T>::value) || is_unsigned_int128<T>::value, std::true_type, std::false_type>::type;
template<class T> using to_unsigned = typename std::conditional<is_signed_int128<T>::value, make_unsigned_int128<T>, typename std::conditional<std::is_signed<T>::value, std::make_unsigned<T>, std::common_type<T>>::type>::type;
#else
template<class T> using is_integral = typename std::is_integral<T>;
template<class T> using is_signed_int = typename std::conditional<is_integral<T>::value && std::is_signed<T>::value, std::true_type, std::false_type>::type;
template<class T> using is_unsigned_int = typename std::conditional<is_integral<T>::value && std::is_unsigned<T>::value, std::true_type, std::false_type>::type;
template<class T> using to_unsigned = typename std::conditional<is_signed_int<T>::value, std::make_unsigned<T>, std::common_type<T>>::type;
#endif

template<class T> using is_signed_int_t = std::enable_if_t<is_signed_int<T>::value>;
template<class T> using is_unsigned_int_t = std::enable_if_t<is_unsigned_int<T>::value>;
template<class T> using to_unsigned_t = typename to_unsigned<T>::type;

template<class T> struct safely_multipliable {};
template<> struct safely_multipliable<short> { using type = int; };
template<> struct safely_multipliable<unsigned short> { using type = unsigned int; };
template<> struct safely_multipliable<int> { using type = long long; };
template<> struct safely_multipliable<unsigned int> { using type = unsigned long long; };
template<> struct safely_multipliable<long long> { using type = __int128; };
template<> struct safely_multipliable<unsigned long long> { using type = __uint128_t; };

template<class T> using safely_multipliable_t = typename safely_multipliable<T>::type;

}  // namespace internal


}  // namespace felix

#line 2 "library/math/safe-mod.hpp"

namespace felix {

namespace internal {

template<class T>
constexpr T safe_mod(T x, T m) {
	x %= m;
	if(x < 0) {
		x += m;
	}
	return x;
}

} // namespace internal


} // namespace felix
#line 3 "library/math/inv-gcd.hpp"

namespace felix {

namespace internal {

template<class T>
constexpr std::pair<T, T> inv_gcd(T a, T b) {
	a = safe_mod(a, b);
	if(a == 0) {
		return {b, 0};
	}
	T s = b, t = a;
	T m0 = 0, m1 = 1;
	while(t) {
		T u = s / t;
		s -= t * u;
		m0 -= m1 * u;
		auto tmp = s;
		s = t;
		t = tmp;
		tmp = m0;
		m0 = m1;
		m1 = tmp;
	}
	if(m0 < 0) {
		m0 += b / s;
	}
	return {s, m0};
}

} // namespace internal


} // namespace felix

#line 9 "library/modint/modint.hpp"

namespace felix {

template<int id>
struct modint {
public:
	static constexpr int mod() { return (id > 0 ? id : md); }
 	
	static constexpr void set_mod(int m) {
		if(id > 0 || md == m) {
			return;
		}
		md = m;
		fact.resize(1);
		inv_fact.resize(1);
		invs.resize(1);
	}

	static constexpr void prepare(int n) {
		int sz = (int) fact.size();
		if(sz == mod()) {
			return;
		}
		n = 1 << std::__lg(2 * n - 1);
		if(n < sz) {
			return;
		}
		if(n < (sz - 1) * 2) {
			n = std::min((sz - 1) * 2, mod() - 1);
		}
		fact.resize(n + 1);
		inv_fact.resize(n + 1);
		invs.resize(n + 1);
		for(int i = sz; i <= n; i++) {
			fact[i] = fact[i - 1] * i;
		}
		auto eg = internal::inv_gcd(fact.back().val(), mod());
		assert(eg.first == 1);
		inv_fact[n] = eg.second;
		for(int i = n - 1; i >= sz; i--) {
			inv_fact[i] = inv_fact[i + 1] * (i + 1);
		}
		for(int i = n; i >= sz; i--) {
			invs[i] = inv_fact[i] * fact[i - 1];
		}
	}
 
	constexpr modint() : v(0) {} 
	template<class T, internal::is_signed_int_t<T>* = nullptr> constexpr modint(T x) : v(x >= 0 ? x % mod() : x % mod() + mod()) {}
	template<class T, internal::is_unsigned_int_t<T>* = nullptr> constexpr modint(T x) : v(x % mod()) {}
 
	constexpr int val() const { return v; }

	constexpr modint inv() const {
		if(id > 0 && v < std::min(mod() >> 1, 1 << 18)) {
			prepare(v);
			return invs[v];
		} else {
			auto eg = internal::inv_gcd(v, mod());
			assert(eg.first == 1);
			return eg.second;
		}
	}
 
	constexpr modint& operator+=(const modint& rhs) & {
		v += rhs.v;
		if(v >= mod()) {
			v -= mod();
		}
		return *this;
	}
 
	constexpr modint& operator-=(const modint& rhs) & {
		v -= rhs.v;
		if(v < 0) {
			v += mod();
		}
		return *this;
	}

	constexpr modint& operator*=(const modint& rhs) & {
		v = 1LL * v * rhs.v % mod();
		return *this;
	}

	constexpr modint& operator/=(const modint& rhs) & {
		return *this *= rhs.inv();
	}

	friend constexpr modint operator+(modint lhs, modint rhs) { return lhs += rhs; }
	friend constexpr modint operator-(modint lhs, modint rhs) { return lhs -= rhs; }
	friend constexpr modint operator*(modint lhs, modint rhs) { return lhs *= rhs; }
	friend constexpr modint operator/(modint lhs, modint rhs) { return lhs /= rhs; }

	constexpr modint operator+() const { return *this; }
	constexpr modint operator-() const { return modint() - *this; } 
	constexpr bool operator==(const modint& rhs) const { return v == rhs.v; } 
	constexpr bool operator!=(const modint& rhs) const { return v != rhs.v; }

	constexpr modint pow(long long p) const {
		modint a(*this), res(1);
		if(p < 0) {
			a = a.inv();
			p = -p;
		}
		while(p) {
			if(p & 1) {
				res *= a;
			}
			a *= a;
			p >>= 1;
		}
		return res;
	}

	constexpr bool has_sqrt() const {
		if(mod() == 2 || v == 0) {
			return true;
		}
		if(pow((mod() - 1) / 2).val() != 1) {
			return false;
		}
		return true;
	}

	constexpr modint sqrt() const {
		if(mod() == 2 || v < 2) {
			return *this;
		}
		assert(pow((mod() - 1) / 2).val() == 1);
		modint b = 1;
		while(b.pow((mod() - 1) >> 1).val() == 1) {
			b += 1;
		}
		int m = mod() - 1, e = __builtin_ctz(m);
		m >>= e;
		modint x = modint(*this).pow((m - 1) >> 1);
		modint y = modint(*this) * x * x;
		x *= v;
		modint z = b.pow(m);
		while(y.val() != 1) {
			int j = 0;
			modint t = y;
			while(t.val() != 1) {
				t *= t;
				j++;
			}
			z = z.pow(1LL << (e - j - 1));
			x *= z, z *= z, y *= z;
			e = j;
		}
		return x;
	}

	friend std::istream& operator>>(std::istream& in, modint& num) {
		long long x;
		in >> x;
		num = modint<id>(x);
		return in;
	}
	
	friend std::ostream& operator<<(std::ostream& out, const modint& num) {
		return out << num.val();
	}

public:
	static std::vector<modint> fact, inv_fact, invs;
 
private:
	int v;
	static int md;
};

template<int id> int modint<id>::md = 998244353;
template<int id> std::vector<modint<id>> modint<id>::fact = {1};
template<int id> std::vector<modint<id>> modint<id>::inv_fact = {1};
template<int id> std::vector<modint<id>> modint<id>::invs = {0};

using modint998244353 = modint<998244353>;
using modint1000000007 = modint<1000000007>;

namespace internal {

template<class T> struct is_modint : public std::false_type {};
template<int id> struct is_modint<modint<id>> : public std::true_type {};

template<class T, class ENABLE = void> struct is_static_modint : public std::false_type {};
template<int id> struct is_static_modint<modint<id>, std::enable_if_t<(id > 0)>> : public std::true_type {};
template<class T> using is_static_modint_t = std::enable_if_t<is_static_modint<T>::value>;

template<class T, class ENABLE = void> struct is_dynamic_modint : public std::false_type {};
template<int id> struct is_dynamic_modint<modint<id>, std::enable_if_t<(id <= 0)>> : public std::true_type {};
template<class T> using is_dynamic_modint_t = std::enable_if_t<is_dynamic_modint<T>::value>;

} // namespace internal


} // namespace felix

#line 8 "test/tree/hld/yosupo-Vertex-Set-Path-Composite.test.cpp"
using namespace std;
using namespace felix;

using mint = modint998244353;

struct S {
	pair<mint, mint> f, g;

	S() : S(1, 0) {}
	S(mint a, mint b) : f(a, b), g(a, b) {}
	S(pair<mint, mint> a, pair<mint, mint> b) : f(a), g(b) {}
};

pair<mint, mint> composition(pair<mint, mint> f, pair<mint, mint> g) { return {f.first * g.first, f.first * g.second + f.second}; }

S e() { return S(); }
S op(S a, S b) { return S(composition(a.f, b.f), composition(b.g, a.g)); }

int main() {
	ios::sync_with_stdio(false);
	cin.tie(0);
	int n, q;
	cin >> n >> q;
	vector<S> a(n);
	for(int i = 0; i < n; i++) {
		mint c, d;
		cin >> c >> d;
		a[i] = S(c, d);
	}
	HLD hld(n);
	for(int i = 0; i < n - 1; i++) {
		int u, v;
		cin >> u >> v;
		hld.add_edge(u, v);
	}
	hld.build();
	segtree<S, e, op> seg(n);
	for(int i = 0; i < n; i++) {
		seg.set(hld.id[i], a[i]);
	}
	while(q--) {
		int type, x, y, z;
		cin >> type >> x >> y >> z;
		if(type == 0) {
			seg.set(hld.id[x], S(y, z));
		} else {
			pair<mint, mint> res = {1, 0};
			for(auto [u, v] : hld.get_path(x, y, true)) {
				if(hld.id[u] <= hld.id[v]) {
					res = composition(seg.prod(hld.id[u], hld.id[v] + 1).g, res);
				} else {
					res = composition(seg.prod(hld.id[v], hld.id[u] + 1).f, res);
				}
			}
			cout << res.first * z + res.second << "\n";
		}
	}
	return 0;
}
Back to top page