TLE Camera dei segreti con FFT

Buongiorno a tutti, stavo provando a risolvere camera dei segreti con fft ma già dal 3 subtask inizia ad andare in tle. la mia idea era calcolare p(x) come la produttoria di (x+Ri) per i=0,i<N, e poi sostituire a x Bi, e moltiplicando ans per p(Bi). tuttavia il codice è molto lento. qualche consiglio?
Ecco il codice, sto usando l’fft che si trova su usaco guide.

#include <bits/stdc++.h>
using namespace std;

constexpr int MOD=104857601;
using ll=long long;

namespace USACO {

using db = double;  // or double, if TL is tight
using str = string;      // yay python!

using vl = vector<ll>;
using vi = vector<int>;

#define tcT template <class T
#define tcTU tcT, class U
tcT > using V = vector<T>;
tcT, size_t SZ > using AR = array<T, SZ>;
tcT > using PR = pair<T, T>;

// pairs
#define mp make_pair
#define f first
#define s second

#define sz(x) int((x).size())

// loops
#define FOR(i, a, b) for (int i = (a); i < (b); ++i)
#define F0R(i, a) FOR(i, 0, a)
#define ROF(i, a, b) for (int i = (b)-1; i >= (a); --i)
#define R0F(i, a) ROF(i, 0, a)
#define each(a, x) for (auto &a : x)

// INPUT
#define tcTUU tcT, class... U
tcT > void re(T &x) { cin >> x; }
tcTUU > void re(T &t, U &...u) {
	re(t);
	re(u...);
}
tcT > void re(V<T> &x) { each(a, x) re(a); }

void setPrec() { cout << fixed << setprecision(15); }
void unsyncIO() { cin.tie(0)->sync_with_stdio(0); }
void setIO() {
	unsyncIO();
	setPrec();
}

#define rep(i, a, b) for (int i = a; i < (b); ++i)
typedef pair<int, int> pii;

typedef complex<double> C;
void fft(vector<C> &a) {
	int n = sz(a), L = 31 - __builtin_clz(n);
	static vector<complex<long double>> R(2, 1);
	static vector<C> rt(2, 1);  // (^ 10% faster if double)
	for (static int k = 2; k < n; k *= 2) {
		R.resize(n);
		rt.resize(n);
		auto x = polar(1.0L, acos(-1.0L) / k);
		rep(i, k, 2 * k) rt[i] = R[i] = i & 1 ? R[i / 2] * x : R[i / 2];
	}
	vi rev(n);
	rep(i, 0, n) rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
	rep(i, 0, n) if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (int k = 1; k < n; k *= 2)
		for (int i = 0; i < n; i += 2 * k) rep(j, 0, k) {
				// C z = rt[j+k] * a[i+j+k]; // (25% faster if hand-rolled)  ///
				// include-line
				auto x = (double *)&rt[j + k],
				     y = (double *)&a[i + j + k];  /// exclude-line
				C z(x[0] * y[0] - x[1] * y[1],
				    x[0] * y[1] + x[1] * y[0]);  /// exclude-line
				a[i + j + k] = a[i + j] - z;
				a[i + j] += z;
			}
}

typedef vector<ll> vl;
vl convMod(const vl &a, const vl &b) {
	if (a.empty() || b.empty()) return {};
	vl res(sz(a) + sz(b) - 1);
	int B = 32 - __builtin_clz(sz(res)), n = 1 << B, cut = int(sqrt(MOD));
	vector<C> L(n), R(n), outs(n), outl(n);
	rep(i, 0, sz(a)) L[i] = C((int)a[i] / cut, (int)a[i] % cut);
	rep(i, 0, sz(b)) R[i] = C((int)b[i] / cut, (int)b[i] % cut);
	fft(L), fft(R);
	rep(i, 0, n) {
		int j = -i & (n - 1);
		outl[j] = (L[i] + conj(L[j])) * R[i] / (2.0 * n);
		outs[j] = (L[i] - conj(L[j])) * R[i] / (2.0 * n) / 1i;
	}
	fft(outl), fft(outs);
	rep(i, 0, sz(res)) {
		ll av = ll(real(outl[i]) + .5), cv = ll(imag(outs[i]) + .5);
		ll bv = ll(imag(outl[i]) + .5) + ll(real(outs[i]) + .5);
		res[i] = ((av % MOD * cut + bv) % MOD * cut + cv) % MOD;
	}
	return res;
}
}

