view src/de/mpiwg/anteater/ml/MLController.java @ 0:036535fcd179

anteater
author jdamerow
date Fri, 14 Sep 2012 10:30:43 +0200
parents
children
line wrap: on
line source

package de.mpiwg.anteater.ml;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import de.mpiwg.anteater.AnteaterConfiguration;
import de.mpiwg.anteater.ml.impl.StanfordNLPTextParser;
import de.mpiwg.anteater.ml.impl.WekaMLComponent;
import de.mpiwg.anteater.ml.preprocessing.DataCreator;
import de.mpiwg.anteater.persons.APerson;
import de.mpiwg.anteater.persons.PersonsExtraction;
import de.mpiwg.anteater.persons.ml.preprocessing.ApplicantDataCreator;
import de.mpiwg.anteater.places.Place;
import de.mpiwg.anteater.places.PlaceInformation;
import de.mpiwg.anteater.places.PlacesExtraction;
import de.mpiwg.anteater.places.ml.preprocessing.LocationDataCreator;
import de.mpiwg.anteater.results.ApplicantResult;
import de.mpiwg.anteater.results.LocationResult;
import de.mpiwg.anteater.results.SpeciesScientificResult;
import de.mpiwg.anteater.text.TextInformation;

public class MLController {
	
	public final static String COMPONENT_NAME = MLController.class.getSimpleName();

	private AnteaterConfiguration configuration;
	
	public MLController(AnteaterConfiguration configuration) {
		this.configuration = configuration;
	}
	
	public List<ApplicantResult> runApplicantMLComponent(List<TextInformation> infos) {
		configuration.getLogger().logMessage(COMPONENT_NAME, "Run Machine Learning component...");
		
		DataCreator dataCreator = new ApplicantDataCreator(configuration);
		
		List<String> arffFiles = new ArrayList<String>();
		for (TextInformation info : infos) {
			String file = dataCreator.createARFFFile(info, new StanfordNLPTextParser());
			if (file != null)
				arffFiles.add(file);
		}
		
		IMLComponent mlComponent = new WekaMLComponent("Applicant_LADTree.model");
		
		List<ApplicantResult> mlresults = new ArrayList<ApplicantResult>();
		for (String arffFile : arffFiles) {
			List<Double> predictions = mlComponent.run(arffFile);
			int idx = arffFiles.indexOf(arffFile);
			TextInformation info = infos.get(idx);
			
			List<PersonsExtraction> results = info.getPersonsExtractions();
			Map<APerson, PersonsExtraction> persons = new HashMap<APerson, PersonsExtraction>();
			
			List<APerson> ps = new ArrayList<APerson>();
			for (PersonsExtraction r : results) {
				ps.addAll(r.getPersons());
				for (APerson p : r.getPersons())
					persons.put(p, r);
			}
			
			
			for (int i = 0; i < predictions.size(); i++) {
				
				ApplicantResult result = new ApplicantResult();
				
				result.setFinding(ps.get(i));
				result.setResult(persons.get(ps.get(i)));
				result.setTextInfo(info);
				result.setPrediction(predictions.get(i));
				mlresults.add(result);				
			}
		}
		
		return mlresults;
	}
	
	public List<LocationResult> runLocationMLComponent(List<TextInformation> infos, List<SpeciesScientificResult> predictedSpecies, List<ApplicantResult> predictedApplicants) {
		configuration.getLogger().logMessage(COMPONENT_NAME, "Run Machine Learning component for locations...");
		
		DataCreator dataCreator = new LocationDataCreator(configuration, predictedSpecies, predictedApplicants);
		
		List<String> arffFiles = new ArrayList<String>();
		for (TextInformation info : infos) {
			String file = dataCreator.createARFFFile(info, new StanfordNLPTextParser());
			if (file != null)
				arffFiles.add(file);
		}
		
		
		IMLComponent mlComponent = new WekaMLComponent("Location_LMT_moreTraining.model");
		
		List<LocationResult> mlresults = new ArrayList<LocationResult>();
		for (String arffFile : arffFiles) {
			List<Double> predictions = mlComponent.run(arffFile);
			int idx = arffFiles.indexOf(arffFile);
			TextInformation info = infos.get(idx);
			
			List<PlacesExtraction> results = info.getPlacesExtractions();
			List<PlaceResultMapping> mappings = new ArrayList<MLController.PlaceResultMapping>();
			
			for (PlacesExtraction r : results) {
				for (PlaceInformation pi : r.getPlaceInformation()) {
					for (Place p : pi.getPlaces())
						mappings.add(new PlaceResultMapping(pi, p, r));
				}
			}
			
			
			for (int i = 0; i < predictions.size(); i++) {
				
				LocationResult result = new LocationResult();
				PlaceResultMapping mapping = mappings.get(i);
				
				result.setFinding(mapping.placeInformation);
				result.setResult(mapping.placesExtraction);
				result.setPlace(mapping.place);
				result.setTextInfo(info);
				result.setPrediction(predictions.get(i));
				mlresults.add(result);				
			}
		}
		
		return mlresults;
	}
	
	class PlaceResultMapping {
		public PlaceInformation placeInformation;
		public Place place;
		public PlacesExtraction placesExtraction;
		
		public PlaceResultMapping(PlaceInformation placeInformation,
				Place place, PlacesExtraction placesExtraction) {
			super();
			this.placeInformation = placeInformation;
			this.place = place;
			this.placesExtraction = placesExtraction;
		}
		
		
	}
}