Subversion Repositories Kolibri OS

Rev

Blame | Last modification | View Log | RSS feed

  1. /*
  2.  * Copyright © 2014 Intel Corporation
  3.  *
  4.  * Permission is hereby granted, free of charge, to any person obtaining a
  5.  * copy of this software and associated documentation files (the "Software"),
  6.  * to deal in the Software without restriction, including without limitation
  7.  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
  8.  * and/or sell copies of the Software, and to permit persons to whom the
  9.  * Software is furnished to do so, subject to the following conditions:
  10.  *
  11.  * The above copyright notice and this permission notice (including the next
  12.  * paragraph) shall be included in all copies or substantial portions of the
  13.  * Software.
  14.  *
  15.  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16.  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17.  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
  18.  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19.  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  20.  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
  21.  * IN THE SOFTWARE.
  22.  *
  23.  * Authors:
  24.  *    Connor Abbott (cwabbott0@gmail.com)
  25.  *
  26.  */
  27.  
  28. #include "nir.h"
  29. #include <stdlib.h>
  30. #include <unistd.h>
  31.  
  32. /*
  33.  * Implements the classic to-SSA algorithm described by Cytron et. al. in
  34.  * "Efficiently Computing Static Single Assignment Form and the Control
  35.  * Dependence Graph."
  36.  */
  37.  
  38. /* inserts a phi node of the form reg = phi(reg, reg, reg, ...) */
  39.  
  40. static void
  41. insert_trivial_phi(nir_register *reg, nir_block *block, void *mem_ctx)
  42. {
  43.    nir_phi_instr *instr = nir_phi_instr_create(mem_ctx);
  44.  
  45.    instr->dest.reg.reg = reg;
  46.    struct set_entry *entry;
  47.    set_foreach(block->predecessors, entry) {
  48.       nir_block *pred = (nir_block *) entry->key;
  49.  
  50.       nir_phi_src *src = ralloc(instr, nir_phi_src);
  51.       src->pred = pred;
  52.       src->src.is_ssa = false;
  53.       src->src.reg.base_offset = 0;
  54.       src->src.reg.indirect = NULL;
  55.       src->src.reg.reg = reg;
  56.       exec_list_push_tail(&instr->srcs, &src->node);
  57.    }
  58.  
  59.    nir_instr_insert_before_block(block, &instr->instr);
  60. }
  61.  
  62. static void
  63. insert_phi_nodes(nir_function_impl *impl)
  64. {
  65.    void *mem_ctx = ralloc_parent(impl);
  66.  
  67.    unsigned *work = calloc(impl->num_blocks, sizeof(unsigned));
  68.    unsigned *has_already = calloc(impl->num_blocks, sizeof(unsigned));
  69.  
  70.    /*
  71.     * Since the work flags already prevent us from inserting a node that has
  72.     * ever been inserted into W, we don't need to use a set to represent W.
  73.     * Also, since no block can ever be inserted into W more than once, we know
  74.     * that the maximum size of W is the number of basic blocks in the
  75.     * function. So all we need to handle W is an array and a pointer to the
  76.     * next element to be inserted and the next element to be removed.
  77.     */
  78.    nir_block **W = malloc(impl->num_blocks * sizeof(nir_block *));
  79.    unsigned w_start, w_end;
  80.  
  81.    unsigned iter_count = 0;
  82.  
  83.    nir_index_blocks(impl);
  84.  
  85.    foreach_list_typed(nir_register, reg, node, &impl->registers) {
  86.       if (reg->num_array_elems != 0)
  87.          continue;
  88.  
  89.       w_start = w_end = 0;
  90.       iter_count++;
  91.  
  92.       nir_foreach_def(reg, dest) {
  93.          nir_instr *def = dest->reg.parent_instr;
  94.          if (work[def->block->index] < iter_count)
  95.             W[w_end++] = def->block;
  96.          work[def->block->index] = iter_count;
  97.       }
  98.  
  99.       while (w_start != w_end) {
  100.          nir_block *cur = W[w_start++];
  101.          struct set_entry *entry;
  102.          set_foreach(cur->dom_frontier, entry) {
  103.             nir_block *next = (nir_block *) entry->key;
  104.  
  105.             /*
  106.              * If there's more than one return statement, then the end block
  107.              * can be a join point for some definitions. However, there are
  108.              * no instructions in the end block, so nothing would use those
  109.              * phi nodes. Of course, we couldn't place those phi nodes
  110.              * anyways due to the restriction of having no instructions in the
  111.              * end block...
  112.              */
  113.             if (next == impl->end_block)
  114.                continue;
  115.  
  116.             if (has_already[next->index] < iter_count) {
  117.                insert_trivial_phi(reg, next, mem_ctx);
  118.                has_already[next->index] = iter_count;
  119.                if (work[next->index] < iter_count) {
  120.                   work[next->index] = iter_count;
  121.                   W[w_end++] = next;
  122.                }
  123.             }
  124.          }
  125.       }
  126.    }
  127.  
  128.    free(work);
  129.    free(has_already);
  130.    free(W);
  131. }
  132.  
  133. typedef struct {
  134.    nir_ssa_def **stack;
  135.    int index;
  136.    unsigned num_defs; /** < used to add indices to debug names */
  137. #ifndef NDEBUG
  138.    unsigned stack_size;
  139. #endif
  140. } reg_state;
  141.  
  142. typedef struct {
  143.    reg_state *states;
  144.    void *mem_ctx;
  145.    nir_instr *parent_instr;
  146.    nir_if *parent_if;
  147.    nir_function_impl *impl;
  148.  
  149.    /* map from SSA value -> original register */
  150.    struct hash_table *ssa_map;
  151. } rewrite_state;
  152.  
  153. static nir_ssa_def *get_ssa_src(nir_register *reg, rewrite_state *state)
  154. {
  155.    unsigned index = reg->index;
  156.  
  157.    if (state->states[index].index == -1) {
  158.       /*
  159.        * We're using an undefined register, create a new undefined SSA value
  160.        * to preserve the information that this source is undefined
  161.        */
  162.       nir_ssa_undef_instr *instr =
  163.          nir_ssa_undef_instr_create(state->mem_ctx, reg->num_components);
  164.  
  165.       /*
  166.        * We could just insert the undefined instruction before the instruction
  167.        * we're rewriting, but we could be rewriting a phi source in which case
  168.        * we can't do that, so do the next easiest thing - insert it at the
  169.        * beginning of the program. In the end, it doesn't really matter where
  170.        * the undefined instructions are because they're going to be ignored
  171.        * in the backend.
  172.        */
  173.       nir_instr_insert_before_cf_list(&state->impl->body, &instr->instr);
  174.       return &instr->def;
  175.    }
  176.  
  177.    return state->states[index].stack[state->states[index].index];
  178. }
  179.  
  180. static bool
  181. rewrite_use(nir_src *src, void *_state)
  182. {
  183.    rewrite_state *state = (rewrite_state *) _state;
  184.  
  185.    if (src->is_ssa)
  186.       return true;
  187.  
  188.    unsigned index = src->reg.reg->index;
  189.  
  190.    if (state->states[index].stack == NULL)
  191.       return true;
  192.  
  193.    nir_ssa_def *def = get_ssa_src(src->reg.reg, state);
  194.    if (state->parent_instr)
  195.       nir_instr_rewrite_src(state->parent_instr, src, nir_src_for_ssa(def));
  196.    else
  197.       nir_if_rewrite_condition(state->parent_if, nir_src_for_ssa(def));
  198.  
  199.    return true;
  200. }
  201.  
  202. static bool
  203. rewrite_def_forwards(nir_dest *dest, void *_state)
  204. {
  205.    rewrite_state *state = (rewrite_state *) _state;
  206.  
  207.    if (dest->is_ssa)
  208.       return true;
  209.  
  210.    nir_register *reg = dest->reg.reg;
  211.    unsigned index = reg->index;
  212.  
  213.    if (state->states[index].stack == NULL)
  214.       return true;
  215.  
  216.    char *name = NULL;
  217.    if (dest->reg.reg->name)
  218.       name = ralloc_asprintf(state->mem_ctx, "%s_%u", dest->reg.reg->name,
  219.                              state->states[index].num_defs);
  220.  
  221.    list_del(&dest->reg.def_link);
  222.    nir_ssa_dest_init(state->parent_instr, dest, reg->num_components, name);
  223.  
  224.    /* push our SSA destination on the stack */
  225.    state->states[index].index++;
  226.    assert(state->states[index].index < state->states[index].stack_size);
  227.    state->states[index].stack[state->states[index].index] = &dest->ssa;
  228.    state->states[index].num_defs++;
  229.  
  230.    _mesa_hash_table_insert(state->ssa_map, &dest->ssa, reg);
  231.  
  232.    return true;
  233. }
  234.  
  235. static void
  236. rewrite_alu_instr_forward(nir_alu_instr *instr, rewrite_state *state)
  237. {
  238.    state->parent_instr = &instr->instr;
  239.  
  240.    nir_foreach_src(&instr->instr, rewrite_use, state);
  241.  
  242.    if (instr->dest.dest.is_ssa)
  243.       return;
  244.  
  245.    nir_register *reg = instr->dest.dest.reg.reg;
  246.    unsigned index = reg->index;
  247.  
  248.    if (state->states[index].stack == NULL)
  249.       return;
  250.  
  251.    unsigned write_mask = instr->dest.write_mask;
  252.    if (write_mask != (1 << instr->dest.dest.reg.reg->num_components) - 1) {
  253.       /*
  254.        * Calculate the number of components the final instruction, which for
  255.        * per-component things is the number of output components of the
  256.        * instruction and non-per-component things is the number of enabled
  257.        * channels in the write mask.
  258.        */
  259.       unsigned num_components;
  260.       if (nir_op_infos[instr->op].output_size == 0) {
  261.          unsigned temp = (write_mask & 0x5) + ((write_mask >> 1) & 0x5);
  262.          num_components = (temp & 0x3) + ((temp >> 2) & 0x3);
  263.       } else {
  264.          num_components = nir_op_infos[instr->op].output_size;
  265.       }
  266.  
  267.       char *name = NULL;
  268.       if (instr->dest.dest.reg.reg->name)
  269.          name = ralloc_asprintf(state->mem_ctx, "%s_%u",
  270.                                 reg->name, state->states[index].num_defs);
  271.  
  272.       instr->dest.write_mask = (1 << num_components) - 1;
  273.       list_del(&instr->dest.dest.reg.def_link);
  274.       nir_ssa_dest_init(&instr->instr, &instr->dest.dest, num_components, name);
  275.  
  276.       if (nir_op_infos[instr->op].output_size == 0) {
  277.          /*
  278.           * When we change the output writemask, we need to change the
  279.           * swizzles for per-component inputs too
  280.           */
  281.          for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
  282.             if (nir_op_infos[instr->op].input_sizes[i] != 0)
  283.                continue;
  284.  
  285.             unsigned new_swizzle[4] = {0, 0, 0, 0};
  286.  
  287.             /*
  288.              * We keep two indices:
  289.              * 1. The index of the original (non-SSA) component
  290.              * 2. The index of the post-SSA, compacted, component
  291.              *
  292.              * We need to map the swizzle component at index 1 to the swizzle
  293.              * component at index 2.
  294.              */
  295.  
  296.             unsigned ssa_index = 0;
  297.             for (unsigned index = 0; index < 4; index++) {
  298.                if (!((write_mask >> index) & 1))
  299.                   continue;
  300.  
  301.                new_swizzle[ssa_index] = instr->src[i].swizzle[index];
  302.                ssa_index++;
  303.             }
  304.  
  305.             for (unsigned j = 0; j < 4; j++)
  306.                instr->src[i].swizzle[j] = new_swizzle[j];
  307.          }
  308.       }
  309.  
  310.       nir_op op;
  311.       switch (reg->num_components) {
  312.       case 2: op = nir_op_vec2; break;
  313.       case 3: op = nir_op_vec3; break;
  314.       case 4: op = nir_op_vec4; break;
  315.       default: unreachable("not reached");
  316.       }
  317.  
  318.       nir_alu_instr *vec = nir_alu_instr_create(state->mem_ctx, op);
  319.  
  320.       vec->dest.dest.reg.reg = reg;
  321.       vec->dest.write_mask = (1 << reg->num_components) - 1;
  322.  
  323.       nir_ssa_def *old_src = get_ssa_src(reg, state);
  324.       nir_ssa_def *new_src = &instr->dest.dest.ssa;
  325.  
  326.       unsigned ssa_index = 0;
  327.       for (unsigned i = 0; i < reg->num_components; i++) {
  328.          vec->src[i].src.is_ssa = true;
  329.          if ((write_mask >> i) & 1) {
  330.             vec->src[i].src.ssa = new_src;
  331.             if (nir_op_infos[instr->op].output_size == 0)
  332.                vec->src[i].swizzle[0] = ssa_index;
  333.             else
  334.                vec->src[i].swizzle[0] = i;
  335.             ssa_index++;
  336.          } else {
  337.             vec->src[i].src.ssa = old_src;
  338.             vec->src[i].swizzle[0] = i;
  339.          }
  340.       }
  341.  
  342.       nir_instr_insert_after(&instr->instr, &vec->instr);
  343.  
  344.       state->parent_instr = &vec->instr;
  345.       rewrite_def_forwards(&vec->dest.dest, state);
  346.    } else {
  347.       rewrite_def_forwards(&instr->dest.dest, state);
  348.    }
  349. }
  350.  
  351. static void
  352. rewrite_phi_instr(nir_phi_instr *instr, rewrite_state *state)
  353. {
  354.    state->parent_instr = &instr->instr;
  355.    rewrite_def_forwards(&instr->dest, state);
  356. }
  357.  
  358. static void
  359. rewrite_instr_forward(nir_instr *instr, rewrite_state *state)
  360. {
  361.    if (instr->type == nir_instr_type_alu) {
  362.       rewrite_alu_instr_forward(nir_instr_as_alu(instr), state);
  363.       return;
  364.    }
  365.  
  366.    if (instr->type == nir_instr_type_phi) {
  367.       rewrite_phi_instr(nir_instr_as_phi(instr), state);
  368.       return;
  369.    }
  370.  
  371.    state->parent_instr = instr;
  372.  
  373.    nir_foreach_src(instr, rewrite_use, state);
  374.    nir_foreach_dest(instr, rewrite_def_forwards, state);
  375. }
  376.  
  377. static void
  378. rewrite_phi_sources(nir_block *block, nir_block *pred, rewrite_state *state)
  379. {
  380.    nir_foreach_instr(block, instr) {
  381.       if (instr->type != nir_instr_type_phi)
  382.          break;
  383.  
  384.       nir_phi_instr *phi_instr = nir_instr_as_phi(instr);
  385.  
  386.       state->parent_instr = instr;
  387.  
  388.       nir_foreach_phi_src(phi_instr, src) {
  389.          if (src->pred == pred) {
  390.             rewrite_use(&src->src, state);
  391.             break;
  392.          }
  393.       }
  394.    }
  395. }
  396.  
  397. static bool
  398. rewrite_def_backwards(nir_dest *dest, void *_state)
  399. {
  400.    rewrite_state *state = (rewrite_state *) _state;
  401.  
  402.    if (!dest->is_ssa)
  403.       return true;
  404.  
  405.    struct hash_entry *entry =
  406.       _mesa_hash_table_search(state->ssa_map, &dest->ssa);
  407.  
  408.    if (!entry)
  409.       return true;
  410.  
  411.    nir_register *reg = (nir_register *) entry->data;
  412.    unsigned index = reg->index;
  413.  
  414.    state->states[index].index--;
  415.    assert(state->states[index].index >= -1);
  416.  
  417.    return true;
  418. }
  419.  
  420. static void
  421. rewrite_instr_backwards(nir_instr *instr, rewrite_state *state)
  422. {
  423.    nir_foreach_dest(instr, rewrite_def_backwards, state);
  424. }
  425.  
  426. static void
  427. rewrite_block(nir_block *block, rewrite_state *state)
  428. {
  429.    /* This will skip over any instructions after the current one, which is
  430.     * what we want because those instructions (vector gather, conditional
  431.     * select) will already be in SSA form.
  432.     */
  433.    nir_foreach_instr_safe(block, instr) {
  434.       rewrite_instr_forward(instr, state);
  435.    }
  436.  
  437.    if (block != state->impl->end_block &&
  438.        !nir_cf_node_is_last(&block->cf_node) &&
  439.        nir_cf_node_next(&block->cf_node)->type == nir_cf_node_if) {
  440.       nir_if *if_stmt = nir_cf_node_as_if(nir_cf_node_next(&block->cf_node));
  441.       state->parent_instr = NULL;
  442.       state->parent_if = if_stmt;
  443.       rewrite_use(&if_stmt->condition, state);
  444.    }
  445.  
  446.    if (block->successors[0])
  447.       rewrite_phi_sources(block->successors[0], block, state);
  448.    if (block->successors[1])
  449.       rewrite_phi_sources(block->successors[1], block, state);
  450.  
  451.    for (unsigned i = 0; i < block->num_dom_children; i++)
  452.       rewrite_block(block->dom_children[i], state);
  453.  
  454.    nir_foreach_instr_reverse(block, instr) {
  455.       rewrite_instr_backwards(instr, state);
  456.    }
  457. }
  458.  
  459. static void
  460. remove_unused_regs(nir_function_impl *impl, rewrite_state *state)
  461. {
  462.    foreach_list_typed_safe(nir_register, reg, node, &impl->registers) {
  463.       if (state->states[reg->index].stack != NULL)
  464.          exec_node_remove(&reg->node);
  465.    }
  466. }
  467.  
  468. static void
  469. init_rewrite_state(nir_function_impl *impl, rewrite_state *state)
  470. {
  471.    state->impl = impl;
  472.    state->mem_ctx = ralloc_parent(impl);
  473.    state->ssa_map = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
  474.                                             _mesa_key_pointer_equal);
  475.    state->states = ralloc_array(NULL, reg_state, impl->reg_alloc);
  476.  
  477.    foreach_list_typed(nir_register, reg, node, &impl->registers) {
  478.       assert(reg->index < impl->reg_alloc);
  479.       if (reg->num_array_elems > 0) {
  480.          state->states[reg->index].stack = NULL;
  481.       } else {
  482.          /*
  483.           * Calculate a conservative estimate of the stack size based on the
  484.           * number of definitions there are. Note that this function *must* be
  485.           * called after phi nodes are inserted so we can count phi node
  486.           * definitions too.
  487.           */
  488.          unsigned stack_size = list_length(&reg->defs);
  489.  
  490.          state->states[reg->index].stack = ralloc_array(state->states,
  491.                                                         nir_ssa_def *,
  492.                                                         stack_size);
  493. #ifndef NDEBUG
  494.          state->states[reg->index].stack_size = stack_size;
  495. #endif
  496.          state->states[reg->index].index = -1;
  497.          state->states[reg->index].num_defs = 0;
  498.       }
  499.    }
  500. }
  501.  
  502. static void
  503. destroy_rewrite_state(rewrite_state *state)
  504. {
  505.    _mesa_hash_table_destroy(state->ssa_map, NULL);
  506.    ralloc_free(state->states);
  507. }
  508.  
  509. void
  510. nir_convert_to_ssa_impl(nir_function_impl *impl)
  511. {
  512.    nir_metadata_require(impl, nir_metadata_dominance);
  513.  
  514.    insert_phi_nodes(impl);
  515.  
  516.    rewrite_state state;
  517.    init_rewrite_state(impl, &state);
  518.  
  519.    rewrite_block(impl->start_block, &state);
  520.  
  521.    remove_unused_regs(impl, &state);
  522.  
  523.    nir_metadata_preserve(impl, nir_metadata_block_index |
  524.                                nir_metadata_dominance);
  525.  
  526.    destroy_rewrite_state(&state);
  527. }
  528.  
  529. void
  530. nir_convert_to_ssa(nir_shader *shader)
  531. {
  532.    nir_foreach_overload(shader, overload) {
  533.       if (overload->impl)
  534.          nir_convert_to_ssa_impl(overload->impl);
  535.    }
  536. }
  537.