use bevy::{
    asset::RenderAssetUsages,
    color::palettes::basic::YELLOW,
    core_pipeline::core_2d::{Transparent2d, CORE_2D_DEPTH_FORMAT},
    math::{ops, FloatOrd},
    mesh::{Indices, MeshVertexAttribute, VertexBufferLayout},
    prelude::*,
    render::{
        mesh::RenderMesh,
        render_asset::RenderAssets,
        render_phase::{
            AddRenderCommand, DrawFunctions, PhaseItemExtraIndex, SetItemPipeline,
            ViewSortedRenderPhases,
        },
        render_resource::{
            BlendState, ColorTargetState, ColorWrites, CompareFunction, DepthBiasState,
            DepthStencilState, Face, FragmentState, MultisampleState, PipelineCache,
            PrimitiveState, PrimitiveTopology, RenderPipelineDescriptor, SpecializedRenderPipeline,
            SpecializedRenderPipelines, StencilFaceState, StencilState, TextureFormat,
            VertexFormat, VertexState, VertexStepMode,
        },
        sync_component::SyncComponentPlugin,
        sync_world::{MainEntityHashMap, RenderEntity},
        view::{ExtractedView, RenderVisibleEntities, ViewTarget},
        Extract, Render, RenderApp, RenderStartup, RenderSystems,
    },
    sprite_render::{
        extract_mesh2d, init_mesh_2d_pipeline, DrawMesh2d, Material2dBindGroupId, Mesh2dPipeline,
        Mesh2dPipelineKey, Mesh2dTransforms, MeshFlags, RenderMesh2dInstance, SetMesh2dBindGroup,
        SetMesh2dViewBindGroup,
    },
};
use std::f32::consts::PI;
fn main() {
    App::new()
        .add_plugins((DefaultPlugins, ColoredMesh2dPlugin))
        .add_systems(Startup, star)
        .run();
}
fn star(
    mut commands: Commands,
        mut meshes: ResMut<Assets<Mesh>>,
) {
                                let mut star = Mesh::new(
        PrimitiveTopology::TriangleList,
        RenderAssetUsages::RENDER_WORLD,
    );
                                                    let mut v_pos = vec![[0.0, 0.0, 0.0]];
    for i in 0..10 {
                let a = i as f32 * PI / 5.0;
                let r = (1 - i % 2) as f32 * 100.0 + 100.0;
                v_pos.push([r * ops::sin(a), r * ops::cos(a), 0.0]);
    }
        star.insert_attribute(Mesh::ATTRIBUTE_POSITION, v_pos);
            let mut v_color: Vec<u32> = vec![LinearRgba::BLACK.as_u32()];
    v_color.extend_from_slice(&[LinearRgba::from(YELLOW).as_u32(); 10]);
    star.insert_attribute(
        MeshVertexAttribute::new("Vertex_Color", 1, VertexFormat::Uint32),
        v_color,
    );
                                        let mut indices = vec![0, 1, 10];
    for i in 2..=10 {
        indices.extend_from_slice(&[0, i, i - 1]);
    }
    star.insert_indices(Indices::U32(indices));
        commands.spawn((
                ColoredMesh2d,
                Mesh2d(meshes.add(star)),
    ));
    commands.spawn(Camera2d);
}
#[derive(Component, Default)]
pub struct ColoredMesh2d;
#[derive(Resource)]
pub struct ColoredMesh2dPipeline {
        mesh2d_pipeline: Mesh2dPipeline,
        shader: Handle<Shader>,
}
fn init_colored_mesh_2d_pipeline(
    mut commands: Commands,
    mesh2d_pipeline: Res<Mesh2dPipeline>,
    colored_mesh2d_shader: Res<ColoredMesh2dShader>,
) {
    commands.insert_resource(ColoredMesh2dPipeline {
        mesh2d_pipeline: mesh2d_pipeline.clone(),
                shader: colored_mesh2d_shader.0.clone(),
    });
}
impl SpecializedRenderPipeline for ColoredMesh2dPipeline {
    type Key = Mesh2dPipelineKey;
    fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
                        let formats = vec![
                        VertexFormat::Float32x3,
                        VertexFormat::Uint32,
        ];
        let vertex_layout =
            VertexBufferLayout::from_vertex_formats(VertexStepMode::Vertex, formats);
        let format = match key.contains(Mesh2dPipelineKey::HDR) {
            true => ViewTarget::TEXTURE_FORMAT_HDR,
            false => TextureFormat::bevy_default(),
        };
        RenderPipelineDescriptor {
            vertex: VertexState {
                                shader: self.shader.clone(),
                                buffers: vec![vertex_layout],
                ..default()
            },
            fragment: Some(FragmentState {
                                shader: self.shader.clone(),
                targets: vec![Some(ColorTargetState {
                    format,
                    blend: Some(BlendState::ALPHA_BLENDING),
                    write_mask: ColorWrites::ALL,
                })],
                ..default()
            }),
                        layout: vec![
                                self.mesh2d_pipeline.view_layout.clone(),
                                self.mesh2d_pipeline.mesh_layout.clone(),
            ],
            primitive: PrimitiveState {
                cull_mode: Some(Face::Back),
                topology: key.primitive_topology(),
                ..default()
            },
            depth_stencil: Some(DepthStencilState {
                format: CORE_2D_DEPTH_FORMAT,
                depth_write_enabled: false,
                depth_compare: CompareFunction::GreaterEqual,
                stencil: StencilState {
                    front: StencilFaceState::IGNORE,
                    back: StencilFaceState::IGNORE,
                    read_mask: 0,
                    write_mask: 0,
                },
                bias: DepthBiasState {
                    constant: 0,
                    slope_scale: 0.0,
                    clamp: 0.0,
                },
            }),
            multisample: MultisampleState {
                count: key.msaa_samples(),
                mask: !0,
                alpha_to_coverage_enabled: false,
            },
            label: Some("colored_mesh2d_pipeline".into()),
            ..default()
        }
    }
}
type DrawColoredMesh2d = (
        SetItemPipeline,
        SetMesh2dViewBindGroup<0>,
        SetMesh2dBindGroup<1>,
        DrawMesh2d,
);
const COLORED_MESH2D_SHADER: &str = r"
// Import the standard 2d mesh uniforms and set their bind groups
#import bevy_sprite_render::mesh2d_functions
// The structure of the vertex buffer is as specified in `specialize()`
struct Vertex {
    @builtin(instance_index) instance_index: u32,
    @location(0) position: vec3<f32>,
    @location(1) color: u32,
};
struct VertexOutput {
    // The vertex shader must set the on-screen position of the vertex
    @builtin(position) clip_position: vec4<f32>,
    // We pass the vertex color to the fragment shader in location 0
    @location(0) color: vec4<f32>,
};
/// Entry point for the vertex shader
@vertex
fn vertex(vertex: Vertex) -> VertexOutput {
    var out: VertexOutput;
    // Project the world position of the mesh into screen position
    let model = mesh2d_functions::get_world_from_local(vertex.instance_index);
    out.clip_position = mesh2d_functions::mesh2d_position_local_to_clip(model, vec4<f32>(vertex.position, 1.0));
    // Unpack the `u32` from the vertex buffer into the `vec4<f32>` used by the fragment shader
    out.color = vec4<f32>((vec4<u32>(vertex.color) >> vec4<u32>(0u, 8u, 16u, 24u)) & vec4<u32>(255u)) / 255.0;
    return out;
}
// The input of the fragment shader must correspond to the output of the vertex shader for all `location`s
struct FragmentInput {
    // The color is interpolated between vertices by default
    @location(0) color: vec4<f32>,
};
/// Entry point for the fragment shader
@fragment
fn fragment(in: FragmentInput) -> @location(0) vec4<f32> {
    return in.color;
}
";
pub struct ColoredMesh2dPlugin;
#[derive(Resource)]
struct ColoredMesh2dShader(Handle<Shader>);
#[derive(Resource, Deref, DerefMut, Default)]
pub struct RenderColoredMesh2dInstances(MainEntityHashMap<RenderMesh2dInstance>);
impl Plugin for ColoredMesh2dPlugin {
    fn build(&self, app: &mut App) {
                let mut shaders = app.world_mut().resource_mut::<Assets<Shader>>();
                        let shader = shaders.add(Shader::from_wgsl(COLORED_MESH2D_SHADER, file!()));
        app.add_plugins(SyncComponentPlugin::<ColoredMesh2d>::default());
                app.get_sub_app_mut(RenderApp)
            .unwrap()
            .insert_resource(ColoredMesh2dShader(shader))
            .add_render_command::<Transparent2d, DrawColoredMesh2d>()
            .init_resource::<SpecializedRenderPipelines<ColoredMesh2dPipeline>>()
            .init_resource::<RenderColoredMesh2dInstances>()
            .add_systems(
                RenderStartup,
                init_colored_mesh_2d_pipeline.after(init_mesh_2d_pipeline),
            )
            .add_systems(
                ExtractSchedule,
                extract_colored_mesh2d.after(extract_mesh2d),
            )
            .add_systems(
                Render,
                queue_colored_mesh2d.in_set(RenderSystems::QueueMeshes),
            );
    }
}
pub fn extract_colored_mesh2d(
    mut commands: Commands,
    mut previous_len: Local<usize>,
            query: Extract<
        Query<
            (
                Entity,
                RenderEntity,
                &ViewVisibility,
                &GlobalTransform,
                &Mesh2d,
            ),
            With<ColoredMesh2d>,
        >,
    >,
    mut render_mesh_instances: ResMut<RenderColoredMesh2dInstances>,
) {
    let mut values = Vec::with_capacity(*previous_len);
    for (entity, render_entity, view_visibility, transform, handle) in &query {
        if !view_visibility.get() {
            continue;
        }
        let transforms = Mesh2dTransforms {
            world_from_local: (&transform.affine()).into(),
            flags: MeshFlags::empty().bits(),
        };
        values.push((render_entity, ColoredMesh2d));
        render_mesh_instances.insert(
            entity.into(),
            RenderMesh2dInstance {
                mesh_asset_id: handle.0.id(),
                transforms,
                material_bind_group_id: Material2dBindGroupId::default(),
                automatic_batching: false,
                tag: 0,
            },
        );
    }
    *previous_len = values.len();
    commands.try_insert_batch(values);
}
pub fn queue_colored_mesh2d(
    transparent_draw_functions: Res<DrawFunctions<Transparent2d>>,
    colored_mesh2d_pipeline: Res<ColoredMesh2dPipeline>,
    mut pipelines: ResMut<SpecializedRenderPipelines<ColoredMesh2dPipeline>>,
    pipeline_cache: Res<PipelineCache>,
    render_meshes: Res<RenderAssets<RenderMesh>>,
    render_mesh_instances: Res<RenderColoredMesh2dInstances>,
    mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent2d>>,
    views: Query<(&RenderVisibleEntities, &ExtractedView, &Msaa)>,
) {
    if render_mesh_instances.is_empty() {
        return;
    }
        for (visible_entities, view, msaa) in &views {
        let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
        else {
            continue;
        };
        let draw_colored_mesh2d = transparent_draw_functions.read().id::<DrawColoredMesh2d>();
        let mesh_key = Mesh2dPipelineKey::from_msaa_samples(msaa.samples())
            | Mesh2dPipelineKey::from_hdr(view.hdr);
                for (render_entity, visible_entity) in visible_entities.iter::<Mesh2d>() {
            if let Some(mesh_instance) = render_mesh_instances.get(visible_entity) {
                let mesh2d_handle = mesh_instance.mesh_asset_id;
                let mesh2d_transforms = &mesh_instance.transforms;
                                let mut mesh2d_key = mesh_key;
                let Some(mesh) = render_meshes.get(mesh2d_handle) else {
                    continue;
                };
                mesh2d_key |= Mesh2dPipelineKey::from_primitive_topology(mesh.primitive_topology());
                let pipeline_id =
                    pipelines.specialize(&pipeline_cache, &colored_mesh2d_pipeline, mesh2d_key);
                let mesh_z = mesh2d_transforms.world_from_local.translation.z;
                transparent_phase.add(Transparent2d {
                    entity: (*render_entity, *visible_entity),
                    draw_function: draw_colored_mesh2d,
                    pipeline: pipeline_id,
                                                            sort_key: FloatOrd(mesh_z),
                                        batch_range: 0..1,
                    extra_index: PhaseItemExtraIndex::None,
                    extracted_index: usize::MAX,
                    indexed: mesh.indexed(),
                });
            }
        }
    }
}