Package net.javlov.example.rooms

Source Code of net.javlov.example.rooms.Main

/*
* Javlov - a Java toolkit for reinforcement learning with multi-agent support.
*
* Copyright (c) 2009 Matthijs Snel
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/
package net.javlov.example.rooms;

import java.awt.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import javax.swing.Timer;

import net.javlov.*;
import net.javlov.policy.EGreedyPolicy;
import net.javlov.world.*;
import net.javlov.world.grid.*;
import net.javlov.world.phys2d.*;
import net.javlov.world.ui.GridWorldView;
import net.phys2d.raw.shapes.Circle;
import net.phys2d.raw.shapes.StaticBox;
import net.javlov.example.ExperimentGUI;
import net.javlov.example.GridLimitedOptionsWorld;

public class Main implements Runnable {
 
  GridLimitedOptionsWorld world;
  GridRewardFunction rf;
  Simulator sim;
  TabularQFunction qf;
  boolean gui;
  int cellwidth, cellheight;
 
  public static void main(String[] args) {
    Main m = new Main();
    m.gui = true;
    m.init();
    m.start();
  }
 
  public void init() {
    cellwidth = 50; cellheight = cellwidth;

    makeWorld();   
   
    List<? extends Option> optionPool = makeOptions();
    world.setOptionPool(optionPool);
   
    Agent a = makeAgent(optionPool);
    AgentBody aBody = makeAgentBody();
   
    sim = new Simulator();
    sim.setEnvironment(world);
   
    world.add(a, aBody);
    sim.setAgent(a);
  }
 
  protected Agent makeAgent(List<? extends Option> optionPool) {
    qf = TabularQFunction.getInstance(optionPool.size());
    SarsaAgent a = new QLearningAgent(qf, 1, optionPool);
    a.setLearnRate(new DecayingLearningRate(1, optionPool.size(), 0.8));
    Policy pi = new EGreedyPolicy(qf, 0.1, optionPool);
    a.setPolicy(pi);
    a.setSMDPMode(false);
    return a;
  }
 
  protected AgentBody makeAgentBody() {
    Phys2DAgentBody aBody = new Phys2DAgentBody(new Circle(20), 0.5f);
    GridGPSSensor gps = new GridGPSSensor(cellwidth, cellheight);
    gps.setBody(aBody);
    aBody.add(gps);
    return aBody;
  }
 
  protected List<? extends Option> makeOptions() {
    List<Action> primitiveActions = new ArrayList<Action>();
    primitiveActions.add(GridMove.getNorthInstance(world));
    primitiveActions.add(GridMove.getEastInstance(world));
    primitiveActions.add(GridMove.getSouthInstance(world));
    primitiveActions.add(GridMove.getWestInstance(world));
   
    List<Option> optionPool = new ArrayList<Option>();
    Option o = new ReachHallOption("R1H1", 0, 8, 0, 9, new Point(4,10), new Point(9,2), primitiveActions);
    o.setID(optionPool.size());
    optionPool.add( o );
   
    o = new ReachHallOption("R1H4", 0, 8, 0, 9, new Point(9,2), new Point(4,10), primitiveActions);
    o.setID(optionPool.size());
    optionPool.add( o );
   
    o = new ReachHallOption("R2H1", 10, 19, 0, 7, new Point(16,8), new Point(9,2), primitiveActions);
    o.setID(optionPool.size());
    optionPool.add( o );
   
    o = new ReachHallOption("R2H2", 10, 19, 0, 7, new Point(9,2), new Point(16,8), primitiveActions);
    o.setID(optionPool.size());
    optionPool.add( o );
   
    o = new ReachHallOption("R3H2", 10, 19, 9, 19, new Point(9,15), new Point(16,8), primitiveActions);
    o.setID(optionPool.size());
    optionPool.add( o );
   
    o = new ReachHallOption("R3H3", 10, 19, 9, 19, new Point(16,8), new Point(9,15), primitiveActions);
    o.setID(optionPool.size());
    optionPool.add( o );
   
    o = new ReachHallOption("R4H3", 0, 8, 11, 19, new Point(4,10), new Point(9,15), primitiveActions);
    o.setID(optionPool.size());
    optionPool.add( o );
   
    o = new ReachHallOption("R4H4", 0, 8, 11, 19, new Point(9,15), new Point(4,10), primitiveActions);
    o.setID(optionPool.size());
    optionPool.add( o );
   
    return optionPool;
  }
 
  protected void makeWorld() {
    world = new GridLimitedOptionsWorld(20, 20, cellwidth, cellheight);
   
    Body wall = new Phys2DBody( new StaticBox(4*cellwidth, cellheight), 10, true );
    wall.setLocation(2*cellwidth, 10*cellheight+0.5*cellheight);
    wall.setType(Body.OBSTACLE);
    world.addFixedBody(wall);
   
    wall = new Phys2DBody( new StaticBox(4*cellwidth, cellheight), 10, true );
    wall.setLocation(7*cellwidth, 10*cellheight+0.5*cellheight);
    wall.setType(Body.OBSTACLE);
    world.addFixedBody(wall);
   
    wall = new Phys2DBody( new StaticBox(cellwidth, 2*cellheight), 10, true );
    wall.setLocation(9*cellwidth+0.5*cellwidth, cellheight);
    wall.setType(Body.OBSTACLE);
    world.addFixedBody(wall);
   
    wall = new Phys2DBody( new StaticBox(cellwidth, 12*cellheight), 10, true );
    wall.setLocation(9*cellwidth+0.5*cellwidth, 9*cellheight);
    wall.setType(Body.OBSTACLE);
    world.addFixedBody(wall);
   
    wall = new Phys2DBody( new StaticBox(cellwidth, 4*cellheight), 10, true );
    wall.setLocation(9*cellwidth+0.5*cellwidth, 18*cellheight);
    wall.setType(Body.OBSTACLE);
    world.addFixedBody(wall);
   
    wall = new Phys2DBody( new StaticBox(6*cellwidth, cellheight), 10, true );
    wall.setLocation(13*cellwidth, 8*cellheight+0.5*cellheight);
    wall.setType(Body.OBSTACLE);
    world.addFixedBody(wall);
   
    wall = new Phys2DBody( new StaticBox(3*cellwidth, cellheight), 10, true );
    wall.setLocation(18.5*cellwidth, 8*cellheight+0.5*cellheight);
    wall.setType(Body.OBSTACLE);
    world.addFixedBody(wall);
   
    rf = new GridRewardFunction();
    world.setRewardFunction(rf);
    world.addCollisionListener(rf);
   
    GoalBody goal = new GoalBody(825, 425);
    goal.setReward(0);
    world.addFixedBody(goal);
  }
 
  public void start() {
    if ( gui ) {
      GridWorldView wv = new GridWorldView(world);
      Timer timer = new Timer(1000/24, wv);
      ExperimentGUI g = new ExperimentGUI("Rooms example", wv, sim);
      timer.start();
      new Thread(this).start();
    } else
      run();
  }

  @Override
  public void run() {
    int episodes = 5000;
    EpisodicRewardStepStatistic stat = new EpisodicRewardStepStatistic(episodes);
    sim.addStatistic(stat);
    sim.init();
    sim.suspend();
    sim.runEpisodes(episodes);
    System.out.println(Arrays.toString(stat.getRewards()));
    System.out.println(qf);
   
  }

}
TOP

Related Classes of net.javlov.example.rooms.Main

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.