from GN.FunctionLibraries import TransformsLibrary
from GN.FunctionLibraries import CurveLibrary
from polygonflow.math import Vector3f
from GN.FunctionLibraries import SceneLibrary
from GN.Standard.Libraries.SceneContainer import SceneContainer, AssetContainer
from GN.Standard.ObjectsInterface import SceneObjectInterface
from GN import Generic
from GN.FunctionLibraries import PointsLibrary
from GN.FunctionLibraries.PointsLibrary import PointOps
from GN.FunctionLibraries.MeshLibrary import GNMesh
from GN.FunctionLibraries import MathLibrary
from GN.FunctionLibraries import BaseLibrary
from GN.FunctionLibraries import ArrayLibrary
from GN.FunctionLibraries import MeshLibrary
from GN.FunctionLibraries.NoiseLibrary import NoiseLib
from polygonflow.primitives import Lines

from GN.FunctionLibraries import RandomLibrary
from GN.FunctionLibraries import TexturesLibrary
from GN.Generic import NodeConnectionInfo

from GN.Core.Data import DataRegistry
from GN.Standard.DCCUtils import DCCUtils
import math, os, sys, json
from polygonflow.array import BitArray, FloatArray, IntArray, VectorArray

from GN.GNL.ToolCreation import ToolDescriptor, PropertyDescriptor, ToolRunner

CURVE_TABLE = {
    "Values": [
        {"Name": "Curve", "Value": [], "Type": "Container", "Description": "The curves to use for shaping the terrain"},
        {"Name": "Width", "Value": [5.0, 0.0, 1000.0], "Type": "Float",
         "Description": "The distance of curve influence on surrounding terrain"},
        {"Name": "Falloff", "Value": [0.3, 0.0, 1.0], "Type": "Float",
         "Description": "Smoothness of transition between curve influence and terrain"},
        {"Name": "Resample", "Value": [180.0, 0.0, 10000.0], "Type": "Float",
         "Description": "The distance between points along guide curves"},
    ]}


