Slalom selezioni nazionali 2006

Sto provando a risolvere Slalom, ma non ottengo più di 5/100, anche se non capisco perché il mio algoritmo possa sbagliare.
Ecco quello che faccio:

  1. scelgo il nodo 1 come radice, poi da lì faccio partire una dfs che costruisce un ordinamento topologico e crea una nuova lista di adiacenza che trasforma, in pratica, l’albero in un grafo diretto (togliendo i collegamenti che risalirebbero al contrario).
  2. scorrendo sul vettore dove ho memorizzato l’ordinamento, faccio una sorta di DP in questo modo: per ogni nodo faccio questo:
    • A) per ognuno dei suoi figli, se nella soluzione ottima per quel figlio il figlio stesso non è presente, aggiungo il suo costo a una certa variabile.
    • B) ora, se la somma di tutti questi costi è maggiore del valore del nodo che sto valutando, significa che conviene scegliere anche il nodo in questione, e questo valore sommato alla somma delle sol. ottime dei figli è la sua soluzione ottima. Altrimenti, non scelgo questo nodo, ma la sua soluzione ottima è pari alla somma delle sol.ottime dei figli, a cui aggiungo la somma dei costi per i nodi che prima di ora non erano presenti nella loro soluzione ottima.

Questo è il codice:

    #include <iostream>
    #include <fstream>
    #include <vector>
    using namespace std;

    ifstream in("input.txt");
    ofstream out("output.txt");

    typedef long long ll;
    ll n;
    vector<ll> u[500000];
    vector<ll> v[500000];
    vector<ll> t;
    bool e[500000];
    ll val[500000];
    ll sol[500000];
    vector<ll> nf;

    void dfs(ll k,ll p){
    	bool f=1;
    	for(int i=0;i<u[k].size();i++){
    		if(u[k][i]!=p){
    			f=0;
    			dfs(u[k][i],k);
    			v[k].push_back(u[k][i]);
    		}
    	}
    	t.push_back(k);
    	if(f)sol[k]=val[k];
    }

    int main(){
    	in>>n;
    	for(int i=0;i<n;i++){
    		in>>val[i];
    	}
    	for(int i=0;i<n-1;i++){
    		int a,b; in>>a>>b;
    		u[a-1].push_back(b-1);
    		u[b-1].push_back(a-1);
    		e[i]=0;
    	}
    	e[n-1]=0;
    	dfs(0,-1);
    	for(int i=0;i<n;i++){
    		ll no=t[i];
    		ll sn=0;
    		ll sum=0;
    		vector<ll> nd;
    		for(int y=0;y<v[no].size();y++){
    			if(!e[v[no][y]]){
    				sn+=val[v[no][y]];
    				nd.push_back(v[no][y]);
    			}
    			sum+=sol[v[no][y]];
    		}
    		if(val[no]<sn){
    			e[no]=1;
    			sol[no]=sum+val[no];
    		}
    		else if(sn<val[no]){
    			for(int j=0;j<nd.size();j++)e[nd[j]]=1;
    			sol[no]=sum+sn;
    		}
    		else{
    			e[no]=1;
    			sol[no]=sum+val[no];
    		}
    		
    	}
    	for(int i=0;i<n;i++){
    		if(e[i])nf.push_back(i);
    	}
    	out<<nf.size()<<endl;
    	for(int i=0;i<nf.size();i++)out<<nf[i]+1<<" ";
    	
        
    }

Grazie per l’aiuto!

1 Mi Piace