# library python imports
from math import floor,ceil,sqrt
from operator import add
from array import array
from heap import Heap

### MAP_SHAPE determines which distance cost algorithms get used
#MAP_SHAPE = 'diamond_v' # diamond shaped - stacked vertically
#MAP_SHAPE = 'diamond_h' # diamond shaped - stacked horizonally
#MAP_SHAPE = 'square_sg' # grid maps that don't allow diagonal movement
#MAP_SHAPE = 'square_sd' # grid maps that allow diagonal movement
MAP_SHAPE = 'rectangle' # array like layout

# A* distance cost algorithms. Either C or python.
from distance_cost import dc_algorithm


class Pathfinder(object):
    """
    Some documentation here.
    """
    def __init__(self,map):
        """Initializes the pathfinder. 
        """
        self.map = map
        self._init_terrain_cost()

    # The static terrain cost makes 2 assumptions; first that the terrain
    # doesn't change, second that units have the same movement costs
    # the cost calculation is oversimplified. but works well for now.
    # 
    # Note: This no longer works as the terrain costs are different for each
    # type of unit. Should there be a terrain map created for each unit type or
    # should these be calculated on the fly. Classic memory/speed tradeoff.
    # Though terrain maps don't use much memory, so I'm leading toward multiple
    # maps.

    tc_map = None
    def _init_terrain_cost(self):
        if self.tc_map: return
        xsize,ysize = self.map.size
        getHex = self.map.getHex
        tc_map = self.tc_map = []
        # temporary hack until units can be handled
        class Unit: # fake infantry unit
            def getType(self): return 0
        unit = Unit()
        tc_map[:] = [array('f', [0] * xsize) for i in range(ysize)]
        for y in range(ysize):
            for x in range(xsize):
                tc_map[y][x] = getHex(x,y).terrain.movementSpeedModifier(unit) 

    # node = (g_cost,depth,point,parent_node)
    # point is both the x,y coord and the state
    # g_cost is cost so far
    def _new_node(self,p,parent,g_cost):
        if not parent:
            return (0,0,p,parent)
        else:
            return (g_cost,parent[1]+1,p,parent)

    # getHexAdjacent() includes hexes off the edge of the map. 
    def _in_bounds(self,(x,y)):
        xedge,yedge = self.map.size
        return x > -1 and y > -1 and x < xedge and y < yedge


    def calculatePath(self,start,end,_test=0):
        """ Return the path as a list of coordinates.
        
            if _test, then pass back 2 lists; path and nodes checked but
            not used in path
        """
        
        tc_map = self.tc_map
        p = start
        p1 = end    
        # use heap as priority queue 
        node_queue = Heap()
        # stores all nodes already checked - this assumes the hex will 
        # have the same movement cost no matter how its traversed
        skip_nodes = {}
        # some micro optimizations
        distance_cost=dc_algorithm[MAP_SHAPE]
        getNeighbors = self.map.getNeighbors
        new_node =  self._new_node
        
        # Classic A* does while loop over queue (len test) and uses a break on
        # the location matching test. I found it easier to reverse these 2
        # because of python's handy while: else: syntax
        node = new_node(p,None,0)
        while node[2] != p1:
            x,y = node[2]
            new_points = [(x,y) for (x,y,b) in getNeighbors(x,y)]
            parent_cost = node[0]
            for pN in new_points:
                if skip_nodes.has_key(pN):
                    f,g = skip_nodes[pN]
                else:
                    terrain = tc_map[pN[0]][pN[1]]
                    # h is the heuristic of future cost
                    # assumes an average movement cost of 1
                    h = distance_cost(pN,p1) 
                    actual_cost = 1/terrain
                    # g is the cost so far
                    g = parent_cost + actual_cost
                    f = h + g
                    node_queue.hpush(f,new_node(pN,node,g))
                    skip_nodes[pN]=f,g
            if not node_queue:
                path = []
                break
            node_cost,node = node_queue.hpop()
        else:
            # what we're here for
            path = []
            while node: #get path
                path.append(node[2])
                node = node[3]
            path.reverse()

        if _test:
            return path,skip_nodes
        else:
            return path


## pseudo-code for path smoothing. ignore for now.
def simpleSmooth(self,path):
    assert type(path) is ListType and path, 'Bad arg type: simpleSmooth([list])'
    checkPoint = path[0]
    currentPoint = path[1]
    for i in len(path):
        if traversable(checkPoint, currentPoint):
            del path[1]
            currentPoint = path[1]
        else:
            checkPoint, currentPoint = currentPoint








    