class GraphNTool(ToolRunner):

    def __init__(self, *args):
        super(GraphNTool, self).__init__()
        self._id = "41ec71a533fb5b9b2adasdasdam"
        self.graph_index = 848230
        self.seed = 0
        self.world_vector = DCCUtils.WORLD_UP_VECTOR
        self.is_unreal = DCCUtils.isUnreal()

    def smoothstep(self, values, edge0, edge1):
        denom = edge1 - edge0
        if denom == 0.0:
            denom = 0.0001
        t = (values - edge0) * (1.0 / denom)
        t.clamp_(0.0, 1.0)
        ss = (t * t) * (3.0 - 2.0 * t)

        return ss

    def almost_unit_identity(self, values, n):
        return (values * values + n).sqrt()

    def run(self, material=None, uv_scale=1.0, scale=1.7, curved=0.5, sink=0.25,
            subdivision=3, turbulence=0.5, height=0.8, mid_turbulance=0.75, seed=0,
            height_texture=None, height_multiplier=0.0, falloff_curves=None, falloff=5.0, falloff_interp=0.3,
            falloff_project=True,
            relax_weight=0.0, curve_distance=180.0, curve_table="",
            *args, **kwargs):

        index = self.graph_index
        self.seed = seed

        default_size = 50.0

        origin = Vector3f(0.0, 0.0, 0.0)

        scale_100 = scale * 100.0
        division_factor = max(1, int(scale * 5.0)) * subdivision
        orig_mesh: GNMesh = MeshLibrary.GNMesh.createPlane(position=origin, width=default_size, height=default_size,
                                                           width_divisions=4, height_divisions=4,
                                                           scale_factor=scale_100, divisions_factor=division_factor,
                                                           func_id=self.graph_index + 132123)
        uv_scale = uv_scale * scale
        if uv_scale != 1.0:
            copied_mesh = orig_mesh.transformUVs(0.0, 0.0, angle=0.0, u_scale=uv_scale, v_scale=uv_scale,
                                                 translate_relative=True, mesh_copy=True,
                                                 func_id=self.graph_index + 132125)
        else:
            copied_mesh = orig_mesh.copyMesh(func_id=self.graph_index + 132124)

        # copied_mesh.cacheVVAdjacency(func_id=self.graph_index+132127)
        # NOTE: We can triangulate here as we don't do other ops that require quads, so we can cache this
        copied_mesh = copied_mesh.triangulate(mode=1, func_id=self.graph_index + 132126)
        # indices = copied_mesh.indices(func_id=self.graph_index+132127)
        positions = orig_mesh.data[0].points
        plane_mesh = copied_mesh.data[0]

        if height_texture and height_multiplier > 0.0:
            texture_interfaces = height_texture.toTextureInterfaces()
            tex_invert = self.getPropertyMask('height_texture')
            if texture_interfaces:
                texture_iface = texture_interfaces[0]
                tex = TexturesLibrary.Textures.grayscale(texture_iface.texture, func_id=self.graph_index + 132130)
                hmap = TexturesLibrary.Textures.sample_grayscale(tex, positions.x, positions.y, tex_invert,
                                                                 func_id=self.graph_index + 132129)
                positions.z += hmap * height_multiplier * 1000.0

        ## CURVATURE CALCULATION
        # For curvature it depends how curvy we want it, if we want a parabola like shape
        distances = (origin - positions).length2().remap(0.0, 1.0)

        curved_max = 2.5  # this is the min max value of it, but we need it in the -1, 1 range
        curvature_multiplier = distances * (curved / curved_max * default_size * scale_100 * 0.25)
        # TODO: This is a UI problem, for whatever reason if we go max, it becomes a float
        seed = int(seed)

        ## NOISE MESH
        turb = turbulence * 0.001
        SMART_NOISE = False
        has_noise = abs(turb) < 1e-8
        if has_noise:
            if SMART_NOISE:
                perlin_noise = NoiseLib.noise_perlin()
                simplex_noise = NoiseLib.noise_opensimplex2()
                fbm_noise1 = NoiseLib.noise_FractalFBm(perlin_noise, lacunarity=1.2, octaves=5)
                n1 = NoiseLib.noise_DomainScale(fbm_noise1, turb)
                n1 = NoiseLib.noise_Multiply(n1, 0.7)

                fbm_noise2 = NoiseLib.noise_FractalFBm(simplex_noise, lacunarity=1.0, octaves=3)
                n2 = NoiseLib.noise_DomainScale(fbm_noise2, turb * 3.0)
                n2 = NoiseLib.noise_DomainRotate(n2, 0.3, 0.7, 0.1)
                n2 = NoiseLib.noise_Multiply(n2, 0.05)

                fin_noise = NoiseLib.noise_Subtract(n1, n2)
                fin_noise = NoiseLib.noise_Fade(a=fin_noise, b=n1, fade=0.3)

                noise_data, min_noise, max_noise = NoiseLib.generate_3d_noise(fin_noise, positions, frequency=1.0)

                noise_data = noise_data.remap(0, 1)
                noise_data = noise_data * (100.0 * height)
            else:
                # just copied the data from terrain creation
                two_d_noise = True
                noise_a = RandomLibrary.RandomGenerators.createNoise(points=positions, frequency=turbulence / 100.0,
                                                                     multiplier=5.0, seed=seed,
                                                                     noise_type="NOISE_PERLIN",
                                                                     rotation_type="ROTATION_NONE",
                                                                     fractal_type="FRACTAL_FBM",
                                                                     cellular_dist_func="CELLULAR_DIST_HYBRID",
                                                                     cellular_return="CELLULAR_RET_CELLVALUE",
                                                                     octaves=1, lacunarity=0.5, gain=2,
                                                                     ping_pong_strength=2.0, weighted_strength=5.0,
                                                                     min_value=-10.0, max_value=height * 100.0,
                                                                     return_augmented=False, two_d=two_d_noise)
                mid_turb5 = mid_turbulance * 5.0

                noise_b = RandomLibrary.RandomGenerators.createNoise(points=positions,
                                                                     frequency=mid_turb5 * turbulence / 100.0,
                                                                     multiplier=5.0, seed=seed,
                                                                     noise_type="NOISE_OPENSIMPLEX2S",
                                                                     rotation_type="ROTATION_NONE",
                                                                     fractal_type="FRACTAL_FBM",
                                                                     cellular_dist_func="CELLULAR_DIST_HYBRID",
                                                                     cellular_return="CELLULAR_RET_CELLVALUE",
                                                                     octaves=1, lacunarity=0.5, gain=2,
                                                                     ping_pong_strength=2.0, weighted_strength=5.0,
                                                                     min_value=-10.0, max_value=25.0,
                                                                     return_augmented=False, two_d=two_d_noise)

                true_noise = ArrayLibrary.ArrayMethods.array_lerp(array_a=noise_a, array_b=noise_b, amount=0.25)
                min_, max_ = ArrayLibrary.ArrayMethods.array_min_max(array=true_noise)
                clamped_value, _ = BaseLibrary.clamp_array_value(value=true_noise, minValue=200.0, maxValue=max_)
                noise_data = ArrayLibrary.ArrayMethods.array_lerp(array_a=clamped_value, array_b=true_noise, amount=0.3)

        up_vector = BaseLibrary.get_world_vector(invert=False)

        if has_noise:
            compound_data = noise_data + curvature_multiplier - (sink * 1000.0)
        else:
            compound_data = curvature_multiplier - (sink * 1000.0)
        positions.z += compound_data
        # new_positions = TransformsLibrary.Transform.moveAlongNormals(points=positions, direction=up_vector,
        #                                                              multiplier=compound_data)
        new_positions = positions

        plane_mesh.points = new_positions

        # TODO: For some reason requerying this will result in no output object in a loop
        output_object = self.tool.activeInstance.outputObjects()
        offset = Vector3f(0, 0, 0)
        if output_object:
            # output_object is a set[id]
            soi: SceneObjectInterface = DataRegistry.SceneObjects.get(output_object.pop())
            # Assumption(for now) is that output object is _not_ rotated or scaled as we use terrain
            if soi is not None:
                offset = soi.position

        md_meshes = kwargs.get('md_meshes')
        if md_meshes:
            md_width = kwargs.get('md_width', 5.0)
            md_falloff = kwargs.get('md_falloff', 0.3)
            md_sampling = 100.0 # meh
            self.process_mesh_deformation(
                meshes=md_meshes,
                sampling=md_sampling,
                falloff_width=md_width,
                falloff_interp=md_falloff,
                new_positions=new_positions,
                plane_mesh=plane_mesh,
                offset=offset
            )
            plane_mesh.points = new_positions

        if falloff_curves or curve_table:
            # Process curves from the falloff_curves input
            if falloff_curves:
                self.process_curve_deformation(
                    curves=falloff_curves,
                    distance=curve_distance,
                    falloff_dist=falloff,
                    falloff_interp=falloff_interp,
                    project_curves=falloff_project,
                    new_positions=new_positions,
                    plane_mesh=plane_mesh,
                    offset=offset
                )

            # Process curves from the data table
            if curve_table:
                for entry in curve_table:
                    _curves = entry[0]
                    _falloff = entry[1]
                    _interp = entry[2]
                    _resample = entry[3]

                    if _curves:
                        self.process_curve_deformation(
                            curves=_curves,
                            distance=_resample,
                            falloff_dist=_falloff,
                            falloff_interp=_interp,
                            project_curves=True,
                            new_positions=new_positions,
                            plane_mesh=plane_mesh,
                            offset=offset
                        )

            plane_mesh.points = new_positions

        if relax_weight > 0.0:
            plane_mesh = plane_mesh.relax(weight=relax_weight)
            new_positions = plane_mesh.points

        # an optimized way as accessing it annoying
        xdiv = int(math.sqrt(len(new_positions)))
        # if indices are passed its calculated multiple times but theres no bound check so its fast
        new_normals = VectorArray.height_map_normal(new_positions, xdiv=xdiv, ydiv=xdiv)
        plane_mesh.per_vertex_normals = new_normals

        instanceName = self.getActiveInstance().name
        # Dont cache as our input can be the same just w/ different points as we defer copy
        output_mesh = SceneLibrary.Scene.createMesh(meshes=GNMesh(plane_mesh), name=f'{instanceName}_Terrain',
                                                    material=material,
                                                    identifier=self.graph_index + 416993)  # , func_id=self.graph_index+416993)

    def process_mesh_deformation(self, meshes: SceneContainer, sampling, falloff_width, falloff_interp, new_positions,
                                 plane_mesh,
                                 offset):

        meshes = meshes.filterType(SceneContainer.Types.MESH)
        if not meshes:
            return

        terrain_min, terrain_max = new_positions.minmax()

        def get_min_z(meshes):
            mesh_min = 99999999999999.0
            for soi in meshes:
                sc = SceneContainer([soi])
                mesh = sc.toMeshInterfaces()[0]
                # Terrain might not be in world space so we adjust it here
                mesh_verts = mesh.vertices - offset

                bbox_min, bbox_max = mesh_verts.minmax()
                mesh_min = min(bbox_min.z, mesh_min)
            return mesh_min


        def get_affected_points(meshes, new_positions, offset, xy_offset=0.0):
            affected_points = BitArray.fill(False, len(new_positions))

            # filter the points as we are only interested in those who are nearby, we dont need extra calculations
            for soi in meshes:
                sc = SceneContainer([soi])
                mesh = sc.toMeshInterfaces()[0]
                # Terrain might not be in world space so we adjust it here
                mesh_verts = mesh.vertices - offset

                bbox_min, bbox_max = mesh_verts.minmax()
                bbox_min.x -= xy_offset
                bbox_min.y -= xy_offset

                bbox_max.x += xy_offset
                bbox_max.y += xy_offset

                # we need the terrain points within the bbox of the objects but in the height of the terrain(like a top down projection)
                bbox_min.z = terrain_min.z - 10.0
                bbox_max.z = terrain_max.z + 10.0
                inside_mask = new_positions.inside_bounding_box(bbox_min, bbox_max)
                affected_points |= inside_mask

            return affected_points

        # points directly inside the mesh bounding boxes
        affected_points = get_affected_points(meshes, new_positions, offset, xy_offset=0.0)
        if not affected_points.any():
            return

        filtered_positions = new_positions[affected_points]
        distance = falloff_width * 100.0


        # needed so we can offset during search
        fmin, fmax = filtered_positions.minmax()

        # mask, pos_data, normals, distances, facemask = PointsLibrary.PointOps.getPointsMeshIntersections(meshes=meshes, points=filtered_positions,
        #                                                                       return_type=1, put_up=0,
        #                                                                       normal_override=1, # shoots upwards
        #                                                                       angle=0.0, max_distance=0.0,
        #                                                                       flip=False, func_id=109738)


        # temporarily move the filtered positions so meshes are always above in this query
        min_z = get_min_z(meshes)
        zz = Vector3f(0, 0, min_z-100.0)
        hit_mask, hit_positions, _, hit_distances, _ = PointOps.getClosestPointsOnMeshes(
            meshes=meshes,
            points=filtered_positions+zz+offset,
            distance=1000000.0,
            distance_filter=-1,  # return as is, dont filter by distance
            func_id=self.graph_index + 1337
        )
        # hit_positions = hit_positions-z_offset
        # hit_distances = hit_distances-min_z
        # readjust the hit_positions + hit_distances

        # SceneLibrary.Scene.createPointsVisualizer(filtered_positions+zz+offset,
        #                                           draw_mode=0,
        #                                           size=10)

        if len(hit_positions) == 0:
            return


        # SceneLibrary.Scene.createPointsVisualizer(hit_positions,
        #                                           draw_mode=0,
        #                                           size=10)

        # TODO: fix this
        hit_distances = hit_distances - zz.z - offset.z
        influence_mask = hit_distances < distance
        if influence_mask.any():
            attractor_z = hit_positions[influence_mask].z

            orig_z = filtered_positions.z
            orig_z[influence_mask] = attractor_z

            new_positions.z[affected_points] = orig_z

        affected_points_offset = get_affected_points(meshes, new_positions, offset, xy_offset=distance)
        nearby_affected = affected_points_offset & (~affected_points)

        if not nearby_affected.any():
            return

        # for the other points that were not found with a very close proximity, use the found points instead

        # SceneLibrary.Scene.createPointsVisualizer(hit_positions,
        #                                           draw_mode=0,
        #                                           size=10)
        kd_tree = PointsLibrary.get_cached_kd_tree(hit_positions)

        search_pnts = new_positions[nearby_affected] + offset
        indices, distsq = kd_tree.nearestNeighbour(search_pnts)

        dists = distsq.sqrt()
        falloff_mask = dists < distance

        if not falloff_mask.any():
            return

        hit_positions = hit_positions - offset
        target_z = hit_positions.z[indices[falloff_mask]]

        if falloff_interp > 0.0:
            falloff_interp *= 0.5
            strength = dists[falloff_mask].remap(0, 1)
            strength = self.smoothstep(strength, 0.5 - falloff_interp, 0.5 + falloff_interp)
            local_indices = nearby_affected.to_indices()[falloff_mask]
            new_positions.z[local_indices] = target_z.lerp(new_positions.z[local_indices], strength)
        else:
            local_indices = nearby_affected.to_indices()[falloff_mask]
            new_positions.z[local_indices] = target_z


    def process_curve_deformation(self, curves, distance, falloff_dist, falloff_interp, project_curves, new_positions,
                                  plane_mesh, offset):
        """Helper method to process curve deformation for both direct input and table entries"""
        curves = curves.filterType(SceneContainer.Types.CURVE)

        if curves:
            lines: Lines = CurveLibrary.GNCurve.fromContainer(container=curves, resolution=0,
                                                              func_id=self.graph_index + 934030)
            lines = CurveLibrary.GNCurve.resample(lines=lines, resample_data=distance,
                                                  func_id=self.graph_index + 410539)
            olines = lines
            orig_pnts = lines.get_points()
            curve_pts = orig_pnts
            if project_curves:
                # we offset the terrain points as new_positions are local space
                pnts = new_positions + offset
                # TODO: currently we have uniform grid, but if do it on non-uniform we gotta calculate xdiv/ydiv
                xdiv = int(math.sqrt(len(pnts)) - 1)
                # Sample grid here is a acting like a raycast, because its a uniform grid we can optimize it
                new_segments = IntArray()
                new_cvs = VectorArray()
                line_segments = olines.get_segments()
                line_count = 0
                for seg in line_segments:
                    partial_pnts = orig_pnts[line_count:line_count + seg]
                    proj_pnts, proj_mask = VectorArray.sample_grid(pnts, partial_pnts, xdiv=xdiv, ydiv=xdiv)
                    line_count += seg
                    if proj_mask.none():
                        continue
                    # we only need the projected ones
                    proj_pnts = proj_pnts[proj_mask]
                    new_cvs.extend(proj_pnts)
                    new_segments << len(proj_pnts)
                curve_pts = new_cvs
                # no need for frameaverage as we don't use any TBN data
                if len(new_cvs) == 0:
                    return
                olines = Lines(new_cvs, new_segments)
                if 0:
                    container = SceneLibrary.Scene.createCurve(olines.get_points(), identifier=1231,
                                                               name='projected_curve')

            # Move the curve into "local" space of mesh as new_positions is not offset, its the GNMesh's data, not the Unreal object's
            curve_pts = curve_pts - offset
            tree = PointsLibrary.get_cached_kd_tree(curve_pts)
            falloff_dist = (falloff_dist * 100.0)
            # prune points that are way too far so we don't query too many pnts
            # curve might be totally flat, having it a z axis of 0, so we just assume its huge in Z
            bbox_min, bbox_max = curve_pts.minmax()
            extent = Vector3f(falloff_dist, falloff_dist, falloff_dist)
            bbox_min -= extent
            bbox_max += extent
            bbox_min.z = -99999999.0
            bbox_max.z = 99999999.0

            inside_mask = new_positions.inside_bounding_box(bbox_min, bbox_max)
            # len(indices) and len(distsq) will be now pruned to inside mask too
            filtered_pnts = new_positions[inside_mask]
            indices, distsq = tree.nearestNeighbour(filtered_pnts)

            start, end, t = olines.closest_data(filtered_pnts, indices)
            fomask = distsq < (falloff_dist * falloff_dist)
            orig = filtered_pnts.z[fomask]
            line_pts = orig_pnts.z
            closest_pnts = line_pts[start].lerp(line_pts[end], t)
            closest_pnts -= offset.z

            # LERP
            if falloff_interp > 0.0:
                falloff_interp *= 0.5
                strength = distsq[fomask].remap(0, 1)
                strength = self.smoothstep(strength, 0.5 - falloff_interp, 0.5 + falloff_interp)
                filtered_pnts.z[fomask] = closest_pnts[fomask].lerp(orig, strength)
            else:
                filtered_pnts.z[fomask] = closest_pnts[fomask]

            if 1:
                new_positions[inside_mask] = filtered_pnts

    @staticmethod
    def createDescriptor():
        tool = ToolDescriptor(identifier=1887775, name='Terrain Tool', version="1.0",
                              description='Creates a terrain mesh. This tool can be used with height maps and curves to shape it.',
                              access='')
        tool.setHasMeshOutput(True)
        tool.createProperty(name='Scale', value=(1.5, 0.01, 32.0, 0.0, 256.0), var_name='scale',
                            description="Overall size of the terrain")
        tool.createProperty(name='Material', value=AssetContainer(), var_name='material',
                            description="Material to apply to the terrain")
        tool.createProperty(name='Curved', value=(0.5, -2.5, 2.5), var_name='curved',
                            description="Controls the curvature of the terrain from concave to convex")
        tool.createProperty(name='Sink', value=(0.25, -1.0, 1.0), var_name='sink',
                            description="Sinks the terrain up or down")
        tool.createProperty(name='Relax Weight', value=(0.0, 0.0, 1.0, 0.0, 1.0), var_name='relax_weight',
                            description="Smooths the terrain by averaging nearby vertices")
        tool.createProperty(name='Subdivision', value=(3, 1, 10, 1, 64), var_name='subdivision',
                            description="Controls mesh density - higher values create more detailed geometry")
        tool.createProperty(name='Height', value=(0.8, 0.0, 5.0, 0.0, 100.0), var_name='height',
                            description="Base height of the terrain")
        tool.createProperty(name='Uv Scale', value=(1.0, 0.0, 10.0, 0.0, 100.0), var_name='uv_scale',
                            description="Scales the UV coordinates for texturing")
        tool.createProperty(name='Seed', value=(0, 0, 100000), var_name='seed',
                            description="Creates new variations of the terrain")

        tool.addGroup('Noise Deformation')
        tool.createProperty(name='Turbulence', value=(0.5, 0.0, 10.0, 0.0, 100.0), var_name='turbulence',
                            description="Controls the frequency of the primary noise pattern")
        tool.createProperty(name='Mid Turbulance', value=(0.75, 0.0, 10.0, 0.0, 100.0), var_name='mid_turbulance',
                            description="Controls the frequency of secondary noise details")

        tool.addGroup('Curve Deformation')
        tool.createProperty(name='Curves', value=SceneContainer(), var_name='falloff_curves',
                            description="You can use curves to change the shape of the terrain")
        tool.createProperty(name='Sampling', value=(180.0, 10.0, 100.0, 0.0, 1000.0), var_name='curve_distance',
                            description="Distance between points along guide curves")
        tool.createProperty(name='Width', value=(5.0, 0.0, 100.0, 0.0, 1000.0), var_name='falloff',
                            description="Distance of curve influence on surrounding terrain")
        tool.createProperty(name='Falloff', value=(0.3, 0.0, 1.0), var_name='falloff_interp',
                            description="Smoothness of transition between curve influence and terrain")
        tool.createProperty(name='Project Curves', value=True, var_name='falloff_project',
                            description="Projects curves onto terrain surface before applying influence")

        tool.createProperty(name="Curve Table", value="", widget="DataTable", var_name="curve_table",
                            description="You can use curves to change the shape of the terrain",
                            extra={"tableDefault": CURVE_TABLE})

        # tool.addGroup('Mesh Deformation')
        # tool.createProperty(name='Meshes', value=SceneContainer(), var_name='md_meshes', description="")
        # tool.createProperty(name='Width', value=(5.0, 0.0, 100.0, 0.0, 1000.0), var_name='md_width', description="")
        # tool.createProperty(name='Falloff', value=(0.3, 0.0, 1.0), var_name='md_falloff', description="")

        tool.addGroup('Height Map')
        texture = tool.createProperty(name='Texture', value=AssetContainer(), var_name='height_texture',
                                      description="Select a height map in the UE content browser, and hit the + button to use it as a deformer on this terrain",
                                      mask_state=False)
        texture.setContainerCount(1)
        tool.createProperty(name='Intensity', value=(0.5, 0.0, 5.0, 0.0, 100.0), var_name='height_multiplier',
                            description="Strength of height map influence")

        return tool