int solve(int N,vector<int>a,vector<int>b){
vector<ll>A={a[0],1};
vector<ll>B={a[1],1};
vector<ll>res=USACO::convMod(A,B);


    for(int i=2;i<N;i++){
        A={a[i],1};
        res=USACO::convMod(res,A);
    }

ll ans=1;

    for(int i=0;i<N;i++){
        ll temp=0;
        ll power=b[i]%MOD;
        for(int j=0;j<res.size();j++){
            if(j>0){
                temp=(temp+(res[j]*power)%MOD)%MOD;
                power=(power*b[i])%MOD;
            }
            else
                temp=(temp+res[j])%MOD;
        }
        ans=(ans*temp)%MOD;
    }

    return ans;
}

Come hai correttamente trovato il problema si divide in due parti. La prima è trovare p(x)= \prod_{i=0}^{N-1}{(x+R_i)}, l’altra trovare ans=\prod_{i=0}^{N-1}{p(B_i)}.

Analizziamo separatamente le due parti.

Parte 1: tu al momento stai moltiplicando in ordine tutti i (x+R_i) in un accumulatore (res). Con ogni moltiplicazione il grado di res aumenta di 1, quindi questa fase ha complessità \sum_{n=2}^{N}{n\log n} = \mathcal{O}(n^2 \log n).
Ora, chiaramente servono n - 1 moltiplicazioni per calcolare p(x), però puoi cercare di ottimizzare la “lunghezza” delle moltiplicazioni per abbassare la complessità a \mathcal{O}(n \log^2 n). Hint: divide and conquer.

Parte 2: Al momento tu stai valutando in n punti il polinomio p(x) di grado n in modo naive. La complessità è quindi \mathcal{O}(n^2). Purtroppo ottimizzare questa parte è più difficile, provo a darti l’idea generale dell’algoritmo ma tralascerò necessariamente qualche dettaglio.

Multi-point evaluation: Facciamo intanto un’osservazione: se q(x_0)=0, allora, dato r(x) \equiv p(x) \pmod{q(x)}, è vero che p(x_0) = r(x_0). Questo perché per definizione esiste un a(x) tale che p(x)=a(x)\cdot q(x)+r(x) e dunque p(x_0)=a(x_0)\cdot q(x_0)+r(x_0) = a(x_0)\cdot 0+r(x_0) = r(x_0).

Supponiamo di dover valutare p(x) su n punti a_0, a_1, \dots, a_{n-1}.
Dividiamo allora gli n in due sottoinsiemi A'=\{{a_0, \dots, a_{\left \lfloor n/2 \right \rfloor}}\} e A''=\{a_{\left \lfloor n/2 \right \rfloor+1}, \dots, a_{n-1} \}. Notiamo che il polinomio p'(x)=\prod_{x_0 \in A'}^{}{(x - x_0)} si annulla in tutti i punti in A' e similmente p''(x)=\prod_{x_0 \in A''}^{}{(x - x_0)} in tutti i punti di A''.

Allora detti r'(x) \equiv p(x) \pmod{p'(x)} e r''(x) \equiv p(x) \pmod{p''(x)} il problema si è ricondotto a valutare r'(x) sugli elementi di A' e r''(x) sugli elementi di A''. Si può facilmente verificare che A' e A'' contengono la metà dei punti di A e che r'(x) e r''(x) hanno grado al più metà di quello di p(x).

Abbiamo quindi una soluzione ricorsiva, il cui caso base è valutare un polinomio di grado 0 su un singolo punto, che può essere banalmente fatto in \mathcal{O}(1).

Per trovare la complessità è necessario analizzare il tempo richiesto dalle riduzioni modulari. Notiamo che i polinomi p'(x) e p''(x) possono essere costruiti dal basso in \mathcal{O}(n \log^2 n), resta da capire con che complessità riusciamo a ridurre p(x) modulo r'(x) e r''(x).

In breve, p(x) = a(x) \cdot p'(x)+r'(x) \implies r'(x)=p(x) - a(x)\cdot p'(x). Quindi se riusciamo a trovare a(x) possiamo trovare (con una moltiplicazione con FFT e una sottrazione) r'(x). Si può trovare anche a(x) in \mathcal{O}(n \log n) adattando FFT, ma i dettagli sono abbastanza off topic, quindi ti lascio solo il link a un possibile approccio.

In conclusione ogni chiamata ricorre su due casi grossi la metà e ha un overhead \mathcal{O}(n \log n) dovuto alla riduzioni modulari di p(x). Abbiamo allora T(n)=2T(\frac{n}{2})+\mathcal{O}(n \log n) e dunque T(n) \in \mathcal{O}(n \log^2 n).

A te la scelta se valga davvero la pena di implementarlo :slight_smile:

4 Mi Piace

grazie mille per la spiegazione. proverò sicuramente a dare un occhiata anche al materiale che hai allegato.