public class SvmSgdAdvancedRunner extends TaskWorker
computeEnvironment, config, persistentVolume, taskExecutor, volatileVolume, workerController, workerEnvironment, workerId
Constructor and Description |
---|
SvmSgdAdvancedRunner() |
Modifier and Type | Method and Description |
---|---|
void |
execute()
A user needs to implement this method to create the task graph and execute it
|
DataObject<double[]> |
executeIterativeTrainingGraph()
This method executes the iterative training graph
Training is done in parallel depending on the parallelism factor given
In this implementation the data loading parallelism and data computing or
training parallelism is same.
|
DataObject<java.lang.Object> |
executeTestingDataLoadingTaskGraph()
This method loads the testing data
The loaded test data is used to evaluate the trained data
Testing data is loaded in parallel depending on the parallelism parameter given
There are partitions created equal to the parallelism
Later this will be used to do the testing in parallel in the testing task graph
|
DataObject<java.lang.Object> |
executeTestingTaskGraph()
This method executes the testing taskgraph with testing data loaded from testing taskgraph
and uses the final weight vector obtained from the training task graph
Testing is also done in a parallel way.
|
DataObject<java.lang.Object> |
executeTrainingDataLoadingTaskGraph()
This method loads the training data in a distributed mode
dataStreamerParallelism is the amount of parallelism used
in loaded the data in parallel.
|
DataObject<double[]> |
executeTrainingGraph()
This method executes the training graph
Training is done in parallel depending on the parallelism factor given
In this implementation the data loading parallelism and data computing or
training parallelism is same.
|
DataObject<java.lang.Object> |
executeWeightVectorLoadingTaskGraph()
This method loads the training data in a distributed mode
dataStreamerParallelism is the amount of parallelism used
in loaded the data in parallel.
|
void |
initializeExecute()
Initializing the execute method
|
void |
initializeParameters()
This method initializes the parameters in running SVM
|
void |
printTaskSummary() |
DataObject<java.lang.Object> |
retrieveTestingAccuracyObject(ComputeGraph predictionGraph,
ExecutionPlan predictionPlan)
This method retrieves the accuracy data object from the prediction task graph
|
DataObject<double[]> |
retrieveWeightVectorFromTaskGraph(ComputeGraph graph1,
ExecutionPlan plan1)
This method returns the final weight vector from the trained model
|
double |
retriveFinalTestingAccuracy(DataObject<java.lang.Object> finalRes)
Calculates the final accuracy by taking the dataParallelism in to consideration
Here the parallelism is vital as we need to know the average accuracy produced by
each testing data set.
|
void |
saveResults() |
execute
public void execute()
TaskWorker
execute
in class TaskWorker
public void initializeParameters()
public void initializeExecute()
public DataObject<java.lang.Object> executeTrainingDataLoadingTaskGraph()
public DataObject<java.lang.Object> executeWeightVectorLoadingTaskGraph()
public DataObject<java.lang.Object> executeTestingDataLoadingTaskGraph()
public DataObject<double[]> executeTrainingGraph()
public DataObject<double[]> executeIterativeTrainingGraph()
public DataObject<double[]> retrieveWeightVectorFromTaskGraph(ComputeGraph graph1, ExecutionPlan plan1)
graph1
- DataflowTaskGraph from which we retrieve the final weight vectorplan1
- ExecutionPlan from which we retrive the final weight vectorpublic DataObject<java.lang.Object> executeTestingTaskGraph()
public DataObject<java.lang.Object> retrieveTestingAccuracyObject(ComputeGraph predictionGraph, ExecutionPlan predictionPlan)
predictionGraph
- ComputeGraph from which the final accuracy is retrievedpredictionPlan
- PredictionTaskGraph from which the final accuracy is retrievedpublic double retriveFinalTestingAccuracy(DataObject<java.lang.Object> finalRes)
finalRes
- DataObject which contains the final accuracypublic void printTaskSummary()
public void saveResults() throws java.io.IOException
java.io.IOException